diff --git a/hsml/.github/workflows/mkdocs-main.yml b/hsml/.github/workflows/mkdocs-main.yml
new file mode 100644
index 000000000..001f1fad1
--- /dev/null
+++ b/hsml/.github/workflows/mkdocs-main.yml
@@ -0,0 +1,35 @@
+name: mkdocs-main
+
+on: pull_request
+
+jobs:
+ publish-main:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: set dev version
+ working-directory: ./java
+ run: echo "DEV_VERSION=$(mvn org.apache.maven.plugins:maven-help-plugin:2.1.1:evaluate -Dexpression=project.version | grep -Ev 'Download|INFO|WARNING')" >> $GITHUB_ENV
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.10"
+
+ - name: install deps
+ working-directory: ./python
+ run: cp ../README.md . && pip3 install -r ../requirements-docs.txt && pip3 install -e .[dev]
+
+ - name: generate autodoc
+ run: python3 auto_doc.py
+
+ - name: setup git
+ run: |
+ git config --global user.name Mike
+ git config --global user.email mike@docs.hopsworks.ai
+
+ - name: mike deploy docs
+ run: mike deploy ${{ env.DEV_VERSION }} dev -u
diff --git a/hsml/.github/workflows/mkdocs-release.yml b/hsml/.github/workflows/mkdocs-release.yml
new file mode 100644
index 000000000..e2b4b2b3f
--- /dev/null
+++ b/hsml/.github/workflows/mkdocs-release.yml
@@ -0,0 +1,42 @@
+name: mkdocs-release
+
+on:
+ push:
+ branches: [branch-*\.*]
+
+jobs:
+ publish-release:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: set major/minor/bugfix release version
+ working-directory: ./java
+ run: echo "RELEASE_VERSION=$(mvn org.apache.maven.plugins:maven-help-plugin:2.1.1:evaluate -Dexpression=project.version | grep -Ev 'Download|INFO|WARNING')" >> $GITHUB_ENV
+
+ - name: set major/minor release version
+ run: echo "MAJOR_VERSION=$(echo $RELEASE_VERSION | sed 's/^\([0-9]*\.[0-9]*\).*$/\1/')" >> $GITHUB_ENV
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.10"
+
+ - name: install deps
+ working-directory: ./python
+ run: cp ../README.md . && pip3 install -r ../requirements-docs.txt && pip3 install -e .[dev]
+
+ - name: generate autodoc
+ run: python3 auto_doc.py
+
+ - name: setup git
+ run: |
+ git config --global user.name Mike
+ git config --global user.email mike@docs.hopsworks.ai
+
+ - name: mike deploy docs
+ run: |
+ mike deploy ${{ env.RELEASE_VERSION }} ${{ env.MAJOR_VERSION }} -u --push
+ mike alias ${{ env.RELEASE_VERSION }} latest -u --push
diff --git a/hsml/.github/workflows/python-lint.yml b/hsml/.github/workflows/python-lint.yml
new file mode 100644
index 000000000..88225add7
--- /dev/null
+++ b/hsml/.github/workflows/python-lint.yml
@@ -0,0 +1,163 @@
+name: python
+
+on: pull_request
+
+jobs:
+ lint_stylecheck:
+ name: Lint and Stylecheck
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+
+ - uses: actions/setup-python@v5
+ with:
+ python-version: "3.11"
+
+ - name: Get all changed files
+ id: get-changed-files
+ uses: tj-actions/changed-files@v44
+ with:
+ files_yaml: |
+ src:
+ - 'python/**/*.py'
+ - '!python/tests/**/*.py'
+ test:
+ - 'python/tests/**/*.py'
+
+ - name: install deps
+ run: pip install ruff==0.4.2
+
+ - name: ruff on python files
+ if: steps.get-changed-files.outputs.src_any_changed == 'true'
+ env:
+ SRC_ALL_CHANGED_FILES: ${{ steps.get-changed-files.outputs.src_all_changed_files }}
+ run: ruff check --output-format=github $SRC_ALL_CHANGED_FILES
+
+ - name: ruff on test files
+ if: steps.get-changed-files.outputs.test_any_changed == 'true'
+ env:
+ TEST_ALL_CHANGED_FILES: ${{ steps.get-changed-files.outputs.test_all_changed_files }}
+ run: ruff check --output-format=github $TEST_ALL_CHANGED_FILES
+
+ - name: ruff format --check $ALL_CHANGED_FILES
+ env:
+ ALL_CHANGED_FILES: ${{ steps.get-changed-files.outputs.all_changed_files }}
+ run: ruff format $ALL_CHANGED_FILES
+
+ unit_tests_ubuntu_utc:
+ name: Unit Testing (Ubuntu)
+ needs: lint_stylecheck
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
+
+ steps:
+ - name: Set Timezone
+ run: sudo timedatectl set-timezone UTC
+
+ - uses: actions/checkout@v4
+ - name: Copy README
+ run: cp README.md python/
+
+ - uses: actions/setup-python@v5
+ name: Setup Python
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: "pip"
+ cache-dependency-path: "python/setup.py"
+ - run: pip install -e python[dev]
+
+ - name: Display Python version
+ run: python --version
+
+ - name: Run Pytest suite
+ run: pytest python/tests
+
+ unit_tests_ubuntu_local:
+ name: Unit Testing (Ubuntu) (Local TZ)
+ needs: lint_stylecheck
+ runs-on: ubuntu-latest
+
+ steps:
+ - name: Set Timezone
+ run: sudo timedatectl set-timezone Europe/Amsterdam
+
+ - uses: actions/checkout@v4
+ - name: Copy README
+ run: cp README.md python/
+
+ - uses: actions/setup-python@v5
+ name: Setup Python
+ with:
+ python-version: "3.12"
+ cache: "pip"
+ cache-dependency-path: "python/setup.py"
+ - run: pip install -e python[dev]
+
+ - name: Display Python version
+ run: python --version
+
+ - name: Run Pytest suite
+ run: pytest python/tests
+
+ unit_tests_windows_utc:
+ name: Unit Testing (Windows)
+ needs: lint_stylecheck
+ runs-on: windows-latest
+ strategy:
+ matrix:
+ python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
+
+ steps:
+ - name: Set Timezone
+ run: tzutil /s "UTC"
+
+ - uses: actions/checkout@v4
+ - name: Copy README
+ run: cp README.md python/
+
+ - uses: actions/setup-python@v5
+ name: Setup Python
+ with:
+ python-version: ${{ matrix.python-version }}
+ cache: "pip"
+ cache-dependency-path: "python/setup.py"
+ - run: pip install -e python[dev]
+
+ - name: Display Python version
+ run: python --version
+
+ - name: Run Pytest suite
+ run: pytest python/tests
+
+ unit_tests_windows_local:
+ name: Unit Testing (Windows) (Local TZ)
+ needs: lint_stylecheck
+ runs-on: windows-latest
+
+ steps:
+ - name: Set Timezone
+ run: tzutil /s "W. Europe Standard Time"
+
+ - uses: actions/checkout@v4
+ - name: Copy README
+ run: cp README.md python/
+
+ - uses: actions/setup-python@v5
+ name: Setup Python
+ with:
+ python-version: "3.12"
+ cache: "pip"
+ cache-dependency-path: "python/setup.py"
+ - run: pip install -e python[dev]
+
+ - name: Display Python version
+ run: python --version
+
+ - name: Display pip freeze
+ run: pip freeze
+
+ - name: Run Pytest suite
+ run: pytest python/tests
diff --git a/hsml/.gitignore b/hsml/.gitignore
new file mode 100644
index 000000000..6e96d8144
--- /dev/null
+++ b/hsml/.gitignore
@@ -0,0 +1,130 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+python/README.md
+python/LICENSE
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+.ruff_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# Java
+.idea
+.vscode
+*.iml
+target/
+
+# Mac
+.DS_Store
+
+# mkdocs intemediate files
+docs/generated
diff --git a/hsml/CONTRIBUTING.md b/hsml/CONTRIBUTING.md
new file mode 100644
index 000000000..b287467c6
--- /dev/null
+++ b/hsml/CONTRIBUTING.md
@@ -0,0 +1,215 @@
+## Python development setup
+---
+
+- Fork and clone the repository
+
+- Create a new Python environment with your favourite environment manager, e.g. virtualenv or conda
+
+- Install repository in editable mode with development dependencies:
+
+ ```bash
+ cd python
+ pip install -e ".[dev]"
+ ```
+
+- Install [pre-commit](https://pre-commit.com/) and then activate its hooks. pre-commit is a framework for managing and maintaining multi-language pre-commit hooks. The Model Registry uses pre-commit to ensure code-style and code formatting through [ruff](https://docs.astral.sh/ruff/). Run the following commands from the `python` directory:
+
+ ```bash
+ cd python
+ pip install --user pre-commit
+ pre-commit install
+ ```
+
+ Afterwards, pre-commit will run whenever you commit.
+
+- To run formatting and code-style separately, you can configure your IDE, such as VSCode, to use [ruff](https://docs.astral.sh/ruff/tutorial/#getting-started):
+
+ ```bash
+ cd python
+ ruff check --fix
+ ruff format
+ ```
+
+### Python documentation
+
+We follow a few best practices for writing the Python documentation:
+
+1. Use the google docstring style:
+
+ ```python
+ """[One Line Summary]
+
+ [Extended Summary]
+
+ [!!! example
+ import xyz
+ ]
+
+ # Arguments
+ arg1: Type[, optional]. Description[, defaults to `default`]
+ arg2: Type[, optional]. Description[, defaults to `default`]
+
+ # Returns
+ Type. Description.
+
+ # Raises
+ Exception. Description.
+ """
+ ```
+
+ If Python 3 type annotations are used, they are inserted automatically.
+
+
+2. Model registry entity engine methods (e.g. ModelEngine etc.) only require a single line docstring.
+3. REST Api implementations (e.g. ModelApi etc.) should be fully documented with docstrings without defaults.
+4. Public Api such as metadata objects should be fully documented with defaults.
+
+#### Setup and Build Documentation
+
+We use `mkdocs` together with `mike` ([for versioning](https://github.com/jimporter/mike/)) to build the documentation and a plugin called `keras-autodoc` to auto generate Python API documentation from docstrings.
+
+**Background about `mike`:**
+ `mike` builds the documentation and commits it as a new directory to the gh-pages branch. Each directory corresponds to one version of the documentation. Additionally, `mike` maintains a json in the root of gh-pages with the mappings of versions/aliases for each of the directories available. With aliases you can define extra names like `dev` or `latest`, to indicate stable and unstable releases.
+
+1. Currently we are using our own version of `keras-autodoc`
+
+ ```bash
+ pip install git+https://github.com/logicalclocks/keras-autodoc
+ ```
+
+2. Install HSML with `docs` extras:
+
+ ```bash
+ pip install -e .[dev,docs]
+ ```
+
+3. To build the docs, first run the auto doc script:
+
+ ```bash
+ cd ..
+ python auto_doc.py
+ ```
+
+##### Option 1: Build only current version of docs
+
+4. Either build the docs, or serve them dynamically:
+
+ Note: Links and pictures might not resolve properly later on when checking with this build.
+ The reason for that is that the docs are deployed with versioning on docs.hopsworks.ai and
+ therefore another level is added to all paths, e.g. `docs.hopsworks.ai/[version-or-alias]`.
+ Using relative links should not be affected by this, however, building the docs with version
+ (Option 2) is recommended.
+
+ ```bash
+ mkdocs build
+ # or
+ mkdocs serve
+ ```
+
+##### Option 2 (Preferred): Build multi-version doc with `mike`
+
+###### Versioning on docs.hopsworks.ai
+
+On docs.hopsworks.ai we implement the following versioning scheme:
+
+- current master branches (e.g. of hsml corresponding to master of Hopsworks): rendered as current Hopsworks snapshot version, e.g. **2.2.0-SNAPSHOT [dev]**, where `dev` is an alias to indicate that this is an unstable version.
+- the latest release: rendered with full current version, e.g. **2.1.5 [latest]** with `latest` alias to indicate that this is the latest stable release.
+- previous stable releases: rendered without alias, e.g. **2.1.4**.
+
+###### Build Instructions
+
+4. For this you can either checkout and make a local copy of the `upstream/gh-pages` branch, where
+`mike` maintains the current state of docs.hopsworks.ai, or just build documentation for the branch you are updating:
+
+ Building *one* branch:
+
+ Checkout your dev branch with modified docs:
+ ```bash
+ git checkout [dev-branch]
+ ```
+
+ Generate API docs if necessary:
+ ```bash
+ python auto_doc.py
+ ```
+
+ Build docs with a version and alias
+ ```bash
+ mike deploy [version] [alias] --update-alias
+
+ # for example, if you are updating documentation to be merged to master,
+ # which will become the new SNAPSHOT version:
+ mike deploy 2.2.0-SNAPSHOT dev --update-alias
+
+ # if you are updating docs of the latest stable release branch
+ mike deploy [version] latest --update-alias
+
+ # if you are updating docs of a previous stable release branch
+ mike deploy [version]
+ ```
+
+ If no gh-pages branch existed in your local repository, this will have created it.
+
+ **Important**: If no previous docs were built, you will have to choose a version as default to be loaded as index, as follows
+
+ ```bash
+ mike set-default [version-or-alias]
+ ```
+
+ You can now checkout the gh-pages branch and serve:
+ ```bash
+ git checkout gh-pages
+ mike serve
+ ```
+
+ You can also list all available versions/aliases:
+ ```bash
+ mike list
+ ```
+
+ Delete and reset your local gh-pages branch:
+ ```bash
+ mike delete --all
+
+ # or delete single version
+ mike delete [version-or-alias]
+ ```
+
+#### Adding new API documentation
+
+To add new documentation for APIs, you need to add information about the method/class to document to the `auto_doc.py` script:
+
+```python
+PAGES = {
+ "connection.md": [
+ "hsml.connection.Connection.connection",
+ "hsml.connection.Connection.setup_databricks",
+ ]
+ "new_template.md": [
+ "module",
+ "xyz.asd"
+ ]
+}
+```
+
+Now you can add a template markdown file to the `docs/templates` directory with the name you specified in the auto-doc script. The `new_template.md` file should contain a tag to identify the place at which the API documentation should be inserted:
+
+```
+## The XYZ package
+
+{{module}}
+
+Some extra content here.
+
+!!! example
+ ```python
+ import xyz
+ ```
+
+{{xyz.asd}}
+```
+
+Finally, run the `auto_doc.py` script, as decribed above, to update the documentation.
+
+For information about Markdown syntax and possible Admonitions/Highlighting etc. see
+the [Material for Mkdocs themes reference documentation](https://squidfunk.github.io/mkdocs-material/reference/abbreviations/).
diff --git a/hsml/Dockerfile b/hsml/Dockerfile
new file mode 100644
index 000000000..7f87ca293
--- /dev/null
+++ b/hsml/Dockerfile
@@ -0,0 +1,9 @@
+FROM ubuntu:20.04
+
+RUN apt-get update && \
+ apt-get install -y python3-pip git && apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+RUN pip3 install twine
+
+RUN mkdir -p /.local && chmod -R 777 /.local
diff --git a/hsml/Jenkinsfile b/hsml/Jenkinsfile
new file mode 100644
index 000000000..d2014d5cb
--- /dev/null
+++ b/hsml/Jenkinsfile
@@ -0,0 +1,23 @@
+pipeline {
+ agent {
+ docker {
+ label "local"
+ image "docker.hops.works/hopsworks_twine:0.0.1"
+ }
+ }
+ stages {
+ stage("publish") {
+ environment {
+ PYPI = credentials('977daeb0-e1c8-43a0-b35a-fc37bb9eee9b')
+ }
+ steps {
+ dir("python") {
+ sh "rm -f LICENSE README.md"
+ sh "cp -f ../LICENSE ../README.md ./"
+ sh "python3 -m build"
+ sh "twine upload -u $PYPI_USR -p $PYPI_PSW --skip-existing dist/*"
+ }
+ }
+ }
+ }
+}
diff --git a/hsml/LICENSE b/hsml/LICENSE
new file mode 100644
index 000000000..261eeb9e9
--- /dev/null
+++ b/hsml/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/hsml/README.md b/hsml/README.md
new file mode 100644
index 000000000..ee835ddc7
--- /dev/null
+++ b/hsml/README.md
@@ -0,0 +1,141 @@
+# Hopsworks Model Management
+
+
- needed to position the dropdown content */
+.dropdown {
+ position: absolute;
+ display: inline-block;
+}
+
+/* Dropdown Content (Hidden by Default) */
+.dropdown-content {
+ display:none;
+ font-size: 13px;
+ position: absolute;
+ background-color: #f9f9f9;
+ min-width: 160px;
+ box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2);
+ z-index: 1000;
+ border-radius: 2px;
+ left:-15px;
+}
+
+/* Links inside the dropdown */
+.dropdown-content a {
+ color: black;
+ padding: 12px 16px;
+ text-decoration: none;
+ display: block;
+}
+
+/* Change color of dropdown links on hover */
+.dropdown-content a:hover {background-color: #f1f1f1}
+
+/* Show the dropdown menu on hover */
+.dropdown:hover .dropdown-content {
+ display: block;
+}
+
+/* Change the background color of the dropdown button when the dropdown content is shown */
+.dropdown:hover .dropbtn {
+}
diff --git a/hsml/docs/css/marctech.css b/hsml/docs/css/marctech.css
new file mode 100644
index 000000000..8bb58c97b
--- /dev/null
+++ b/hsml/docs/css/marctech.css
@@ -0,0 +1,1047 @@
+:root {
+ --md-primary-fg-color: #1EB382;
+ --md-secondary-fg-color: #188a64;
+ --md-tertiary-fg-color: #0d493550;
+ --md-quaternary-fg-color: #fdfdfd;
+ --md-fiftuary-fg-color: #2471cf;
+ --border-radius-variable: 5px;
+ --border-width:1px;
+ }
+
+ .marctech_main a{
+ color: var(--md-fiftuary-fg-color);
+ border-bottom: 1px dotted var(--md-fiftuary-fg-color) !important;
+ text-decoration: dotted !important;}
+
+ .marctech_main a:hover{
+ border-bottom: 1px dotted var(--md-primary-fg-color)!important;
+ }
+
+ .marctech_main a:visited{
+ color: var(--md-tertiary-fg-color);
+ border-bottom: 1px dotted var(--md-tertiary-fg-color) !important;
+
+ }
+
+ .w-layout-grid {
+ display: -ms-grid;
+ display: grid;
+ grid-auto-columns: 1fr;
+ -ms-grid-columns: 1fr 1fr;
+ grid-template-columns: 1fr 1fr;
+ -ms-grid-rows: auto auto;
+ grid-template-rows: auto auto;
+ grid-row-gap: 16px;
+ grid-column-gap: 16px;
+ }
+
+ .image_logo{
+ width: 69%;
+ background-color: white;
+ z-index: 50;
+ padding: 0px 15px 0px 15px;
+ margin-bottom: 10px;
+ }
+
+ .layer_02{
+ pointer-events: none;
+ }
+
+ .round-frame{
+ pointer-events: initial;
+ }
+
+ .marctech_main {
+ margin-top:-20px;
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ margin-bottom: 55px;
+ }
+
+ .collumns {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ height: 100%;
+ -webkit-box-align: stretch;
+ -webkit-align-items: stretch;
+ -ms-flex-align: stretch;
+ align-items: stretch;
+ }
+
+ .col_heading {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ }
+
+ .enterprisefs {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ }
+
+ .enterprise_ai {
+ -webkit-align-self: center;
+ -ms-flex-item-align: center;
+ -ms-grid-row-align: center;
+ align-self: center;
+ -webkit-box-flex: 1;
+ -webkit-flex: 1;
+ -ms-flex: 1;
+ flex: 1;
+ }
+
+ .side-content {
+ z-index: 0;
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ width: 240px;
+ height: 100%;
+ margin-top: 10px;
+ margin-bottom: 10px;
+ padding: 20px 10px;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ -webkit-align-content: flex-start;
+ -ms-flex-line-pack: start;
+ align-content: flex-start;
+ border-style: solid;
+ border-width: var(--border-width);
+ border-color: #585858;
+ border-radius: 10px;
+ background-color:var(--md-quaternary-fg-color);
+ }
+ .body {
+ padding: 40px;
+ font-family: Roboto, sans-serif;
+ }
+
+ .green {
+ color: #1eb182;
+ font-size: 1.2vw;
+ }
+
+ .rec_frame {
+ position: relative;
+ z-index: 1;
+ display: inline-block;
+ min-width: 150px;
+ margin-top: 10px;
+ margin-right: 10px;
+ margin-left: 10px;
+ padding: 10px 10px;
+ border-style: solid;
+ border-width: var(--border-width);
+ border-color: #585858;
+ border-radius: 10px;
+ background-color: #fff;
+ box-shadow: 4px 4px 0 0 rgba(88, 88, 88, 0.16);
+ -webkit-transition: box-shadow 200ms ease, border-color 200ms ease;
+ transition: box-shadow 200ms ease, border-color 200ms ease;
+ color: #585858;
+ text-align: center;
+ cursor: pointer;
+ }
+
+ .rec_frame:hover {
+ border-color: #c2c2c2;
+ box-shadow: none;
+ }
+
+ .name_item {
+ font-size: 0.7rem;
+ line-height: 120%;
+ font-weight: 700;
+ }
+
+ .name_item.db {
+ position: relative;
+ z-index: 3;
+ text-align: left;
+ }
+
+ .name_item.small {
+ font-size: 0.6rem;
+ font-weight: 500;
+ }
+
+ .name_item.ingrey {
+ padding-bottom: 20px;
+ }
+
+ .db_frame-mid {
+ position: relative;
+ z-index: 1;
+ margin-top: -8px;
+ padding: 5px 2px;
+ border-style: solid;
+ border-width: var(--border-width);
+ border-color: #585858;
+ border-radius: 0px 0% 50% 50%;
+ background-color: #fff;
+ color: #585858;
+ text-align: center;
+ }
+
+ .db_frame-top {
+ position: relative;
+ z-index: 2;
+ padding: 5px 2px;
+ border-style: solid;
+ border-width: var(--border-width);
+ border-color: #585858;
+ border-radius: 50%;
+ background-color: #fff;
+ color: #585858;
+ text-align: center;
+ }
+
+ .icondb {
+ position: relative;
+ width: 25px;
+ min-width: 25px;
+ margin-right: 10px;
+ }
+
+ .db_frame {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ width: 150px;
+ height: 55px;
+ padding: 20px 10px;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ border-style: solid;
+ border-width: var(--border-width);
+ border-color: #585858;
+ border-radius: 10px;
+ background-color: #fff;
+ box-shadow: 4px 4px 0 0 rgba(88, 88, 88, 0.16);
+ -webkit-transition: box-shadow 200ms ease, border-color 200ms ease;
+ transition: box-shadow 200ms ease, border-color 200ms ease;
+ color: #585858;
+ text-align: center;
+ cursor: pointer;
+ }
+
+ .db_frame:hover {
+ border-color: #c2c2c2;
+ box-shadow: none;
+ }
+
+ .grid {
+ -ms-grid-rows: auto auto auto;
+ grid-template-rows: auto auto auto;
+ }
+
+ .arrowdown {
+ position: relative;
+ z-index: 0;
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ margin-top: -10px;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ }
+
+ .heading_MT {
+ margin-top: 0px !important;
+ margin-bottom: 0px !important;
+ font-size: 1.3rem !important;
+ white-space: nowrap !important;
+ }
+
+ .head_col {
+ padding-left: 10px;
+ }
+
+ .MT_heading3 {
+ margin-top: 0px !important ;
+ font-size: 0.8rem !important;
+ }
+
+ .MT_heading3.green {
+ color: #1eb182 !important;
+ }
+
+ .column_sides {
+ position: relative;
+ z-index: 2;
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-pack: justify;
+ -webkit-justify-content: space-between;
+ -ms-flex-pack: justify;
+ justify-content: space-between;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ }
+
+ .hopsicon {
+ width: 45px;
+ height: 45px;
+ }
+
+ .column_center {
+ z-index: 10;
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ }
+
+ .center-content {
+ z-index: -50;
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ width: 750px;
+ height: 670px;
+ margin-top: 10px;
+ margin-bottom: 10px;
+ padding: 20px 10px;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ -webkit-align-content: center;
+ -ms-flex-line-pack: center;
+ align-content: center;
+ border-radius: 10px;
+ background-color: transparent;
+ }
+
+ .image {
+ width: 260px;
+ }
+
+ .layer_01 {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ -webkit-box-align: stretch;
+ -webkit-align-items: stretch;
+ -ms-flex-align: stretch;
+ align-items: stretch;
+ }
+
+ .name_center {
+ font-size: 1rem;
+ font-weight: 700;
+ }
+
+ .rec_frame_main {
+ position: relative;
+ z-index: 1;
+ margin-top: 10px;
+ margin-right: 10px;
+ margin-left: 10px;
+ padding: 5px 10px;
+ border-style: solid;
+ border-width: var(--border-width);
+ border-color: #1eb182;
+ border-radius: 10px;
+ background-color: #e6fdf6;
+ box-shadow: 4px 4px 0 0 #dcf7ee;
+ -webkit-transition: box-shadow 200ms ease, border-color 200ms ease;
+ transition: box-shadow 200ms ease, border-color 200ms ease;
+ color: #1eb182;
+ text-align: center;
+ cursor: pointer;
+ }
+
+ .rec_frame_main:hover {
+ border-color: #9fecd4;
+ box-shadow: none;
+ }
+
+ .rec_frame_main.no_content {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ height: 100%;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ box-shadow: 4px 4px 0 0 #dcf7ee;
+ }
+
+ .rec_frame_main.no_content:hover {
+ border-color: #1eb182;
+ box-shadow: 4px 4px 0 0 rgba(88, 88, 88, 0.16);
+ }
+
+ .name_item_02 {
+ font-size: 0.85rem;
+ font-weight: 700;
+ }
+
+ .grid-infra {
+ padding-top: 20px;
+ -ms-grid-columns: 1fr 1fr 1fr 1fr;
+ grid-template-columns: 1fr 1fr 1fr 1fr;
+ -ms-grid-rows: auto;
+ grid-template-rows: auto;
+ }
+
+ .rec_frame_main-white {
+ position: relative;
+ z-index: 1;
+ display: inline-block;
+ width: 100%;
+ margin-top: 10px;
+ margin-bottom: 10px;
+ padding: 5px 10px;
+ border-style: solid;
+ border-width: var(--border-width);
+ border-color: #1eb182;
+ border-radius: 10px;
+ background-color: #fff;
+ box-shadow: 4px 4px 0 0 rgba(88, 88, 88, 0.16);
+ -webkit-transition: box-shadow 200ms ease, border-color 200ms ease;
+ transition: box-shadow 200ms ease, border-color 200ms ease;
+ color: #1eb182;
+ text-align: center;
+ cursor: pointer;
+ }
+
+ .rec_frame_main-white:hover {
+ border-color: #c2c2c2;
+ box-shadow: none;
+ }
+
+ .rec_frame_main-white.dotted {
+ border-style: dotted;
+ }
+
+ .column {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-pack: justify;
+ -webkit-justify-content: space-between;
+ -ms-flex-pack: justify;
+ justify-content: space-between;
+ -webkit-box-align: stretch;
+ -webkit-align-items: stretch;
+ -ms-flex-align: stretch;
+ align-items: stretch;
+ }
+
+ .columns_center {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ -webkit-box-orient: horizontal;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: row;
+ -ms-flex-direction: row;
+ flex-direction: row;
+ -webkit-box-pack: justify;
+ -webkit-justify-content: space-between;
+ -ms-flex-pack: justify;
+ justify-content: space-between;
+ }
+
+ .non-bold {
+ font-weight: 400;
+ }
+
+ .logo-holder {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ }
+
+ .infra {
+ text-align: center;
+ position: relative;
+ z-index: 30;
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ padding: 10px;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ border: 1px dashed #000;
+ border-radius: 6px;
+ background-color: #fff;
+ cursor: pointer;
+ }
+
+ .infra:hover {
+ border-style: solid;
+ border-color: #585858;
+ }
+
+ .text_and_icon {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ }
+
+ .svg_icon {
+ width: 33px;
+ margin-right: 10px;
+ margin-left: 10px;
+ }
+
+ .layer_02 {
+ position: absolute;
+ z-index: 10;
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ width: 96%;
+ height: 90%;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ -webkit-box-align: stretch;
+ -webkit-align-items: stretch;
+ -ms-flex-align: stretch;
+ align-items: stretch;
+ border-style: solid;
+ border-width: calc (var(--border-width)*2);
+ border-color: #bbbbbb50 ;
+ border-radius: 100%;
+ background-color: transparent;
+ }
+
+ .round-frame {
+ position: absolute;
+ left: 0%;
+ top: auto;
+ right: auto;
+ bottom: 0%;
+ z-index: 10;
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ width: 120px;
+ height: 120px;
+ margin: 10px;
+ padding: 20px;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ -webkit-box-align: center;
+ -webkit-align-items: center;
+ -ms-flex-align: center;
+ align-items: center;
+ border-style: solid;
+ border-width: var(--border-width);
+ border-color: #585858;
+ border-radius: 100%;
+ background-color: #fff;
+ outline-color: #fff;
+ outline-offset: 0px;
+ outline-style: solid;
+ outline-width: 7px;
+ -webkit-transition: box-shadow 200ms ease, border-color 200ms ease;
+ transition: box-shadow 200ms ease, border-color 200ms ease;
+ color: #585858;
+ text-align: center;
+ cursor: pointer;
+ }
+
+ .round-frame:hover {
+ border-color: #c2c2c2;
+ box-shadow: none;
+ }
+
+ .round-frame.top-left {
+ left: 4%;
+ top: 15%;
+ right: auto;
+ bottom: auto;
+ }
+
+ .round-frame.bottom-left {
+ left: 4%;
+ bottom: 15%;
+ }
+
+ .round-frame.top-right {
+ left: auto;
+ top: 15%;
+ right: 4%;
+ bottom: auto;
+ }
+
+ .round-frame.bottom-right {
+ left: auto;
+ top: auto;
+ right: 4%;
+ bottom: 15%;
+ padding: 10px;
+ }
+
+ .side-holder {
+ z-index: -1;
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ height: 630px;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-pack: center;
+ -webkit-justify-content: center;
+ -ms-flex-pack: center;
+ justify-content: center;
+ }
+
+ .infra-icon {
+ width: 25px;
+ height: 25px;
+ }
+
+ .div-block {
+ display: -webkit-box;
+ display: -webkit-flex;
+ display: -ms-flexbox;
+ display: flex;
+ height: 100%;
+ -webkit-box-orient: vertical;
+ -webkit-box-direction: normal;
+ -webkit-flex-direction: column;
+ -ms-flex-direction: column;
+ flex-direction: column;
+ -webkit-box-pack: justify;
+ -webkit-justify-content: space-between;
+ -ms-flex-pack: justify;
+ justify-content: space-between;
+ }
+
+ #w-node-a2a9b648-f5dd-74e5-e1c2-f7aaf4fa1fcd-46672785 {
+ -ms-grid-column: span 1;
+ grid-column-start: span 1;
+ -ms-grid-column-span: 1;
+ grid-column-end: span 1;
+ -ms-grid-row: span 1;
+ grid-row-start: span 1;
+ -ms-grid-row-span: 1;
+ grid-row-end: span 1;
+ }
+
+ #w-node-_466aa2bf-88bf-5a65-eab4-fc1eb95e7384-46672785 {
+ -ms-grid-column: span 1;
+ grid-column-start: span 1;
+ -ms-grid-column-span: 1;
+ grid-column-end: span 1;
+ -ms-grid-row: span 1;
+ grid-row-start: span 1;
+ -ms-grid-row-span: 1;
+ grid-row-end: span 1;
+ }
+
+ #w-node-_87009ba3-d9a6-e0b7-4cce-581190a19cf3-46672785 {
+ -ms-grid-column: span 1;
+ grid-column-start: span 1;
+ -ms-grid-column-span: 1;
+ grid-column-end: span 1;
+ -ms-grid-row: span 1;
+ grid-row-start: span 1;
+ -ms-grid-row-span: 1;
+ grid-row-end: span 1;
+ }
+
+ #w-node-_4a479fbb-90c7-9f47-d439-20aa6a224339-46672785 {
+ -ms-grid-column: span 1;
+ grid-column-start: span 1;
+ -ms-grid-column-span: 1;
+ grid-column-end: span 1;
+ -ms-grid-row: span 1;
+ grid-row-start: span 1;
+ -ms-grid-row-span: 1;
+ grid-row-end: span 1;
+ }
+
+
+ /*
+
+
+ inherited from the original template
+
+ */
+
+ .w-container .w-row {
+ margin-left: -10px;
+ margin-right: -10px;
+ }
+ .w-row:before,
+ .w-row:after {
+ content: " ";
+ display: table;
+ grid-column-start: 1;
+ grid-row-start: 1;
+ grid-column-end: 2;
+ grid-row-end: 2;
+ }
+ .w-row:after {
+ clear: both;
+ }
+ .w-row .w-row {
+ margin-left: 0;
+ margin-right: 0;
+ }
+ .w-col {
+ position: relative;
+ float: left;
+ width: 100%;
+ min-height: 1px;
+ padding-left: 10px;
+ padding-right: 10px;
+ }
+ .w-col .w-col {
+ padding-left: 0;
+ padding-right: 0;
+ }
+ .w-col-1 {
+ width: 8.33333333%;
+ }
+ .w-col-2 {
+ width: 16.66666667%;
+ }
+ .w-col-3 {
+ width: 25%;
+ }
+ .w-col-4 {
+ width: 33.33333333%;
+ }
+ .w-col-5 {
+ width: 41.66666667%;
+ }
+ .w-col-6 {
+ width: 50%;
+ }
+ .w-col-7 {
+ width: 58.33333333%;
+ }
+ .w-col-8 {
+ width: 66.66666667%;
+ }
+ .w-col-9 {
+ width: 75%;
+ }
+ .w-col-10 {
+ width: 83.33333333%;
+ }
+ .w-col-11 {
+ width: 91.66666667%;
+ }
+ .w-col-12 {
+ width: 100%;
+ }
+ .w-hidden-main {
+ display: none !important;
+ }
+ @media screen and (max-width: 991px) {
+ .w-container {
+ max-width: 728px;
+ }
+ .w-hidden-main {
+ display: inherit !important;
+ }
+ .w-hidden-medium {
+ display: none !important;
+ }
+ .w-col-medium-1 {
+ width: 8.33333333%;
+ }
+ .w-col-medium-2 {
+ width: 16.66666667%;
+ }
+ .w-col-medium-3 {
+ width: 25%;
+ }
+ .w-col-medium-4 {
+ width: 33.33333333%;
+ }
+ .w-col-medium-5 {
+ width: 41.66666667%;
+ }
+ .w-col-medium-6 {
+ width: 50%;
+ }
+ .w-col-medium-7 {
+ width: 58.33333333%;
+ }
+ .w-col-medium-8 {
+ width: 66.66666667%;
+ }
+ .w-col-medium-9 {
+ width: 75%;
+ }
+ .w-col-medium-10 {
+ width: 83.33333333%;
+ }
+ .w-col-medium-11 {
+ width: 91.66666667%;
+ }
+ .w-col-medium-12 {
+ width: 100%;
+ }
+ .w-col-stack {
+ width: 100%;
+ left: auto;
+ right: auto;
+ }
+ }
+ @media screen and (max-width: 767px) {
+ .w-hidden-main {
+ display: inherit !important;
+ }
+ .w-hidden-medium {
+ display: inherit !important;
+ }
+ .w-hidden-small {
+ display: none !important;
+ }
+ .w-row,
+ .w-container .w-row {
+ margin-left: 0;
+ margin-right: 0;
+ }
+ .w-col {
+ width: 100%;
+ left: auto;
+ right: auto;
+ }
+ .w-col-small-1 {
+ width: 8.33333333%;
+ }
+ .w-col-small-2 {
+ width: 16.66666667%;
+ }
+ .w-col-small-3 {
+ width: 25%;
+ }
+ .w-col-small-4 {
+ width: 33.33333333%;
+ }
+ .w-col-small-5 {
+ width: 41.66666667%;
+ }
+ .w-col-small-6 {
+ width: 50%;
+ }
+ .w-col-small-7 {
+ width: 58.33333333%;
+ }
+ .w-col-small-8 {
+ width: 66.66666667%;
+ }
+ .w-col-small-9 {
+ width: 75%;
+ }
+ .w-col-small-10 {
+ width: 83.33333333%;
+ }
+ .w-col-small-11 {
+ width: 91.66666667%;
+ }
+ .w-col-small-12 {
+ width: 100%;
+ }
+ }
+ @media screen and (max-width: 479px) {
+ .w-container {
+ max-width: none;
+ }
+ .w-hidden-main {
+ display: inherit !important;
+ }
+ .w-hidden-medium {
+ display: inherit !important;
+ }
+ .w-hidden-small {
+ display: inherit !important;
+ }
+ .w-hidden-tiny {
+ display: none !important;
+ }
+ .w-col {
+ width: 100%;
+ }
+ .w-col-tiny-1 {
+ width: 8.33333333%;
+ }
+ .w-col-tiny-2 {
+ width: 16.66666667%;
+ }
+ .w-col-tiny-3 {
+ width: 25%;
+ }
+ .w-col-tiny-4 {
+ width: 33.33333333%;
+ }
+ .w-col-tiny-5 {
+ width: 41.66666667%;
+ }
+ .w-col-tiny-6 {
+ width: 50%;
+ }
+ .w-col-tiny-7 {
+ width: 58.33333333%;
+ }
+ .w-col-tiny-8 {
+ width: 66.66666667%;
+ }
+ .w-col-tiny-9 {
+ width: 75%;
+ }
+ .w-col-tiny-10 {
+ width: 83.33333333%;
+ }
+ .w-col-tiny-11 {
+ width: 91.66666667%;
+ }
+ .w-col-tiny-12 {
+ width: 100%;
+ }
+ }
diff --git a/hsml/docs/css/version-select.css b/hsml/docs/css/version-select.css
new file mode 100644
index 000000000..3b908ae84
--- /dev/null
+++ b/hsml/docs/css/version-select.css
@@ -0,0 +1,36 @@
+@media only screen and (max-width:76.1875em) {
+}
+
+#version-selector select.form-control {
+ appearance: none;
+ -webkit-appearance: none;
+ -moz-appearance: none;
+
+ background-color: #F5F5F5;
+
+ background-position: center right;
+ background-repeat: no-repeat;
+ border: 0px;
+ border-radius: 2px;
+ /* box-shadow: 0px 1px 3px rgb(0 0 0 / 10%); */
+ color: inherit;
+ width: -webkit-fill-available;
+ width: -moz-available;
+ max-width: 200px;
+ font-size: inherit;
+ /* font-weight: 600; */
+ margin: 10px;
+ overflow: hidden;
+ padding: 7px 10px;
+ text-overflow: ellipsis;
+ white-space: nowrap;
+}
+
+#version-selector::after {
+ content: '⌄';
+ font-family: inherit;
+ font-size: 22px;
+ margin: -35px;
+ vertical-align: 7%;
+ padding-bottom: 10px;
+}
diff --git a/hsml/docs/index.md b/hsml/docs/index.md
new file mode 100644
index 000000000..ee835ddc7
--- /dev/null
+++ b/hsml/docs/index.md
@@ -0,0 +1,141 @@
+# Hopsworks Model Management
+
+
+
+
+
+
+
+
+
+
+
+
+HSML is the library to interact with the Hopsworks Model Registry and Model Serving. The library makes it easy to export, manage and deploy models.
+
+However, to connect from an external Python environment additional connection information, such as host and port, is required.
+
+## Getting Started On Hopsworks
+
+Get started easily by registering an account on [Hopsworks Serverless](https://app.hopsworks.ai/). Create your project and a [new Api key](https://docs.hopsworks.ai/latest/user_guides/projects/api_key/create_api_key/). In a new python environment with Python 3.8 or higher, install the [client library](https://docs.hopsworks.ai/latest/user_guides/client_installation/) using pip:
+
+```bash
+# Get all Hopsworks SDKs: Feature Store, Model Serving and Platform SDK
+pip install hopsworks
+# or just the Model Registry and Model Serving SDK
+pip install hsml
+```
+
+You can start a notebook and instantiate a connection and get the project feature store handler.
+
+```python
+import hopsworks
+
+project = hopsworks.login() # you will be prompted for your api key
+
+mr = project.get_model_registry()
+# or
+ms = project.get_model_serving()
+```
+
+or using `hsml` directly:
+
+```python
+import hsml
+
+connection = hsml.connection(
+ host="c.app.hopsworks.ai", #
+ project="your-project",
+ api_key_value="your-api-key",
+)
+
+mr = connection.get_model_registry()
+# or
+ms = connection.get_model_serving()
+```
+
+Create a new model
+```python
+model = mr.tensorflow.create_model(name="mnist",
+ version=1,
+ metrics={"accuracy": 0.94},
+ description="mnist model description")
+model.save("/tmp/model_directory") # or /tmp/model_file
+```
+
+Download a model
+```python
+model = mr.get_model("mnist", version=1)
+
+model_path = model.download()
+```
+
+Delete a model
+```python
+model.delete()
+```
+
+Get best performing model
+```python
+best_model = mr.get_best_model('mnist', 'accuracy', 'max')
+
+```
+
+Deploy a model
+```python
+deployment = model.deploy()
+```
+
+Start a deployment
+```python
+deployment.start()
+```
+
+Make predictions with a deployed model
+```python
+data = { "instances": [ model.input_example ] }
+
+predictions = deployment.predict(data)
+```
+
+# Tutorials
+
+You can find more examples on how to use the library in our [tutorials](https://github.com/logicalclocks/hopsworks-tutorials).
+
+## Documentation
+
+Documentation is available at [Hopsworks Model Management Documentation](https://docs.hopsworks.ai/).
+
+## Issues
+
+For general questions about the usage of Hopsworks Machine Learning please open a topic on [Hopsworks Community](https://community.hopsworks.ai/).
+Please report any issue using [Github issue tracking](https://github.com/logicalclocks/machine-learning-api/issues).
+
+
+## Contributing
+
+If you would like to contribute to this library, please see the [Contribution Guidelines](CONTRIBUTING.md).
diff --git a/hsml/docs/js/dropdown.js b/hsml/docs/js/dropdown.js
new file mode 100644
index 000000000..b897ba36a
--- /dev/null
+++ b/hsml/docs/js/dropdown.js
@@ -0,0 +1,2 @@
+document.getElementsByClassName("md-tabs__link")[7].style.display = "none";
+document.getElementsByClassName("md-tabs__link")[9].style.display = "none";
\ No newline at end of file
diff --git a/hsml/docs/js/inject-api-links.js b/hsml/docs/js/inject-api-links.js
new file mode 100644
index 000000000..6c8a4a3b3
--- /dev/null
+++ b/hsml/docs/js/inject-api-links.js
@@ -0,0 +1,31 @@
+window.addEventListener("DOMContentLoaded", function () {
+ var windowPathNameSplits = window.location.pathname.split("/");
+ var majorVersionRegex = new RegExp("(\\d+[.]\\d+)")
+ var latestRegex = new RegExp("latest");
+ if (majorVersionRegex.test(windowPathNameSplits[1])) { // On landing page docs.hopsworks.api/3.0 - URL contains major version
+ // Version API dropdown
+ document.getElementById("hopsworks_api_link").href = "https://docs.hopsworks.ai/hopsworks-api/" + windowPathNameSplits[1] + "/generated/api/login/";
+ document.getElementById("hsfs_api_link").href = "https://docs.hopsworks.ai/feature-store-api/" + windowPathNameSplits[1] + "/generated/api/connection_api/";
+ document.getElementById("hsml_api_link").href = "https://docs.hopsworks.ai/machine-learning-api/" + windowPathNameSplits[1] + "/generated/connection_api/";
+ } else { // on docs.hopsworks.api/feature-store-api/3.0 / docs.hopsworks.api/hopsworks-api/3.0 / docs.hopsworks.api/machine-learning-api/3.0
+ if (latestRegex.test(windowPathNameSplits[2]) || latestRegex.test(windowPathNameSplits[1])) {
+ var majorVersion = "latest";
+ } else {
+ var apiVersion = windowPathNameSplits[2];
+ var majorVersion = apiVersion.match(majorVersionRegex)[0];
+ }
+ // Version main navigation
+ document.getElementsByClassName("md-tabs__link")[0].href = "https://docs.hopsworks.ai/" + majorVersion;
+ document.getElementsByClassName("md-tabs__link")[1].href = "https://colab.research.google.com/github/logicalclocks/hopsworks-tutorials/blob/master/quickstart.ipynb";
+ document.getElementsByClassName("md-tabs__link")[2].href = "https://docs.hopsworks.ai/" + majorVersion + "/tutorials/";
+ document.getElementsByClassName("md-tabs__link")[3].href = "https://docs.hopsworks.ai/" + majorVersion + "/concepts/hopsworks/";
+ document.getElementsByClassName("md-tabs__link")[4].href = "https://docs.hopsworks.ai/" + majorVersion + "/user_guides/";
+ document.getElementsByClassName("md-tabs__link")[5].href = "https://docs.hopsworks.ai/" + majorVersion + "/setup_installation/aws/getting_started/";
+ document.getElementsByClassName("md-tabs__link")[6].href = "https://docs.hopsworks.ai/" + majorVersion + "/admin/";
+ // Version API dropdown
+ document.getElementById("hopsworks_api_link").href = "https://docs.hopsworks.ai/hopsworks-api/" + majorVersion + "/generated/api/login/";
+ document.getElementById("hsfs_api_link").href = "https://docs.hopsworks.ai/feature-store-api/" + majorVersion + "/generated/api/connection_api/";
+ document.getElementById("hsfs_javadoc_link").href = "https://docs.hopsworks.ai/feature-store-api/" + majorVersion + "/javadoc";
+ document.getElementById("hsml_api_link").href = "https://docs.hopsworks.ai/machine-learning-api/" + majorVersion + "/generated/connection_api/";
+ }
+});
diff --git a/hsml/docs/js/version-select.js b/hsml/docs/js/version-select.js
new file mode 100644
index 000000000..9c8331660
--- /dev/null
+++ b/hsml/docs/js/version-select.js
@@ -0,0 +1,64 @@
+window.addEventListener("DOMContentLoaded", function() {
+ // This is a bit hacky. Figure out the base URL from a known CSS file the
+ // template refers to...
+ var ex = new RegExp("/?css/version-select.css$");
+ var sheet = document.querySelector('link[href$="version-select.css"]');
+
+ var ABS_BASE_URL = sheet.href.replace(ex, "");
+ var CURRENT_VERSION = ABS_BASE_URL.split("/").pop();
+
+ function makeSelect(options, selected) {
+ var select = document.createElement("select");
+ select.classList.add("form-control");
+
+ options.forEach(function(i) {
+ var option = new Option(i.text, i.value, undefined,
+ i.value === selected);
+ select.add(option);
+ });
+
+ return select;
+ }
+
+ var xhr = new XMLHttpRequest();
+ xhr.open("GET", ABS_BASE_URL + "/../versions.json");
+ xhr.onload = function() {
+ var versions = JSON.parse(this.responseText);
+
+ var realVersion = versions.find(function(i) {
+ return i.version === CURRENT_VERSION ||
+ i.aliases.includes(CURRENT_VERSION);
+ }).version;
+ var latestVersion = versions.find(function(i) {
+ return i.aliases.includes("latest");
+ }).version;
+ let outdated_banner = document.querySelector('div[data-md-color-scheme="default"][data-md-component="outdated"]');
+ if (realVersion !== latestVersion) {
+ outdated_banner.removeAttribute("hidden");
+ } else {
+ outdated_banner.setAttribute("hidden", "");
+ }
+
+ var select = makeSelect(versions.map(function(i) {
+ var allowedAliases = ["dev", "latest"]
+ if (i.aliases.length > 0) {
+ var aliasString = " [" + i.aliases.filter(function (str) { return allowedAliases.includes(str); }).join(", ") + "]";
+ } else {
+ var aliasString = "";
+ }
+ return {text: i.title + aliasString, value: i.version};
+ }), realVersion);
+ select.addEventListener("change", function(event) {
+ window.location.href = ABS_BASE_URL + "/../" + this.value + "/generated/connection_api/";
+ });
+
+ var container = document.createElement("div");
+ container.id = "version-selector";
+ // container.className = "md-nav__item";
+ container.appendChild(select);
+
+ var sidebar = document.querySelector(".md-nav--primary > .md-nav__list");
+ sidebar.parentNode.insertBefore(container, sidebar.nextSibling);
+ };
+ xhr.send();
+});
diff --git a/hsml/docs/overrides/main.html b/hsml/docs/overrides/main.html
new file mode 100644
index 000000000..a1bc45bb5
--- /dev/null
+++ b/hsml/docs/overrides/main.html
@@ -0,0 +1,8 @@
+{% extends "base.html" %}
+
+{% block outdated %}
+You're not viewing the latest version of the documentation.
+
+ Click here to go to latest.
+
+{% endblock %}
\ No newline at end of file
diff --git a/hsml/docs/templates/connection_api.md b/hsml/docs/templates/connection_api.md
new file mode 100644
index 000000000..19e13f3eb
--- /dev/null
+++ b/hsml/docs/templates/connection_api.md
@@ -0,0 +1,11 @@
+# Connection
+
+{{connection}}
+
+## Properties
+
+{{connection_properties}}
+
+## Methods
+
+{{connection_methods}}
diff --git a/hsml/docs/templates/model-registry/links.md b/hsml/docs/templates/model-registry/links.md
new file mode 100644
index 000000000..07abe3177
--- /dev/null
+++ b/hsml/docs/templates/model-registry/links.md
@@ -0,0 +1,15 @@
+# Provenance Links
+
+Provenance Links are objects returned by methods such as [get_feature_view_provenance](../model_api/#get_feature_view_provenance), [get_training_dataset_provenance](../model_api/#get_training_dataset_provenance). These methods use the provenance graph to return the parent feature view/training dataset of a model. These methods will return the actual instances of the feature view/training dataset if available. If the instance was deleted, or it belongs to a featurestore that the current project doesn't have access anymore, an Artifact object is returned.
+
+There is an additional method using the provenance graph: [get_feature_view](../model_api/#get_feature_view). This method wraps the `get_feature_view_provenance` and always returns a correct, usable Feature View object or throws an exception if the returned object is an Artifact. Thus an exception is thrown if the feature view was deleted or the featurestore it belongs to was unshared.
+## Properties
+
+{{links_properties}}
+
+# Artifact
+
+Artifacts objects are part of the provenance graph and contain a minimal set of information regarding the entities (feature views, training datasets) they represent.
+The provenance graph contains Artifact objects when the underlying entities have been deleted or they are corrupted or they are not accessible by the current project anymore.
+
+{{artifact_properties}}
diff --git a/hsml/docs/templates/model-registry/model_api.md b/hsml/docs/templates/model-registry/model_api.md
new file mode 100644
index 000000000..edb2e5ade
--- /dev/null
+++ b/hsml/docs/templates/model-registry/model_api.md
@@ -0,0 +1,29 @@
+# Model
+
+## Creation of a TensorFlow model
+
+{{ml_create_tf}}
+
+## Creation of a Torch model
+
+{{ml_create_th}}
+
+## Creation of a scikit-learn model
+
+{{ml_create_sl}}
+
+## Creation of a generic model
+
+{{ml_create_py}}
+
+## Retrieval
+
+{{ml_get}}
+
+## Properties
+
+{{ml_properties}}
+
+## Methods
+
+{{ml_methods}}
diff --git a/hsml/docs/templates/model-registry/model_registry_api.md b/hsml/docs/templates/model-registry/model_registry_api.md
new file mode 100644
index 000000000..d577e91e3
--- /dev/null
+++ b/hsml/docs/templates/model-registry/model_registry_api.md
@@ -0,0 +1,17 @@
+# Model Registry
+
+## Retrieval
+
+{{mr_get}}
+
+## Modules
+
+{{mr_modules}}
+
+## Properties
+
+{{mr_properties}}
+
+## Methods
+
+{{mr_methods}}
diff --git a/hsml/docs/templates/model-registry/model_schema_api.md b/hsml/docs/templates/model-registry/model_schema_api.md
new file mode 100644
index 000000000..28170a419
--- /dev/null
+++ b/hsml/docs/templates/model-registry/model_schema_api.md
@@ -0,0 +1,36 @@
+# Model Schema
+
+## Creation
+
+To create a ModelSchema, the schema of the Model inputs and/or Model ouputs has to be defined beforehand.
+
+{{schema}}
+
+After defining the Model inputs and/or outputs schemas, a ModelSchema can be created using its class constructor.
+
+{{model_schema}}
+
+## Retrieval
+
+### Model Schema
+
+Model schemas can be accessed from the model metadata objects.
+
+``` python
+model.model_schema
+```
+
+### Model Input & Ouput Schemas
+
+The schemas of the Model inputs and outputs can be accessed from the ModelSchema metadata objects.
+
+``` python
+model_schema.input_schema
+model_schema.output_schema
+```
+
+## Methods
+
+{{schema_dict}}
+
+{{model_schema_dict}}
diff --git a/hsml/docs/templates/model-serving/deployment_api.md b/hsml/docs/templates/model-serving/deployment_api.md
new file mode 100644
index 000000000..aebccca55
--- /dev/null
+++ b/hsml/docs/templates/model-serving/deployment_api.md
@@ -0,0 +1,25 @@
+# Deployment
+
+## Handle
+
+{{ms_get_model_serving}}
+
+## Creation
+
+{{ms_create_deployment}}
+
+{{m_deploy}}
+
+{{p_deploy}}
+
+## Retrieval
+
+{{ms_get_deployments}}
+
+## Properties
+
+{{dep_properties}}
+
+## Methods
+
+{{dep_methods}}
diff --git a/hsml/docs/templates/model-serving/inference_batcher_api.md b/hsml/docs/templates/model-serving/inference_batcher_api.md
new file mode 100644
index 000000000..3a2609962
--- /dev/null
+++ b/hsml/docs/templates/model-serving/inference_batcher_api.md
@@ -0,0 +1,25 @@
+# Inference batcher
+
+## Creation
+
+{{ib}}
+
+## Retrieval
+
+### predictor.inference_batcher
+
+Inference batchers can be accessed from the predictor metadata objects.
+
+``` python
+predictor.inference_batcher
+```
+
+Predictors can be found in the deployment metadata objects (see [Predictor Reference](../predictor_api/#retrieval)). To retrieve a deployment, see the [Deployment Reference](../deployment_api/#retrieval).
+
+## Properties
+
+{{ib_properties}}
+
+## Methods
+
+{{ib_methods}}
diff --git a/hsml/docs/templates/model-serving/inference_logger_api.md b/hsml/docs/templates/model-serving/inference_logger_api.md
new file mode 100644
index 000000000..2cf68d652
--- /dev/null
+++ b/hsml/docs/templates/model-serving/inference_logger_api.md
@@ -0,0 +1,25 @@
+# Inference logger
+
+## Creation
+
+{{il}}
+
+## Retrieval
+
+### predictor.inference_logger
+
+Inference loggers can be accessed from the predictor metadata objects.
+
+``` python
+predictor.inference_logger
+```
+
+Predictors can be found in the deployment metadata objects (see [Predictor Reference](../predictor_api/#retrieval)). To retrieve a deployment, see the [Deployment Reference](../deployment_api/#retrieval).
+
+## Properties
+
+{{il_properties}}
+
+## Methods
+
+{{il_methods}}
diff --git a/hsml/docs/templates/model-serving/model_serving_api.md b/hsml/docs/templates/model-serving/model_serving_api.md
new file mode 100644
index 000000000..0eb557213
--- /dev/null
+++ b/hsml/docs/templates/model-serving/model_serving_api.md
@@ -0,0 +1,13 @@
+# Model Serving
+
+## Retrieval
+
+{{ms_get}}
+
+## Properties
+
+{{ms_properties}}
+
+## Methods
+
+{{ms_methods}}
diff --git a/hsml/docs/templates/model-serving/predictor_api.md b/hsml/docs/templates/model-serving/predictor_api.md
new file mode 100644
index 000000000..3dd9df195
--- /dev/null
+++ b/hsml/docs/templates/model-serving/predictor_api.md
@@ -0,0 +1,29 @@
+# Predictor
+
+## Handle
+
+{{ms_get_model_serving}}
+
+## Creation
+
+{{ms_create_predictor}}
+
+## Retrieval
+
+### deployment.predictor
+
+Predictors can be accessed from the deployment metadata objects.
+
+``` python
+deployment.predictor
+```
+
+To retrieve a deployment, see the [Deployment Reference](../deployment_api/#retrieval).
+
+## Properties
+
+{{pred_properties}}
+
+## Methods
+
+{{pred_methods}}
diff --git a/hsml/docs/templates/model-serving/predictor_state_api.md b/hsml/docs/templates/model-serving/predictor_state_api.md
new file mode 100644
index 000000000..2640b9b48
--- /dev/null
+++ b/hsml/docs/templates/model-serving/predictor_state_api.md
@@ -0,0 +1,18 @@
+# Deployment state
+
+The state of a deployment corresponds to the state of the predictor configured in it.
+
+!!! note
+ Currently, only one predictor is supported in a deployment. Support for multiple predictors (the inference graphs) is coming soon.
+
+## Retrieval
+
+{{ps_get}}
+
+## Properties
+
+{{ps_properties}}
+
+## Methods
+
+{{ps_methods}}
diff --git a/hsml/docs/templates/model-serving/predictor_state_condition_api.md b/hsml/docs/templates/model-serving/predictor_state_condition_api.md
new file mode 100644
index 000000000..e1566d2b1
--- /dev/null
+++ b/hsml/docs/templates/model-serving/predictor_state_condition_api.md
@@ -0,0 +1,15 @@
+# Deployment state condition
+
+The state condition of a deployment is a more detailed representation of a deployment state.
+
+## Retrieval
+
+{{psc_get}}
+
+## Properties
+
+{{psc_properties}}
+
+## Methods
+
+{{psc_methods}}
diff --git a/hsml/docs/templates/model-serving/resources_api.md b/hsml/docs/templates/model-serving/resources_api.md
new file mode 100644
index 000000000..addc7f51e
--- /dev/null
+++ b/hsml/docs/templates/model-serving/resources_api.md
@@ -0,0 +1,35 @@
+# Resources
+
+## Creation
+
+{{res}}
+
+## Retrieval
+
+### predictor.resources
+
+Resources allocated for a preditor can be accessed from the predictor metadata object.
+
+``` python
+predictor.resources
+```
+
+Predictors can be found in the deployment metadata objects (see [Predictor Reference](../predictor_api/#retrieval)). To retrieve a deployment, see the [Deployment Reference](../deployment_api/#retrieval).
+
+### transformer.resources
+
+Resources allocated for a transformer can be accessed from the transformer metadata object.
+
+``` python
+transformer.resources
+```
+
+Transformer can be found in the predictor metadata objects (see [Predictor Reference](../predictor_api/#retrieval)).
+
+## Properties
+
+{{res_properties}}
+
+## Methods
+
+{{res_methods}}
diff --git a/hsml/docs/templates/model-serving/transformer_api.md b/hsml/docs/templates/model-serving/transformer_api.md
new file mode 100644
index 000000000..ae81e84ef
--- /dev/null
+++ b/hsml/docs/templates/model-serving/transformer_api.md
@@ -0,0 +1,29 @@
+# Transformer
+
+## Handle
+
+{{ms_get_model_serving}}
+
+## Creation
+
+{{ms_create_transformer}}
+
+## Retrieval
+
+### predictor.transformer
+
+Transformers can be accessed from the predictor metadata objects.
+
+``` python
+predictor.transformer
+```
+
+Predictors can be found in the deployment metadata objects (see [Predictor Reference](../predictor_api/#retrieval)). To retrieve a deployment, see the [Deployment Reference](../deployment_api/#retrieval).
+
+## Properties
+
+{{trans_properties}}
+
+## Methods
+
+{{trans_methods}}
diff --git a/hsml/java/pom.xml b/hsml/java/pom.xml
new file mode 100644
index 000000000..cb3e60028
--- /dev/null
+++ b/hsml/java/pom.xml
@@ -0,0 +1,109 @@
+
+
+ 4.0.0
+
+ com.logicalclocks
+ hsml
+ 4.0.0-SNAPSHOT
+
+
+ 1.8
+ 1.8
+
+
+
+
+
+ org.scala-tools
+ maven-scala-plugin
+
+
+ scala-compile-first
+ process-resources
+
+ add-source
+ compile
+
+
+
+ scala-test-compile
+ process-test-resources
+
+ testCompile
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-assembly-plugin
+ 2.4.1
+
+
+
+ jar-with-dependencies
+
+
+
+
+ make-assembly
+
+ package
+
+ single
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-checkstyle-plugin
+ 3.1.1
+
+
+ validate
+ validate
+
+ check
+
+
+
+
+ src/main/resources/checkstyle.xml
+ src/main/resources/suppressions.xml
+ true
+ true
+ true
+ true
+
+ src/main/java
+
+
+
+
+
+
+
+
+ Hops
+ Hops Repo
+ https://archiva.hops.works/repository/Hops/
+
+ true
+
+
+ true
+
+
+
+
+
+
+ Hops
+ Hops Repo
+ https://archiva.hops.works/repository/Hops/
+
+
+
diff --git a/hsml/java/src/main/resources/checkstyle.xml b/hsml/java/src/main/resources/checkstyle.xml
new file mode 100644
index 000000000..5f99eb681
--- /dev/null
+++ b/hsml/java/src/main/resources/checkstyle.xml
@@ -0,0 +1,312 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/hsml/java/src/main/resources/suppressions.xml b/hsml/java/src/main/resources/suppressions.xml
new file mode 100644
index 000000000..a86fa8219
--- /dev/null
+++ b/hsml/java/src/main/resources/suppressions.xml
@@ -0,0 +1,5 @@
+
+
+
+
\ No newline at end of file
diff --git a/hsml/mkdocs.yml b/hsml/mkdocs.yml
new file mode 100644
index 000000000..f20a7b1c5
--- /dev/null
+++ b/hsml/mkdocs.yml
@@ -0,0 +1,120 @@
+site_name: "Hopsworks Documentation"
+site_description: "Official documentation for Hopsworks and its Feature Store - an open source data-intensive AI platform used for the development and operation of machine learning models at scale."
+site_author: "Logical Clocks"
+site_url: "https://docs.hopsworks.ai/machine-learning-api/latest"
+
+# Repository
+repo_name: logicalclocks/hopsworks
+repo_url: https://github.com/logicalclocks/hopsworks
+edit_uri: ""
+
+nav:
+ - Home: https://docs.hopsworks.ai/
+ - Getting Started ↗: https://docs.hopsworks.ai/
+ - Tutorials: https://docs.hopsworks.ai/
+ - Concepts: https://docs.hopsworks.ai/
+ - Guides: https://docs.hopsworks.ai/
+ - Setup and Installation: https://docs.hopsworks.ai/
+ - Administration: https://docs.hopsworks.ai/
+ - API
:
+ - API Reference:
+ - Connection: generated/connection_api.md
+ - Model Registry:
+ - Model Registry: generated/model-registry/model_registry_api.md
+ - Model: generated/model-registry/model_api.md
+ - Model Schema: generated/model-registry/model_schema_api.md
+ - Model Serving:
+ - Model Serving: generated/model-serving/model_serving_api.md
+ - Deployment: generated/model-serving/deployment_api.md
+ - Deployment state: generated/model-serving/predictor_state_api.md
+ - Deployment state condition: generated/model-serving/predictor_state_condition_api.md
+ - Predictor: generated/model-serving/predictor_api.md
+ - Transformer: generated/model-serving/transformer_api.md
+ - Inference Logger: generated/model-serving/inference_logger_api.md
+ - Inference Batcher: generated/model-serving/inference_batcher_api.md
+ - Resources: generated/model-serving/resources_api.md
+ # Added to allow navigation using the side drawer
+ - Hopsworks API: https://docs.hopsworks.ai/
+ - Feature Store API: https://docs.hopsworks.ai/
+ - Feature Store JavaDoc: https://docs.hopsworks.ai/
+ - Contributing: CONTRIBUTING.md
+ - Community ↗: https://community.hopsworks.ai/
+
+theme:
+ name: material
+ custom_dir: docs/overrides
+ favicon: assets/images/favicon.ico
+ logo: assets/images/hops-logo.png
+ icon:
+ repo: fontawesome/brands/github
+ font:
+ text: "Roboto"
+ code: "IBM Plex Mono"
+ palette:
+ accent: teal
+ scheme: hopsworks
+ features:
+ - navigation.tabs
+ - navigation.tabs.sticky
+ - navigation.expand
+
+
+extra:
+ analytics:
+ provider: google
+ property: G-64FEEXPSDN
+ generator: false
+ version:
+ - provider: mike
+ - version: latest
+ social:
+ - icon: fontawesome/brands/twitter
+ link: https://twitter.com/hopsworks
+ - icon: fontawesome/brands/github
+ link: https://github.com/logicalclocks/hopsworks
+ - icon: fontawesome/brands/discourse
+ link: https://community.hopsworks.ai/
+ - icon: fontawesome/brands/linkedin
+ link: https://www.linkedin.com/company/hopsworks/
+
+extra_css:
+ - css/custom.css
+ - css/version-select.css
+ - css/dropdown.css
+ - css/marctech.css
+
+extra_javascript:
+ - js/version-select.js
+ - js/inject-api-links.js
+ - js/dropdown.js
+
+plugins:
+ - search
+ - minify:
+ minify_html: true
+ minify_css: true
+ minify_js: true
+ - mike:
+ canonical_version: latest
+
+markdown_extensions:
+ - admonition
+ - codehilite
+ - footnotes
+ - pymdownx.tabbed:
+ alternate_style: true
+ - pymdownx.arithmatex
+ - pymdownx.superfences
+ - pymdownx.details
+ - pymdownx.caret
+ - pymdownx.mark
+ - pymdownx.tilde
+ - pymdownx.critic
+ - attr_list
+ - md_in_html
+ - toc:
+ permalink: "#"
+ - pymdownx.tasklist:
+ custom_checkbox: true
+ - markdown_include.include:
+ base_path: docs
diff --git a/hsml/python/.pre-commit-config.yaml b/hsml/python/.pre-commit-config.yaml
new file mode 100644
index 000000000..645dcf677
--- /dev/null
+++ b/hsml/python/.pre-commit-config.yaml
@@ -0,0 +1,10 @@
+exclude: setup.py
+repos:
+ - repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.4.2
+ hooks:
+ # Run the linter
+ - id: ruff
+ args: [--fix]
+ # Run the formatter
+ - id: ruff-format
diff --git a/hsml/python/hsml/__init__.py b/hsml/python/hsml/__init__.py
new file mode 100644
index 000000000..4fb8156e3
--- /dev/null
+++ b/hsml/python/hsml/__init__.py
@@ -0,0 +1,35 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import warnings
+
+from hsml import util, version
+from hsml.connection import Connection
+
+
+connection = Connection.connection
+
+__version__ = version.__version__
+
+
+def ml_formatwarning(message, category, filename, lineno, line=None):
+ return "{}: {}\n".format(category.__name__, message)
+
+
+warnings.formatwarning = ml_formatwarning
+warnings.simplefilter("always", util.VersionWarning)
+
+__all__ = ["connection"]
diff --git a/hsml/python/hsml/client/__init__.py b/hsml/python/hsml/client/__init__.py
new file mode 100644
index 000000000..3982f0c56
--- /dev/null
+++ b/hsml/python/hsml/client/__init__.py
@@ -0,0 +1,152 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from hsml.client.hopsworks import base as hw_base
+from hsml.client.hopsworks import external as hw_external
+from hsml.client.hopsworks import internal as hw_internal
+from hsml.client.istio import base as ist_base
+from hsml.client.istio import external as ist_external
+from hsml.client.istio import internal as ist_internal
+from hsml.connection import CONNECTION_SAAS_HOSTNAME
+
+
+_client_type = None
+_saas_connection = None
+
+_hopsworks_client = None
+_istio_client = None
+
+_kserve_installed = None
+_serving_resource_limits = None
+_serving_num_instances_limits = None
+_knative_domain = None
+
+
+def init(
+ client_type,
+ host=None,
+ port=None,
+ project=None,
+ hostname_verification=None,
+ trust_store_path=None,
+ api_key_file=None,
+ api_key_value=None,
+):
+ global _client_type
+ _client_type = client_type
+
+ global _saas_connection
+ _saas_connection = host == CONNECTION_SAAS_HOSTNAME
+
+ global _hopsworks_client
+ if not _hopsworks_client:
+ if client_type == "internal":
+ _hopsworks_client = hw_internal.Client()
+ elif client_type == "external":
+ _hopsworks_client = hw_external.Client(
+ host,
+ port,
+ project,
+ hostname_verification,
+ trust_store_path,
+ api_key_file,
+ api_key_value,
+ )
+
+
+def get_instance() -> hw_base.Client:
+ global _hopsworks_client
+ if _hopsworks_client:
+ return _hopsworks_client
+ raise Exception("Couldn't find client. Try reconnecting to Hopsworks.")
+
+
+def set_istio_client(host, port, project=None, api_key_value=None):
+ global _client_type, _istio_client
+
+ if not _istio_client:
+ if _client_type == "internal":
+ _istio_client = ist_internal.Client(host, port)
+ elif _client_type == "external":
+ _istio_client = ist_external.Client(host, port, project, api_key_value)
+
+
+def get_istio_instance() -> ist_base.Client:
+ global _istio_client
+ return _istio_client
+
+
+def get_client_type() -> str:
+ global _client_type
+ return _client_type
+
+
+def is_saas_connection() -> bool:
+ global _saas_connection
+ return _saas_connection
+
+
+def set_kserve_installed(kserve_installed):
+ global _kserve_installed
+ _kserve_installed = kserve_installed
+
+
+def is_kserve_installed() -> bool:
+ global _kserve_installed
+ return _kserve_installed
+
+
+def set_serving_resource_limits(max_resources):
+ global _serving_resource_limits
+ _serving_resource_limits = max_resources
+
+
+def get_serving_resource_limits():
+ global _serving_resource_limits
+ return _serving_resource_limits
+
+
+def set_serving_num_instances_limits(num_instances_range):
+ global _serving_num_instances_limits
+ _serving_num_instances_limits = num_instances_range
+
+
+def get_serving_num_instances_limits():
+ global _serving_num_instances_limits
+ return _serving_num_instances_limits
+
+
+def is_scale_to_zero_required():
+ # scale-to-zero is required for KServe deployments if the Hopsworks variable `kube_serving_min_num_instances`
+ # is set to 0. Other possible values are -1 (unlimited num instances) or >1 num instances.
+ return get_serving_num_instances_limits()[0] == 0
+
+
+def get_knative_domain():
+ global _knative_domain
+ return _knative_domain
+
+
+def set_knative_domain(knative_domain):
+ global _knative_domain
+ _knative_domain = knative_domain
+
+
+def stop():
+ global _hopsworks_client, _istio_client
+ _hopsworks_client._close()
+ _istio_client._close()
+ _hopsworks_client = _istio_client = None
diff --git a/hsml/python/hsml/client/auth.py b/hsml/python/hsml/client/auth.py
new file mode 100644
index 000000000..696aaad2e
--- /dev/null
+++ b/hsml/python/hsml/client/auth.py
@@ -0,0 +1,64 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+
+import requests
+from hsml.client import exceptions
+
+
+class BearerAuth(requests.auth.AuthBase):
+ """Class to encapsulate a Bearer token."""
+
+ def __init__(self, token):
+ self._token = token
+
+ def __call__(self, r):
+ r.headers["Authorization"] = "Bearer " + self._token
+ return r
+
+
+class ApiKeyAuth(requests.auth.AuthBase):
+ """Class to encapsulate an API key."""
+
+ def __init__(self, token):
+ self._token = token
+
+ def __call__(self, r):
+ r.headers["Authorization"] = "ApiKey " + self._token
+ return r
+
+
+def get_api_key(api_key_value, api_key_file):
+ if api_key_value is not None:
+ return api_key_value
+ elif api_key_file is not None:
+ file = None
+ if os.path.exists(api_key_file):
+ try:
+ file = open(api_key_file, mode="r")
+ return file.read()
+ finally:
+ file.close()
+ else:
+ raise IOError(
+ "Could not find api key file on path: {}".format(api_key_file)
+ )
+ else:
+ raise exceptions.ExternalClientError(
+ "Either api_key_file or api_key_value must be set when connecting to"
+ " hopsworks from an external environment."
+ )
diff --git a/hsml/python/hsml/client/base.py b/hsml/python/hsml/client/base.py
new file mode 100644
index 000000000..d36e366c5
--- /dev/null
+++ b/hsml/python/hsml/client/base.py
@@ -0,0 +1,119 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABC, abstractmethod
+
+import furl
+import requests
+import urllib3
+from hsml.client import exceptions
+from hsml.decorators import connected
+
+
+urllib3.disable_warnings(urllib3.exceptions.SecurityWarning)
+urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
+
+
+class Client(ABC):
+ @abstractmethod
+ def __init__(self):
+ """To be implemented by clients."""
+ pass
+
+ @abstractmethod
+ def _get_verify(self, verify, trust_store_path):
+ """To be implemented by clients."""
+ pass
+
+ @abstractmethod
+ def _get_retry(self, session, request, response):
+ """To be implemented by clients."""
+ pass
+
+ @abstractmethod
+ def _get_host_port_pair(self):
+ """To be implemented by clients."""
+ pass
+
+ @connected
+ def _send_request(
+ self,
+ method,
+ path_params,
+ query_params=None,
+ headers=None,
+ data=None,
+ stream=False,
+ files=None,
+ ):
+ """Send REST request to a REST endpoint.
+
+ Uses the client it is executed from. Path parameters are url encoded automatically.
+
+ :param method: 'GET', 'PUT' or 'POST'
+ :type method: str
+ :param path_params: a list of path params to build the query url from starting after
+ the api resource, for example `["project", 119]`.
+ :type path_params: list
+ :param query_params: A dictionary of key/value pairs to be added as query parameters,
+ defaults to None
+ :type query_params: dict, optional
+ :param headers: Additional header information, defaults to None
+ :type headers: dict, optional
+ :param data: The payload as a python dictionary to be sent as json, defaults to None
+ :type data: dict, optional
+ :param stream: Set if response should be a stream, defaults to False
+ :type stream: boolean, optional
+ :param files: dictionary for multipart encoding upload
+ :type files: dict, optional
+ :raises RestAPIError: Raised when request wasn't correctly received, understood or accepted
+ :return: Response json
+ :rtype: dict
+ """
+ f_url = furl.furl(self._base_url)
+ f_url.path.segments = self.BASE_PATH_PARAMS + path_params
+ url = str(f_url)
+ request = requests.Request(
+ method,
+ url=url,
+ headers=headers,
+ data=data,
+ params=query_params,
+ auth=self._auth,
+ files=files,
+ )
+
+ prepped = self._session.prepare_request(request)
+ response = self._session.send(prepped, verify=self._verify, stream=stream)
+
+ if self._get_retry(request, response):
+ prepped = self._session.prepare_request(request)
+ response = self._session.send(prepped, verify=self._verify, stream=stream)
+
+ if response.status_code // 100 != 2:
+ raise exceptions.RestAPIError(url, response)
+
+ if stream:
+ return response
+ else:
+ # handle different success response codes
+ if len(response.content) == 0:
+ return None
+ return response.json()
+
+ def _close(self):
+ """Closes a client. Can be implemented for clean up purposes, not mandatory."""
+ self._connected = False
diff --git a/hsml/python/hsml/client/exceptions.py b/hsml/python/hsml/client/exceptions.py
new file mode 100644
index 000000000..6a59909db
--- /dev/null
+++ b/hsml/python/hsml/client/exceptions.py
@@ -0,0 +1,85 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+class RestAPIError(Exception):
+ """REST Exception encapsulating the response object and url."""
+
+ def __init__(self, url, response):
+ try:
+ error_object = response.json()
+ except Exception:
+ self.error_code = error_object = None
+
+ message = (
+ "Metadata operation error: (url: {}). Server response: \n"
+ "HTTP code: {}, HTTP reason: {}, body: {}".format(
+ url,
+ response.status_code,
+ response.reason,
+ response.content,
+ )
+ )
+
+ if error_object is not None:
+ self.error_code = error_object.get("errorCode", "")
+ message += ", error code: {}, error msg: {}, user msg: {}".format(
+ self.error_code,
+ error_object.get("errorMsg", ""),
+ error_object.get("usrMsg", ""),
+ )
+
+ super().__init__(message)
+ self.url = url
+ self.response = response
+
+ STATUS_CODE_BAD_REQUEST = 400
+ STATUS_CODE_UNAUTHORIZED = 401
+ STATUS_CODE_FORBIDDEN = 403
+ STATUS_CODE_NOT_FOUND = 404
+ STATUS_CODE_INTERNAL_SERVER_ERROR = 500
+
+
+class UnknownSecretStorageError(Exception):
+ """This exception will be raised if an unused secrets storage is passed as a parameter."""
+
+
+class ModelRegistryException(Exception):
+ """Generic model registry exception"""
+
+
+class ModelServingException(Exception):
+ """Generic model serving exception"""
+
+ ERROR_CODE_SERVING_NOT_FOUND = 240000
+ ERROR_CODE_ILLEGAL_ARGUMENT = 240001
+ ERROR_CODE_DUPLICATED_ENTRY = 240011
+
+ ERROR_CODE_DEPLOYMENT_NOT_RUNNING = 250001
+
+
+class InternalClientError(TypeError):
+ """Raised when internal client cannot be initialized due to missing arguments."""
+
+ def __init__(self, message):
+ super().__init__(message)
+
+
+class ExternalClientError(TypeError):
+ """Raised when external client cannot be initialized due to missing arguments."""
+
+ def __init__(self, message):
+ super().__init__(message)
diff --git a/hsml/python/hsml/client/hopsworks/__init__.py b/hsml/python/hsml/client/hopsworks/__init__.py
new file mode 100644
index 000000000..7fa8fd556
--- /dev/null
+++ b/hsml/python/hsml/client/hopsworks/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/client/hopsworks/base.py b/hsml/python/hsml/client/hopsworks/base.py
new file mode 100644
index 000000000..a0326b2d5
--- /dev/null
+++ b/hsml/python/hsml/client/hopsworks/base.py
@@ -0,0 +1,111 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from abc import abstractmethod
+
+from hsml.client import auth, base
+
+
+class Client(base.Client):
+ TOKEN_FILE = "token.jwt"
+ APIKEY_FILE = "api.key"
+ REST_ENDPOINT = "REST_ENDPOINT"
+ HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST"
+
+ BASE_PATH_PARAMS = ["hopsworks-api", "api"]
+
+ @abstractmethod
+ def __init__(self):
+ """To be extended by clients."""
+ pass
+
+ def _get_verify(self, verify, trust_store_path):
+ """Get verification method for sending HTTP requests to Hopsworks.
+
+ Credit to https://gist.github.com/gdamjan/55a8b9eec6cf7b771f92021d93b87b2c
+
+ :param verify: perform hostname verification, 'true' or 'false'
+ :type verify: str
+ :param trust_store_path: path of the truststore locally if it was uploaded manually to
+ the external environment
+ :type trust_store_path: str
+ :return: if verify is true and the truststore is provided, then return the trust store location
+ if verify is true but the truststore wasn't provided, then return true
+ if verify is false, then return false
+ :rtype: str or boolean
+ """
+ if verify == "true":
+ if trust_store_path is not None:
+ return trust_store_path
+ else:
+ return True
+
+ return False
+
+ def _get_retry(self, request, response):
+ """Get retry method for resending HTTP requests to Hopsworks
+
+ :param request: original HTTP request already sent
+ :type request: requests.Request
+ :param response: response of the original HTTP request
+ :type response: requests.Response
+ """
+ if response.status_code == 401 and self.REST_ENDPOINT in os.environ:
+ # refresh token and retry request - only on hopsworks
+ self._auth = auth.BearerAuth(self._read_jwt())
+ # Update request with the new token
+ request.auth = self._auth
+ # retry request
+ return True
+ return False
+
+ def _get_host_port_pair(self):
+ """
+ Removes "http or https" from the rest endpoint and returns a list
+ [endpoint, port], where endpoint is on the format /path.. without http://
+
+ :return: a list [endpoint, port]
+ :rtype: list
+ """
+ endpoint = self._base_url
+ if endpoint.startswith("http"):
+ last_index = endpoint.rfind("/")
+ endpoint = endpoint[last_index + 1 :]
+ host, port = endpoint.split(":")
+ return host, port
+
+ def _read_jwt(self):
+ """Retrieve jwt from local container."""
+ return self._read_file(self.TOKEN_FILE)
+
+ def _read_apikey(self):
+ """Retrieve apikey from local container."""
+ return self._read_file(self.APIKEY_FILE)
+
+ def _read_file(self, secret_file):
+ """Retrieve secret from local container."""
+ with open(os.path.join(self._secrets_dir, secret_file), "r") as secret:
+ return secret.read()
+
+ def _close(self):
+ """Closes a client. Can be implemented for clean up purposes, not mandatory."""
+ self._connected = False
+
+ def _replace_public_host(self, url):
+ """replace hostname to public hostname set in HOPSWORKS_PUBLIC_HOST"""
+ ui_url = url._replace(netloc=os.environ[self.HOPSWORKS_PUBLIC_HOST])
+ return ui_url
diff --git a/hsml/python/hsml/client/hopsworks/external.py b/hsml/python/hsml/client/hopsworks/external.py
new file mode 100644
index 000000000..6da14a4d3
--- /dev/null
+++ b/hsml/python/hsml/client/hopsworks/external.py
@@ -0,0 +1,85 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import requests
+from hsml.client import auth, exceptions
+from hsml.client.hopsworks import base as hopsworks
+
+
+class Client(hopsworks.Client):
+ def __init__(
+ self,
+ host,
+ port,
+ project,
+ hostname_verification,
+ trust_store_path,
+ api_key_file,
+ api_key_value,
+ ):
+ """Initializes a client in an external environment."""
+ if not host:
+ raise exceptions.ExternalClientError(
+ "host cannot be of type NoneType, host is a non-optional "
+ "argument to connect to hopsworks from an external environment."
+ )
+ if not project:
+ raise exceptions.ExternalClientError(
+ "project cannot be of type NoneType, project is a non-optional "
+ "argument to connect to hopsworks from an external environment."
+ )
+
+ self._host = host
+ self._port = port
+ self._base_url = "https://" + self._host + ":" + str(self._port)
+ self._project_name = project
+
+ api_key = auth.get_api_key(api_key_value, api_key_file)
+ self._auth = auth.ApiKeyAuth(api_key)
+
+ self._session = requests.session()
+ self._connected = True
+ self._verify = self._get_verify(self._host, trust_store_path)
+
+ if self._project_name is not None:
+ project_info = self._get_project_info(self._project_name)
+ self._project_id = str(project_info["projectId"])
+ else:
+ self._project_id = None
+
+ self._cert_key = None
+
+ def _close(self):
+ """Closes a client."""
+ self._connected = False
+
+ def _get_project_info(self, project_name):
+ """Makes a REST call to hopsworks to get all metadata of a project for the provided project.
+
+ :param project_name: the name of the project
+ :type project_name: str
+ :return: JSON response with project info
+ :rtype: dict
+ """
+ return self._send_request("GET", ["project", "getProjectInfo", project_name])
+
+ def _replace_public_host(self, url):
+ """no need to replace as we are already in external client"""
+ return url
+
+ @property
+ def host(self):
+ return self._host
diff --git a/hsml/python/hsml/client/hopsworks/internal.py b/hsml/python/hsml/client/hopsworks/internal.py
new file mode 100644
index 000000000..760251540
--- /dev/null
+++ b/hsml/python/hsml/client/hopsworks/internal.py
@@ -0,0 +1,208 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import base64
+import os
+import textwrap
+from pathlib import Path
+
+import requests
+from hsml.client import auth
+from hsml.client.hopsworks import base as hopsworks
+
+
+try:
+ import jks
+except ImportError:
+ pass
+
+
+class Client(hopsworks.Client):
+ REQUESTS_VERIFY = "REQUESTS_VERIFY"
+ DOMAIN_CA_TRUSTSTORE_PEM = "DOMAIN_CA_TRUSTSTORE_PEM"
+ PROJECT_ID = "HOPSWORKS_PROJECT_ID"
+ PROJECT_NAME = "HOPSWORKS_PROJECT_NAME"
+ HADOOP_USER_NAME = "HADOOP_USER_NAME"
+ MATERIAL_DIRECTORY = "MATERIAL_DIRECTORY"
+ HDFS_USER = "HDFS_USER"
+ T_CERTIFICATE = "t_certificate"
+ K_CERTIFICATE = "k_certificate"
+ TRUSTSTORE_SUFFIX = "__tstore.jks"
+ KEYSTORE_SUFFIX = "__kstore.jks"
+ PEM_CA_CHAIN = "ca_chain.pem"
+ CERT_KEY_SUFFIX = "__cert.key"
+ MATERIAL_PWD = "material_passwd"
+ SECRETS_DIR = "SECRETS_DIR"
+
+ def __init__(self):
+ """Initializes a client being run from a job/notebook directly on Hopsworks."""
+ self._base_url = self._get_hopsworks_rest_endpoint()
+ self._host, self._port = self._get_host_port_pair()
+ self._secrets_dir = (
+ os.environ[self.SECRETS_DIR] if self.SECRETS_DIR in os.environ else ""
+ )
+ self._cert_key = self._get_cert_pw()
+ trust_store_path = self._get_trust_store_path()
+ hostname_verification = (
+ os.environ[self.REQUESTS_VERIFY]
+ if self.REQUESTS_VERIFY in os.environ
+ else "true"
+ )
+ self._project_id = os.environ[self.PROJECT_ID]
+ self._project_name = self._project_name()
+ try:
+ self._auth = auth.BearerAuth(self._read_jwt())
+ except FileNotFoundError:
+ self._auth = auth.ApiKeyAuth(self._read_apikey())
+ self._verify = self._get_verify(hostname_verification, trust_store_path)
+ self._session = requests.session()
+
+ self._connected = True
+
+ def _get_hopsworks_rest_endpoint(self):
+ """Get the hopsworks REST endpoint for making requests to the REST API."""
+ return os.environ[self.REST_ENDPOINT]
+
+ def _get_trust_store_path(self):
+ """Convert truststore from jks to pem and return the location"""
+ ca_chain_path = Path(self.PEM_CA_CHAIN)
+ if not ca_chain_path.exists():
+ self._write_ca_chain(ca_chain_path)
+ return str(ca_chain_path)
+
+ def _write_ca_chain(self, ca_chain_path):
+ """
+ Converts JKS trustore file into PEM to be compatible with Python libraries
+ """
+ keystore_pw = self._cert_key
+ keystore_ca_cert = self._convert_jks_to_pem(
+ self._get_jks_key_store_path(), keystore_pw
+ )
+ truststore_ca_cert = self._convert_jks_to_pem(
+ self._get_jks_trust_store_path(), keystore_pw
+ )
+
+ with ca_chain_path.open("w") as f:
+ f.write(keystore_ca_cert + truststore_ca_cert)
+
+ def _convert_jks_to_pem(self, jks_path, keystore_pw):
+ """
+ Converts a keystore JKS that contains client private key,
+ client certificate and CA certificate that was used to
+ sign the certificate to PEM format and returns the CA certificate.
+ Args:
+ :jks_path: path to the JKS file
+ :pw: password for decrypting the JKS file
+ Returns:
+ strings: (ca_cert)
+ """
+ # load the keystore and decrypt it with password
+ ks = jks.KeyStore.load(jks_path, keystore_pw, try_decrypt_keys=True)
+ ca_certs = ""
+
+ # Convert CA Certificates into PEM format and append to string
+ for _alias, c in ks.certs.items():
+ ca_certs = ca_certs + self._bytes_to_pem_str(c.cert, "CERTIFICATE")
+ return ca_certs
+
+ def _bytes_to_pem_str(self, der_bytes, pem_type):
+ """
+ Utility function for creating PEM files
+
+ Args:
+ der_bytes: DER encoded bytes
+ pem_type: type of PEM, e.g Certificate, Private key, or RSA private key
+
+ Returns:
+ PEM String for a DER-encoded certificate or private key
+ """
+ pem_str = ""
+ pem_str = pem_str + "-----BEGIN {}-----".format(pem_type) + "\n"
+ pem_str = (
+ pem_str
+ + "\r\n".join(
+ textwrap.wrap(base64.b64encode(der_bytes).decode("ascii"), 64)
+ )
+ + "\n"
+ )
+ pem_str = pem_str + "-----END {}-----".format(pem_type) + "\n"
+ return pem_str
+
+ def _get_jks_trust_store_path(self):
+ """
+ Get truststore location
+
+ Returns:
+ truststore location
+ """
+ t_certificate = Path(self.T_CERTIFICATE)
+ if t_certificate.exists():
+ return str(t_certificate)
+ else:
+ username = os.environ[self.HADOOP_USER_NAME]
+ material_directory = Path(os.environ[self.MATERIAL_DIRECTORY])
+ return str(material_directory.joinpath(username + self.TRUSTSTORE_SUFFIX))
+
+ def _get_jks_key_store_path(self):
+ """
+ Get keystore location
+
+ Returns:
+ keystore location
+ """
+ k_certificate = Path(self.K_CERTIFICATE)
+ if k_certificate.exists():
+ return str(k_certificate)
+ else:
+ username = os.environ[self.HADOOP_USER_NAME]
+ material_directory = Path(os.environ[self.MATERIAL_DIRECTORY])
+ return str(material_directory.joinpath(username + self.KEYSTORE_SUFFIX))
+
+ def _project_name(self):
+ try:
+ return os.environ[self.PROJECT_NAME]
+ except KeyError:
+ pass
+
+ hops_user = self._project_user()
+ hops_user_split = hops_user.split(
+ "__"
+ ) # project users have username project__user
+ project = hops_user_split[0]
+ return project
+
+ def _project_user(self):
+ try:
+ hops_user = os.environ[self.HADOOP_USER_NAME]
+ except KeyError:
+ hops_user = os.environ[self.HDFS_USER]
+ return hops_user
+
+ def _get_cert_pw(self):
+ """
+ Get keystore password from local container
+
+ Returns:
+ Certificate password
+ """
+ pwd_path = Path(self.MATERIAL_PWD)
+ if not pwd_path.exists():
+ username = os.environ[self.HADOOP_USER_NAME]
+ material_directory = Path(os.environ[self.MATERIAL_DIRECTORY])
+ pwd_path = material_directory.joinpath(username + self.CERT_KEY_SUFFIX)
+
+ with pwd_path.open() as f:
+ return f.read()
diff --git a/hsml/python/hsml/client/istio/__init__.py b/hsml/python/hsml/client/istio/__init__.py
new file mode 100644
index 000000000..7fa8fd556
--- /dev/null
+++ b/hsml/python/hsml/client/istio/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/client/istio/base.py b/hsml/python/hsml/client/istio/base.py
new file mode 100644
index 000000000..9aaab9ba0
--- /dev/null
+++ b/hsml/python/hsml/client/istio/base.py
@@ -0,0 +1,97 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from abc import abstractmethod
+
+from hsml.client import base
+from hsml.client.istio.grpc.inference_client import GRPCInferenceServerClient
+
+
+class Client(base.Client):
+ SERVING_API_KEY = "SERVING_API_KEY"
+ HOPSWORKS_PUBLIC_HOST = "HOPSWORKS_PUBLIC_HOST"
+
+ BASE_PATH_PARAMS = []
+
+ @abstractmethod
+ def __init__(self):
+ """To be implemented by clients."""
+ pass
+
+ def _get_verify(self, verify, trust_store_path):
+ """Get verification method for sending inference requests to Istio.
+
+ Credit to https://gist.github.com/gdamjan/55a8b9eec6cf7b771f92021d93b87b2c
+
+ :param verify: perform hostname verification, 'true' or 'false'
+ :type verify: str
+ :param trust_store_path: path of the truststore locally if it was uploaded manually to
+ the external environment such as EKS or AKS
+ :type trust_store_path: str
+ :return: if verify is true and the truststore is provided, then return the trust store location
+ if verify is true but the truststore wasn't provided, then return true
+ if verify is false, then return false
+ :rtype: str or boolean
+ """
+ if verify == "true":
+ if trust_store_path is not None:
+ return trust_store_path
+ else:
+ return True
+
+ return False
+
+ def _get_retry(self, request, response):
+ """Get retry method for resending HTTP requests to Istio
+
+ :param request: original HTTP request already sent
+ :type request: requests.Request
+ :param response: response of the original HTTP request
+ :type response: requests.Response
+ """
+ return False
+
+ def _get_host_port_pair(self):
+ """
+ Removes "http or https" from the rest endpoint and returns a list
+ [endpoint, port], where endpoint is on the format /path.. without http://
+
+ :return: a list [endpoint, port]
+ :rtype: list
+ """
+ endpoint = self._base_url
+ if endpoint.startswith("http"):
+ last_index = endpoint.rfind("/")
+ endpoint = endpoint[last_index + 1 :]
+ host, port = endpoint.split(":")
+ return host, port
+
+ def _close(self):
+ """Closes a client. Can be implemented for clean up purposes, not mandatory."""
+ self._connected = False
+
+ def _replace_public_host(self, url):
+ """replace hostname to public hostname set in HOPSWORKS_PUBLIC_HOST"""
+ ui_url = url._replace(netloc=os.environ[self.HOPSWORKS_PUBLIC_HOST])
+ return ui_url
+
+ def _create_grpc_channel(self, service_hostname: str) -> GRPCInferenceServerClient:
+ return GRPCInferenceServerClient(
+ url=self._host + ":" + str(self._port),
+ channel_args=(("grpc.ssl_target_name_override", service_hostname),),
+ serving_api_key=self._auth._token,
+ )
diff --git a/hsml/python/hsml/client/istio/external.py b/hsml/python/hsml/client/istio/external.py
new file mode 100644
index 000000000..c4fd89787
--- /dev/null
+++ b/hsml/python/hsml/client/istio/external.py
@@ -0,0 +1,56 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import requests
+from hsml.client import auth
+from hsml.client.istio import base as istio
+
+
+class Client(istio.Client):
+ def __init__(
+ self,
+ host,
+ port,
+ project,
+ api_key_value,
+ hostname_verification=None,
+ trust_store_path=None,
+ ):
+ """Initializes a client in an external environment such as AWS Sagemaker."""
+ self._host = host
+ self._port = port
+ self._base_url = "http://" + self._host + ":" + str(self._port)
+ self._project_name = project
+
+ self._auth = auth.ApiKeyAuth(api_key_value)
+
+ self._session = requests.session()
+ self._connected = True
+ self._verify = self._get_verify(hostname_verification, trust_store_path)
+
+ self._cert_key = None
+
+ def _close(self):
+ """Closes a client."""
+ self._connected = False
+
+ def _replace_public_host(self, url):
+ """no need to replace as we are already in external client"""
+ return url
+
+ @property
+ def host(self):
+ return self._host
diff --git a/hsml/python/hsml/client/istio/grpc/__init__.py b/hsml/python/hsml/client/istio/grpc/__init__.py
new file mode 100644
index 000000000..ff8055b9b
--- /dev/null
+++ b/hsml/python/hsml/client/istio/grpc/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/client/istio/grpc/errors.py b/hsml/python/hsml/client/istio/grpc/errors.py
new file mode 100644
index 000000000..062630bea
--- /dev/null
+++ b/hsml/python/hsml/client/istio/grpc/errors.py
@@ -0,0 +1,30 @@
+# Copyright 2022 The KServe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# This implementation has been borrowed from the kserve/kserve repository
+# https://github.com/kserve/kserve/blob/release-0.11/python/kserve/kserve/errors.py
+
+
+class InvalidInput(ValueError):
+ """
+ Exception class indicating invalid input arguments.
+ HTTP Servers should return HTTP_400 (Bad Request).
+ """
+
+ def __init__(self, reason):
+ self.reason = reason
+
+ def __str__(self):
+ return self.reason
diff --git a/hsml/python/hsml/client/istio/grpc/exceptions.py b/hsml/python/hsml/client/istio/grpc/exceptions.py
new file mode 100644
index 000000000..6477c9488
--- /dev/null
+++ b/hsml/python/hsml/client/istio/grpc/exceptions.py
@@ -0,0 +1,123 @@
+# Copyright 2023 The KServe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# coding: utf-8
+
+# This implementation has been borrowed from kserve/kserve repository
+# https://github.com/kserve/kserve/blob/release-0.11/python/kserve/kserve/exceptions.py
+
+import six
+
+
+class OpenApiException(Exception):
+ """The base exception class for all OpenAPIExceptions"""
+
+
+class ApiTypeError(OpenApiException, TypeError):
+ def __init__(self, msg, path_to_item=None, valid_classes=None, key_type=None):
+ """Raises an exception for TypeErrors
+
+ Args:
+ msg (str): the exception message
+
+ Keyword Args:
+ path_to_item (list): a list of keys an indices to get to the
+ current_item
+ None if unset
+ valid_classes (tuple): the primitive classes that current item
+ should be an instance of
+ None if unset
+ key_type (bool): False if our value is a value in a dict
+ True if it is a key in a dict
+ False if our item is an item in a list
+ None if unset
+ """
+ self.path_to_item = path_to_item
+ self.valid_classes = valid_classes
+ self.key_type = key_type
+ full_msg = msg
+ if path_to_item:
+ full_msg = "{0} at {1}".format(msg, render_path(path_to_item))
+ super(ApiTypeError, self).__init__(full_msg)
+
+
+class ApiValueError(OpenApiException, ValueError):
+ def __init__(self, msg, path_to_item=None):
+ """
+ Args:
+ msg (str): the exception message
+
+ Keyword Args:
+ path_to_item (list) the path to the exception in the
+ received_data dict. None if unset
+ """
+
+ self.path_to_item = path_to_item
+ full_msg = msg
+ if path_to_item:
+ full_msg = "{0} at {1}".format(msg, render_path(path_to_item))
+ super(ApiValueError, self).__init__(full_msg)
+
+
+class ApiKeyError(OpenApiException, KeyError):
+ def __init__(self, msg, path_to_item=None):
+ """
+ Args:
+ msg (str): the exception message
+
+ Keyword Args:
+ path_to_item (None/list) the path to the exception in the
+ received_data dict
+ """
+ self.path_to_item = path_to_item
+ full_msg = msg
+ if path_to_item:
+ full_msg = "{0} at {1}".format(msg, render_path(path_to_item))
+ super(ApiKeyError, self).__init__(full_msg)
+
+
+class ApiException(OpenApiException):
+ def __init__(self, status=None, reason=None, http_resp=None):
+ if http_resp:
+ self.status = http_resp.status
+ self.reason = http_resp.reason
+ self.body = http_resp.data
+ self.headers = http_resp.getheaders()
+ else:
+ self.status = status
+ self.reason = reason
+ self.body = None
+ self.headers = None
+
+ def __str__(self):
+ """Custom error messages for exception"""
+ error_message = "({0})\n" "Reason: {1}\n".format(self.status, self.reason)
+ if self.headers:
+ error_message += "HTTP response headers: {0}\n".format(self.headers)
+
+ if self.body:
+ error_message += "HTTP response body: {0}\n".format(self.body)
+
+ return error_message
+
+
+def render_path(path_to_item):
+ """Returns a string representation of a path"""
+ result = ""
+ for pth in path_to_item:
+ if isinstance(pth, six.integer_types):
+ result += "[{0}]".format(pth)
+ else:
+ result += "['{0}']".format(pth)
+ return result
diff --git a/hsml/python/hsml/client/istio/grpc/inference_client.py b/hsml/python/hsml/client/istio/grpc/inference_client.py
new file mode 100644
index 000000000..3cc3164c5
--- /dev/null
+++ b/hsml/python/hsml/client/istio/grpc/inference_client.py
@@ -0,0 +1,74 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import grpc
+from hsml.client.istio.grpc.proto.grpc_predict_v2_pb2_grpc import (
+ GRPCInferenceServiceStub,
+)
+from hsml.client.istio.utils.infer_type import InferRequest, InferResponse
+
+
+class GRPCInferenceServerClient:
+ def __init__(
+ self,
+ url,
+ serving_api_key,
+ channel_args=None,
+ ):
+ if channel_args is not None:
+ channel_opt = channel_args
+ else:
+ channel_opt = [
+ ("grpc.max_send_message_length", -1),
+ ("grpc.max_receive_message_length", -1),
+ ]
+
+ # Authentication is done via API Key in the Authorization header
+ self._channel = grpc.insecure_channel(url, options=channel_opt)
+ self._client_stub = GRPCInferenceServiceStub(self._channel)
+ self._serving_api_key = serving_api_key
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
+
+ def __del__(self):
+ """It is called during object garbage collection."""
+ self.close()
+
+ def close(self):
+ """Close the client. Future calls to server will result in an Error."""
+ self._channel.close()
+
+ def infer(self, infer_request: InferRequest, headers=None, client_timeout=None):
+ headers = {} if headers is None else headers
+ headers["authorization"] = "ApiKey " + self._serving_api_key
+ metadata = headers.items()
+
+ # convert the InferRequest to a ModelInferRequest message
+ request = infer_request.to_grpc()
+
+ try:
+ # send request
+ model_infer_response = self._client_stub.ModelInfer(
+ request=request, metadata=metadata, timeout=client_timeout
+ )
+ except grpc.RpcError as rpc_error:
+ raise rpc_error
+
+ # convert back the ModelInferResponse message to InferResponse
+ return InferResponse.from_grpc(model_infer_response)
diff --git a/hsml/python/hsml/client/istio/grpc/proto/__init__.py b/hsml/python/hsml/client/istio/grpc/proto/__init__.py
new file mode 100644
index 000000000..ff8055b9b
--- /dev/null
+++ b/hsml/python/hsml/client/istio/grpc/proto/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2.proto b/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2.proto
new file mode 100644
index 000000000..c05221d73
--- /dev/null
+++ b/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2.proto
@@ -0,0 +1,362 @@
+// Copyright 2022 The KServe Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+syntax = "proto3";
+package inference;
+
+// Inference Server GRPC endpoints.
+service GRPCInferenceService
+{
+ // The ServerLive API indicates if the inference server is able to receive
+ // and respond to metadata and inference requests.
+ rpc ServerLive(ServerLiveRequest) returns (ServerLiveResponse) {}
+
+ // The ServerReady API indicates if the server is ready for inferencing.
+ rpc ServerReady(ServerReadyRequest) returns (ServerReadyResponse) {}
+
+ // The ModelReady API indicates if a specific model is ready for inferencing.
+ rpc ModelReady(ModelReadyRequest) returns (ModelReadyResponse) {}
+
+ // The ServerMetadata API provides information about the server. Errors are
+ // indicated by the google.rpc.Status returned for the request. The OK code
+ // indicates success and other codes indicate failure.
+ rpc ServerMetadata(ServerMetadataRequest) returns (ServerMetadataResponse) {}
+
+ // The per-model metadata API provides information about a model. Errors are
+ // indicated by the google.rpc.Status returned for the request. The OK code
+ // indicates success and other codes indicate failure.
+ rpc ModelMetadata(ModelMetadataRequest) returns (ModelMetadataResponse) {}
+
+ // The ModelInfer API performs inference using the specified model. Errors are
+ // indicated by the google.rpc.Status returned for the request. The OK code
+ // indicates success and other codes indicate failure.
+ rpc ModelInfer(ModelInferRequest) returns (ModelInferResponse) {}
+
+ // Load or reload a model from a repository.
+ rpc RepositoryModelLoad(RepositoryModelLoadRequest) returns (RepositoryModelLoadResponse) {}
+
+ // Unload a model.
+ rpc RepositoryModelUnload(RepositoryModelUnloadRequest) returns (RepositoryModelUnloadResponse) {}
+}
+
+message ServerLiveRequest {}
+
+message ServerLiveResponse
+{
+ // True if the inference server is live, false if not live.
+ bool live = 1;
+}
+
+message ServerReadyRequest {}
+
+message ServerReadyResponse
+{
+ // True if the inference server is ready, false if not ready.
+ bool ready = 1;
+}
+
+message ModelReadyRequest
+{
+ // The name of the model to check for readiness.
+ string name = 1;
+
+ // The version of the model to check for readiness. If not given the
+ // server will choose a version based on the model and internal policy.
+ string version = 2;
+}
+
+message ModelReadyResponse
+{
+ // True if the model is ready, false if not ready.
+ bool ready = 1;
+}
+
+message ServerMetadataRequest {}
+
+message ServerMetadataResponse
+{
+ // The server name.
+ string name = 1;
+
+ // The server version.
+ string version = 2;
+
+ // The extensions supported by the server.
+ repeated string extensions = 3;
+}
+
+message ModelMetadataRequest
+{
+ // The name of the model.
+ string name = 1;
+
+ // The version of the model to check for readiness. If not given the
+ // server will choose a version based on the model and internal policy.
+ string version = 2;
+}
+
+message ModelMetadataResponse
+{
+ // Metadata for a tensor.
+ message TensorMetadata
+ {
+ // The tensor name.
+ string name = 1;
+
+ // The tensor data type.
+ string datatype = 2;
+
+ // The tensor shape. A variable-size dimension is represented
+ // by a -1 value.
+ repeated int64 shape = 3;
+ }
+
+ // The model name.
+ string name = 1;
+
+ // The versions of the model available on the server.
+ repeated string versions = 2;
+
+ // The model's platform. See Platforms.
+ string platform = 3;
+
+ // The model's inputs.
+ repeated TensorMetadata inputs = 4;
+
+ // The model's outputs.
+ repeated TensorMetadata outputs = 5;
+}
+
+message ModelInferRequest
+{
+ // An input tensor for an inference request.
+ message InferInputTensor
+ {
+ // The tensor name.
+ string name = 1;
+
+ // The tensor data type.
+ string datatype = 2;
+
+ // The tensor shape.
+ repeated int64 shape = 3;
+
+ // Optional inference input tensor parameters.
+ map
parameters = 4;
+
+ // The tensor contents using a data-type format. This field must
+ // not be specified if "raw" tensor contents are being used for
+ // the inference request.
+ InferTensorContents contents = 5;
+ }
+
+ // An output tensor requested for an inference request.
+ message InferRequestedOutputTensor
+ {
+ // The tensor name.
+ string name = 1;
+
+ // Optional requested output tensor parameters.
+ map parameters = 2;
+ }
+
+ // The name of the model to use for inferencing.
+ string model_name = 1;
+
+ // The version of the model to use for inference. If not given the
+ // server will choose a version based on the model and internal policy.
+ string model_version = 2;
+
+ // Optional identifier for the request. If specified will be
+ // returned in the response.
+ string id = 3;
+
+ // Optional inference parameters.
+ map parameters = 4;
+
+ // The input tensors for the inference.
+ repeated InferInputTensor inputs = 5;
+
+ // The requested output tensors for the inference. Optional, if not
+ // specified all outputs produced by the model will be returned.
+ repeated InferRequestedOutputTensor outputs = 6;
+
+ // The data contained in an input tensor can be represented in "raw"
+ // bytes form or in the repeated type that matches the tensor's data
+ // type. To use the raw representation 'raw_input_contents' must be
+ // initialized with data for each tensor in the same order as
+ // 'inputs'. For each tensor, the size of this content must match
+ // what is expected by the tensor's shape and data type. The raw
+ // data must be the flattened, one-dimensional, row-major order of
+ // the tensor elements without any stride or padding between the
+ // elements. Note that the FP16 and BF16 data types must be represented as
+ // raw content as there is no specific data type for a 16-bit float type.
+ //
+ // If this field is specified then InferInputTensor::contents must
+ // not be specified for any input tensor.
+ repeated bytes raw_input_contents = 7;
+}
+
+message ModelInferResponse
+{
+ // An output tensor returned for an inference request.
+ message InferOutputTensor
+ {
+ // The tensor name.
+ string name = 1;
+
+ // The tensor data type.
+ string datatype = 2;
+
+ // The tensor shape.
+ repeated int64 shape = 3;
+
+ // Optional output tensor parameters.
+ map parameters = 4;
+
+ // The tensor contents using a data-type format. This field must
+ // not be specified if "raw" tensor contents are being used for
+ // the inference response.
+ InferTensorContents contents = 5;
+ }
+
+ // The name of the model used for inference.
+ string model_name = 1;
+
+ // The version of the model used for inference.
+ string model_version = 2;
+
+ // The id of the inference request if one was specified.
+ string id = 3;
+
+ // Optional inference response parameters.
+ map parameters = 4;
+
+ // The output tensors holding inference results.
+ repeated InferOutputTensor outputs = 5;
+
+ // The data contained in an output tensor can be represented in
+ // "raw" bytes form or in the repeated type that matches the
+ // tensor's data type. To use the raw representation 'raw_output_contents'
+ // must be initialized with data for each tensor in the same order as
+ // 'outputs'. For each tensor, the size of this content must match
+ // what is expected by the tensor's shape and data type. The raw
+ // data must be the flattened, one-dimensional, row-major order of
+ // the tensor elements without any stride or padding between the
+ // elements. Note that the FP16 and BF16 data types must be represented as
+ // raw content as there is no specific data type for a 16-bit float type.
+ //
+ // If this field is specified then InferOutputTensor::contents must
+ // not be specified for any output tensor.
+ repeated bytes raw_output_contents = 6;
+}
+
+// An inference parameter value. The Parameters message describes a
+// “name”/”value” pair, where the “name” is the name of the parameter
+// and the “value” is a boolean, integer, or string corresponding to
+// the parameter.
+message InferParameter
+{
+ // The parameter value can be a string, an int64, a boolean
+ // or a message specific to a predefined parameter.
+ oneof parameter_choice
+ {
+ // A boolean parameter value.
+ bool bool_param = 1;
+
+ // An int64 parameter value.
+ int64 int64_param = 2;
+
+ // A string parameter value.
+ string string_param = 3;
+ }
+}
+
+// The data contained in a tensor represented by the repeated type
+// that matches the tensor's data type. Protobuf oneof is not used
+// because oneofs cannot contain repeated fields.
+message InferTensorContents
+{
+ // Representation for BOOL data type. The size must match what is
+ // expected by the tensor's shape. The contents must be the flattened,
+ // one-dimensional, row-major order of the tensor elements.
+ repeated bool bool_contents = 1;
+
+ // Representation for INT8, INT16, and INT32 data types. The size
+ // must match what is expected by the tensor's shape. The contents
+ // must be the flattened, one-dimensional, row-major order of the
+ // tensor elements.
+ repeated int32 int_contents = 2;
+
+ // Representation for INT64 data types. The size must match what
+ // is expected by the tensor's shape. The contents must be the
+ // flattened, one-dimensional, row-major order of the tensor elements.
+ repeated int64 int64_contents = 3;
+
+ // Representation for UINT8, UINT16, and UINT32 data types. The size
+ // must match what is expected by the tensor's shape. The contents
+ // must be the flattened, one-dimensional, row-major order of the
+ // tensor elements.
+ repeated uint32 uint_contents = 4;
+
+ // Representation for UINT64 data types. The size must match what
+ // is expected by the tensor's shape. The contents must be the
+ // flattened, one-dimensional, row-major order of the tensor elements.
+ repeated uint64 uint64_contents = 5;
+
+ // Representation for FP32 data type. The size must match what is
+ // expected by the tensor's shape. The contents must be the flattened,
+ // one-dimensional, row-major order of the tensor elements.
+ repeated float fp32_contents = 6;
+
+ // Representation for FP64 data type. The size must match what is
+ // expected by the tensor's shape. The contents must be the flattened,
+ // one-dimensional, row-major order of the tensor elements.
+ repeated double fp64_contents = 7;
+
+ // Representation for BYTES data type. The size must match what is
+ // expected by the tensor's shape. The contents must be the flattened,
+ // one-dimensional, row-major order of the tensor elements.
+ repeated bytes bytes_contents = 8;
+}
+
+message RepositoryModelLoadRequest
+{
+ // The name of the model to load, or reload.
+ string model_name = 1;
+}
+
+message RepositoryModelLoadResponse
+{
+ // The name of the model trying to load or reload.
+ string model_name = 1;
+
+ // boolean parameter to indicate whether model is loaded or not
+ bool isLoaded = 2;
+}
+
+message RepositoryModelUnloadRequest
+{
+ // The name of the model to unload.
+ string model_name = 1;
+}
+
+message RepositoryModelUnloadResponse
+{
+ // The name of the model trying to load or reload.
+ string model_name = 1;
+
+ // boolean parameter to indicate whether model is unloaded or not
+ bool isUnloaded = 2;
+}
diff --git a/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2_pb2.py b/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2_pb2.py
new file mode 100644
index 000000000..07af5f1c5
--- /dev/null
+++ b/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2_pb2.py
@@ -0,0 +1,451 @@
+# Copyright 2022 The KServe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# -*- coding: utf-8 -*-
+# Generated by the protocol buffer compiler. DO NOT EDIT!
+# source: grpc_predict_v2.proto
+"""Generated protocol buffer code."""
+
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import descriptor_pool as _descriptor_pool
+from google.protobuf import message as _message
+from google.protobuf import reflection as _reflection
+from google.protobuf import symbol_database as _symbol_database
+
+
+# @@protoc_insertion_point(imports)
+
+_sym_db = _symbol_database.Default()
+
+
+DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
+ b'\n\x15grpc_predict_v2.proto\x12\tinference"\x13\n\x11ServerLiveRequest""\n\x12ServerLiveResponse\x12\x0c\n\x04live\x18\x01 \x01(\x08"\x14\n\x12ServerReadyRequest"$\n\x13ServerReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"2\n\x11ModelReadyRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t"#\n\x12ModelReadyResponse\x12\r\n\x05ready\x18\x01 \x01(\x08"\x17\n\x15ServerMetadataRequest"K\n\x16ServerMetadataResponse\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t\x12\x12\n\nextensions\x18\x03 \x03(\t"5\n\x14ModelMetadataRequest\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0f\n\x07version\x18\x02 \x01(\t"\x8d\x02\n\x15ModelMetadataResponse\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08versions\x18\x02 \x03(\t\x12\x10\n\x08platform\x18\x03 \x01(\t\x12?\n\x06inputs\x18\x04 \x03(\x0b\x32/.inference.ModelMetadataResponse.TensorMetadata\x12@\n\x07outputs\x18\x05 \x03(\x0b\x32/.inference.ModelMetadataResponse.TensorMetadata\x1a?\n\x0eTensorMetadata\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64\x61tatype\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03"\xee\x06\n\x11ModelInferRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x15\n\rmodel_version\x18\x02 \x01(\t\x12\n\n\x02id\x18\x03 \x01(\t\x12@\n\nparameters\x18\x04 \x03(\x0b\x32,.inference.ModelInferRequest.ParametersEntry\x12=\n\x06inputs\x18\x05 \x03(\x0b\x32-.inference.ModelInferRequest.InferInputTensor\x12H\n\x07outputs\x18\x06 \x03(\x0b\x32\x37.inference.ModelInferRequest.InferRequestedOutputTensor\x12\x1a\n\x12raw_input_contents\x18\x07 \x03(\x0c\x1a\x94\x02\n\x10InferInputTensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64\x61tatype\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03\x12Q\n\nparameters\x18\x04 \x03(\x0b\x32=.inference.ModelInferRequest.InferInputTensor.ParametersEntry\x12\x30\n\x08\x63ontents\x18\x05 \x01(\x0b\x32\x1e.inference.InferTensorContents\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1a\xd5\x01\n\x1aInferRequestedOutputTensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12[\n\nparameters\x18\x02 \x03(\x0b\x32G.inference.ModelInferRequest.InferRequestedOutputTensor.ParametersEntry\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01"\xd5\x04\n\x12ModelInferResponse\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x15\n\rmodel_version\x18\x02 \x01(\t\x12\n\n\x02id\x18\x03 \x01(\t\x12\x41\n\nparameters\x18\x04 \x03(\x0b\x32-.inference.ModelInferResponse.ParametersEntry\x12@\n\x07outputs\x18\x05 \x03(\x0b\x32/.inference.ModelInferResponse.InferOutputTensor\x12\x1b\n\x13raw_output_contents\x18\x06 \x03(\x0c\x1a\x97\x02\n\x11InferOutputTensor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x10\n\x08\x64\x61tatype\x18\x02 \x01(\t\x12\r\n\x05shape\x18\x03 \x03(\x03\x12S\n\nparameters\x18\x04 \x03(\x0b\x32?.inference.ModelInferResponse.InferOutputTensor.ParametersEntry\x12\x30\n\x08\x63ontents\x18\x05 \x01(\x0b\x32\x1e.inference.InferTensorContents\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01\x1aL\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.inference.InferParameter:\x02\x38\x01"i\n\x0eInferParameter\x12\x14\n\nbool_param\x18\x01 \x01(\x08H\x00\x12\x15\n\x0bint64_param\x18\x02 \x01(\x03H\x00\x12\x16\n\x0cstring_param\x18\x03 \x01(\tH\x00\x42\x12\n\x10parameter_choice"\xd0\x01\n\x13InferTensorContents\x12\x15\n\rbool_contents\x18\x01 \x03(\x08\x12\x14\n\x0cint_contents\x18\x02 \x03(\x05\x12\x16\n\x0eint64_contents\x18\x03 \x03(\x03\x12\x15\n\ruint_contents\x18\x04 \x03(\r\x12\x17\n\x0fuint64_contents\x18\x05 \x03(\x04\x12\x15\n\rfp32_contents\x18\x06 \x03(\x02\x12\x15\n\rfp64_contents\x18\x07 \x03(\x01\x12\x16\n\x0e\x62ytes_contents\x18\x08 \x03(\x0c"0\n\x1aRepositoryModelLoadRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t"C\n\x1bRepositoryModelLoadResponse\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x10\n\x08isLoaded\x18\x02 \x01(\x08"2\n\x1cRepositoryModelUnloadRequest\x12\x12\n\nmodel_name\x18\x01 \x01(\t"G\n\x1dRepositoryModelUnloadResponse\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x12\n\nisUnloaded\x18\x02 \x01(\x08\x32\xd2\x05\n\x14GRPCInferenceService\x12K\n\nServerLive\x12\x1c.inference.ServerLiveRequest\x1a\x1d.inference.ServerLiveResponse"\x00\x12N\n\x0bServerReady\x12\x1d.inference.ServerReadyRequest\x1a\x1e.inference.ServerReadyResponse"\x00\x12K\n\nModelReady\x12\x1c.inference.ModelReadyRequest\x1a\x1d.inference.ModelReadyResponse"\x00\x12W\n\x0eServerMetadata\x12 .inference.ServerMetadataRequest\x1a!.inference.ServerMetadataResponse"\x00\x12T\n\rModelMetadata\x12\x1f.inference.ModelMetadataRequest\x1a .inference.ModelMetadataResponse"\x00\x12K\n\nModelInfer\x12\x1c.inference.ModelInferRequest\x1a\x1d.inference.ModelInferResponse"\x00\x12\x66\n\x13RepositoryModelLoad\x12%.inference.RepositoryModelLoadRequest\x1a&.inference.RepositoryModelLoadResponse"\x00\x12l\n\x15RepositoryModelUnload\x12\'.inference.RepositoryModelUnloadRequest\x1a(.inference.RepositoryModelUnloadResponse"\x00\x62\x06proto3'
+)
+
+
+_SERVERLIVEREQUEST = DESCRIPTOR.message_types_by_name["ServerLiveRequest"]
+_SERVERLIVERESPONSE = DESCRIPTOR.message_types_by_name["ServerLiveResponse"]
+_SERVERREADYREQUEST = DESCRIPTOR.message_types_by_name["ServerReadyRequest"]
+_SERVERREADYRESPONSE = DESCRIPTOR.message_types_by_name["ServerReadyResponse"]
+_MODELREADYREQUEST = DESCRIPTOR.message_types_by_name["ModelReadyRequest"]
+_MODELREADYRESPONSE = DESCRIPTOR.message_types_by_name["ModelReadyResponse"]
+_SERVERMETADATAREQUEST = DESCRIPTOR.message_types_by_name["ServerMetadataRequest"]
+_SERVERMETADATARESPONSE = DESCRIPTOR.message_types_by_name["ServerMetadataResponse"]
+_MODELMETADATAREQUEST = DESCRIPTOR.message_types_by_name["ModelMetadataRequest"]
+_MODELMETADATARESPONSE = DESCRIPTOR.message_types_by_name["ModelMetadataResponse"]
+_MODELMETADATARESPONSE_TENSORMETADATA = _MODELMETADATARESPONSE.nested_types_by_name[
+ "TensorMetadata"
+]
+_MODELINFERREQUEST = DESCRIPTOR.message_types_by_name["ModelInferRequest"]
+_MODELINFERREQUEST_INFERINPUTTENSOR = _MODELINFERREQUEST.nested_types_by_name[
+ "InferInputTensor"
+]
+_MODELINFERREQUEST_INFERINPUTTENSOR_PARAMETERSENTRY = (
+ _MODELINFERREQUEST_INFERINPUTTENSOR.nested_types_by_name["ParametersEntry"]
+)
+_MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR = _MODELINFERREQUEST.nested_types_by_name[
+ "InferRequestedOutputTensor"
+]
+_MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR_PARAMETERSENTRY = (
+ _MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR.nested_types_by_name[
+ "ParametersEntry"
+ ]
+)
+_MODELINFERREQUEST_PARAMETERSENTRY = _MODELINFERREQUEST.nested_types_by_name[
+ "ParametersEntry"
+]
+_MODELINFERRESPONSE = DESCRIPTOR.message_types_by_name["ModelInferResponse"]
+_MODELINFERRESPONSE_INFEROUTPUTTENSOR = _MODELINFERRESPONSE.nested_types_by_name[
+ "InferOutputTensor"
+]
+_MODELINFERRESPONSE_INFEROUTPUTTENSOR_PARAMETERSENTRY = (
+ _MODELINFERRESPONSE_INFEROUTPUTTENSOR.nested_types_by_name["ParametersEntry"]
+)
+_MODELINFERRESPONSE_PARAMETERSENTRY = _MODELINFERRESPONSE.nested_types_by_name[
+ "ParametersEntry"
+]
+_INFERPARAMETER = DESCRIPTOR.message_types_by_name["InferParameter"]
+_INFERTENSORCONTENTS = DESCRIPTOR.message_types_by_name["InferTensorContents"]
+_REPOSITORYMODELLOADREQUEST = DESCRIPTOR.message_types_by_name[
+ "RepositoryModelLoadRequest"
+]
+_REPOSITORYMODELLOADRESPONSE = DESCRIPTOR.message_types_by_name[
+ "RepositoryModelLoadResponse"
+]
+_REPOSITORYMODELUNLOADREQUEST = DESCRIPTOR.message_types_by_name[
+ "RepositoryModelUnloadRequest"
+]
+_REPOSITORYMODELUNLOADRESPONSE = DESCRIPTOR.message_types_by_name[
+ "RepositoryModelUnloadResponse"
+]
+ServerLiveRequest = _reflection.GeneratedProtocolMessageType(
+ "ServerLiveRequest",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _SERVERLIVEREQUEST,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ServerLiveRequest)
+ },
+)
+_sym_db.RegisterMessage(ServerLiveRequest)
+
+ServerLiveResponse = _reflection.GeneratedProtocolMessageType(
+ "ServerLiveResponse",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _SERVERLIVERESPONSE,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ServerLiveResponse)
+ },
+)
+_sym_db.RegisterMessage(ServerLiveResponse)
+
+ServerReadyRequest = _reflection.GeneratedProtocolMessageType(
+ "ServerReadyRequest",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _SERVERREADYREQUEST,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ServerReadyRequest)
+ },
+)
+_sym_db.RegisterMessage(ServerReadyRequest)
+
+ServerReadyResponse = _reflection.GeneratedProtocolMessageType(
+ "ServerReadyResponse",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _SERVERREADYRESPONSE,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ServerReadyResponse)
+ },
+)
+_sym_db.RegisterMessage(ServerReadyResponse)
+
+ModelReadyRequest = _reflection.GeneratedProtocolMessageType(
+ "ModelReadyRequest",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MODELREADYREQUEST,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelReadyRequest)
+ },
+)
+_sym_db.RegisterMessage(ModelReadyRequest)
+
+ModelReadyResponse = _reflection.GeneratedProtocolMessageType(
+ "ModelReadyResponse",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MODELREADYRESPONSE,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelReadyResponse)
+ },
+)
+_sym_db.RegisterMessage(ModelReadyResponse)
+
+ServerMetadataRequest = _reflection.GeneratedProtocolMessageType(
+ "ServerMetadataRequest",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _SERVERMETADATAREQUEST,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ServerMetadataRequest)
+ },
+)
+_sym_db.RegisterMessage(ServerMetadataRequest)
+
+ServerMetadataResponse = _reflection.GeneratedProtocolMessageType(
+ "ServerMetadataResponse",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _SERVERMETADATARESPONSE,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ServerMetadataResponse)
+ },
+)
+_sym_db.RegisterMessage(ServerMetadataResponse)
+
+ModelMetadataRequest = _reflection.GeneratedProtocolMessageType(
+ "ModelMetadataRequest",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MODELMETADATAREQUEST,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelMetadataRequest)
+ },
+)
+_sym_db.RegisterMessage(ModelMetadataRequest)
+
+ModelMetadataResponse = _reflection.GeneratedProtocolMessageType(
+ "ModelMetadataResponse",
+ (_message.Message,),
+ {
+ "TensorMetadata": _reflection.GeneratedProtocolMessageType(
+ "TensorMetadata",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MODELMETADATARESPONSE_TENSORMETADATA,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelMetadataResponse.TensorMetadata)
+ },
+ ),
+ "DESCRIPTOR": _MODELMETADATARESPONSE,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelMetadataResponse)
+ },
+)
+_sym_db.RegisterMessage(ModelMetadataResponse)
+_sym_db.RegisterMessage(ModelMetadataResponse.TensorMetadata)
+
+ModelInferRequest = _reflection.GeneratedProtocolMessageType(
+ "ModelInferRequest",
+ (_message.Message,),
+ {
+ "InferInputTensor": _reflection.GeneratedProtocolMessageType(
+ "InferInputTensor",
+ (_message.Message,),
+ {
+ "ParametersEntry": _reflection.GeneratedProtocolMessageType(
+ "ParametersEntry",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MODELINFERREQUEST_INFERINPUTTENSOR_PARAMETERSENTRY,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelInferRequest.InferInputTensor.ParametersEntry)
+ },
+ ),
+ "DESCRIPTOR": _MODELINFERREQUEST_INFERINPUTTENSOR,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelInferRequest.InferInputTensor)
+ },
+ ),
+ "InferRequestedOutputTensor": _reflection.GeneratedProtocolMessageType(
+ "InferRequestedOutputTensor",
+ (_message.Message,),
+ {
+ "ParametersEntry": _reflection.GeneratedProtocolMessageType(
+ "ParametersEntry",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR_PARAMETERSENTRY,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelInferRequest.InferRequestedOutputTensor.ParametersEntry)
+ },
+ ),
+ "DESCRIPTOR": _MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelInferRequest.InferRequestedOutputTensor)
+ },
+ ),
+ "ParametersEntry": _reflection.GeneratedProtocolMessageType(
+ "ParametersEntry",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MODELINFERREQUEST_PARAMETERSENTRY,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelInferRequest.ParametersEntry)
+ },
+ ),
+ "DESCRIPTOR": _MODELINFERREQUEST,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelInferRequest)
+ },
+)
+_sym_db.RegisterMessage(ModelInferRequest)
+_sym_db.RegisterMessage(ModelInferRequest.InferInputTensor)
+_sym_db.RegisterMessage(ModelInferRequest.InferInputTensor.ParametersEntry)
+_sym_db.RegisterMessage(ModelInferRequest.InferRequestedOutputTensor)
+_sym_db.RegisterMessage(ModelInferRequest.InferRequestedOutputTensor.ParametersEntry)
+_sym_db.RegisterMessage(ModelInferRequest.ParametersEntry)
+
+ModelInferResponse = _reflection.GeneratedProtocolMessageType(
+ "ModelInferResponse",
+ (_message.Message,),
+ {
+ "InferOutputTensor": _reflection.GeneratedProtocolMessageType(
+ "InferOutputTensor",
+ (_message.Message,),
+ {
+ "ParametersEntry": _reflection.GeneratedProtocolMessageType(
+ "ParametersEntry",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MODELINFERRESPONSE_INFEROUTPUTTENSOR_PARAMETERSENTRY,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelInferResponse.InferOutputTensor.ParametersEntry)
+ },
+ ),
+ "DESCRIPTOR": _MODELINFERRESPONSE_INFEROUTPUTTENSOR,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelInferResponse.InferOutputTensor)
+ },
+ ),
+ "ParametersEntry": _reflection.GeneratedProtocolMessageType(
+ "ParametersEntry",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _MODELINFERRESPONSE_PARAMETERSENTRY,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelInferResponse.ParametersEntry)
+ },
+ ),
+ "DESCRIPTOR": _MODELINFERRESPONSE,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.ModelInferResponse)
+ },
+)
+_sym_db.RegisterMessage(ModelInferResponse)
+_sym_db.RegisterMessage(ModelInferResponse.InferOutputTensor)
+_sym_db.RegisterMessage(ModelInferResponse.InferOutputTensor.ParametersEntry)
+_sym_db.RegisterMessage(ModelInferResponse.ParametersEntry)
+
+InferParameter = _reflection.GeneratedProtocolMessageType(
+ "InferParameter",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _INFERPARAMETER,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.InferParameter)
+ },
+)
+_sym_db.RegisterMessage(InferParameter)
+
+InferTensorContents = _reflection.GeneratedProtocolMessageType(
+ "InferTensorContents",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _INFERTENSORCONTENTS,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.InferTensorContents)
+ },
+)
+_sym_db.RegisterMessage(InferTensorContents)
+
+RepositoryModelLoadRequest = _reflection.GeneratedProtocolMessageType(
+ "RepositoryModelLoadRequest",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _REPOSITORYMODELLOADREQUEST,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.RepositoryModelLoadRequest)
+ },
+)
+_sym_db.RegisterMessage(RepositoryModelLoadRequest)
+
+RepositoryModelLoadResponse = _reflection.GeneratedProtocolMessageType(
+ "RepositoryModelLoadResponse",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _REPOSITORYMODELLOADRESPONSE,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.RepositoryModelLoadResponse)
+ },
+)
+_sym_db.RegisterMessage(RepositoryModelLoadResponse)
+
+RepositoryModelUnloadRequest = _reflection.GeneratedProtocolMessageType(
+ "RepositoryModelUnloadRequest",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _REPOSITORYMODELUNLOADREQUEST,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.RepositoryModelUnloadRequest)
+ },
+)
+_sym_db.RegisterMessage(RepositoryModelUnloadRequest)
+
+RepositoryModelUnloadResponse = _reflection.GeneratedProtocolMessageType(
+ "RepositoryModelUnloadResponse",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _REPOSITORYMODELUNLOADRESPONSE,
+ "__module__": "grpc_predict_v2_pb2",
+ # @@protoc_insertion_point(class_scope:inference.RepositoryModelUnloadResponse)
+ },
+)
+_sym_db.RegisterMessage(RepositoryModelUnloadResponse)
+
+_GRPCINFERENCESERVICE = DESCRIPTOR.services_by_name["GRPCInferenceService"]
+if _descriptor._USE_C_DESCRIPTORS == False: # noqa: E712
+ DESCRIPTOR._options = None
+ _MODELINFERREQUEST_INFERINPUTTENSOR_PARAMETERSENTRY._options = None
+ _MODELINFERREQUEST_INFERINPUTTENSOR_PARAMETERSENTRY._serialized_options = b"8\001"
+ _MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR_PARAMETERSENTRY._options = None
+ _MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR_PARAMETERSENTRY._serialized_options = b"8\001"
+ _MODELINFERREQUEST_PARAMETERSENTRY._options = None
+ _MODELINFERREQUEST_PARAMETERSENTRY._serialized_options = b"8\001"
+ _MODELINFERRESPONSE_INFEROUTPUTTENSOR_PARAMETERSENTRY._options = None
+ _MODELINFERRESPONSE_INFEROUTPUTTENSOR_PARAMETERSENTRY._serialized_options = b"8\001"
+ _MODELINFERRESPONSE_PARAMETERSENTRY._options = None
+ _MODELINFERRESPONSE_PARAMETERSENTRY._serialized_options = b"8\001"
+ _SERVERLIVEREQUEST._serialized_start = 36
+ _SERVERLIVEREQUEST._serialized_end = 55
+ _SERVERLIVERESPONSE._serialized_start = 57
+ _SERVERLIVERESPONSE._serialized_end = 91
+ _SERVERREADYREQUEST._serialized_start = 93
+ _SERVERREADYREQUEST._serialized_end = 113
+ _SERVERREADYRESPONSE._serialized_start = 115
+ _SERVERREADYRESPONSE._serialized_end = 151
+ _MODELREADYREQUEST._serialized_start = 153
+ _MODELREADYREQUEST._serialized_end = 203
+ _MODELREADYRESPONSE._serialized_start = 205
+ _MODELREADYRESPONSE._serialized_end = 240
+ _SERVERMETADATAREQUEST._serialized_start = 242
+ _SERVERMETADATAREQUEST._serialized_end = 265
+ _SERVERMETADATARESPONSE._serialized_start = 267
+ _SERVERMETADATARESPONSE._serialized_end = 342
+ _MODELMETADATAREQUEST._serialized_start = 344
+ _MODELMETADATAREQUEST._serialized_end = 397
+ _MODELMETADATARESPONSE._serialized_start = 400
+ _MODELMETADATARESPONSE._serialized_end = 669
+ _MODELMETADATARESPONSE_TENSORMETADATA._serialized_start = 606
+ _MODELMETADATARESPONSE_TENSORMETADATA._serialized_end = 669
+ _MODELINFERREQUEST._serialized_start = 672
+ _MODELINFERREQUEST._serialized_end = 1550
+ _MODELINFERREQUEST_INFERINPUTTENSOR._serialized_start = 980
+ _MODELINFERREQUEST_INFERINPUTTENSOR._serialized_end = 1256
+ _MODELINFERREQUEST_INFERINPUTTENSOR_PARAMETERSENTRY._serialized_start = 1180
+ _MODELINFERREQUEST_INFERINPUTTENSOR_PARAMETERSENTRY._serialized_end = 1256
+ _MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR._serialized_start = 1259
+ _MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR._serialized_end = 1472
+ _MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR_PARAMETERSENTRY._serialized_start = (
+ 1180
+ )
+ _MODELINFERREQUEST_INFERREQUESTEDOUTPUTTENSOR_PARAMETERSENTRY._serialized_end = 1256
+ _MODELINFERREQUEST_PARAMETERSENTRY._serialized_start = 1180
+ _MODELINFERREQUEST_PARAMETERSENTRY._serialized_end = 1256
+ _MODELINFERRESPONSE._serialized_start = 1553
+ _MODELINFERRESPONSE._serialized_end = 2150
+ _MODELINFERRESPONSE_INFEROUTPUTTENSOR._serialized_start = 1793
+ _MODELINFERRESPONSE_INFEROUTPUTTENSOR._serialized_end = 2072
+ _MODELINFERRESPONSE_INFEROUTPUTTENSOR_PARAMETERSENTRY._serialized_start = 1180
+ _MODELINFERRESPONSE_INFEROUTPUTTENSOR_PARAMETERSENTRY._serialized_end = 1256
+ _MODELINFERRESPONSE_PARAMETERSENTRY._serialized_start = 1180
+ _MODELINFERRESPONSE_PARAMETERSENTRY._serialized_end = 1256
+ _INFERPARAMETER._serialized_start = 2152
+ _INFERPARAMETER._serialized_end = 2257
+ _INFERTENSORCONTENTS._serialized_start = 2260
+ _INFERTENSORCONTENTS._serialized_end = 2468
+ _REPOSITORYMODELLOADREQUEST._serialized_start = 2470
+ _REPOSITORYMODELLOADREQUEST._serialized_end = 2518
+ _REPOSITORYMODELLOADRESPONSE._serialized_start = 2520
+ _REPOSITORYMODELLOADRESPONSE._serialized_end = 2587
+ _REPOSITORYMODELUNLOADREQUEST._serialized_start = 2589
+ _REPOSITORYMODELUNLOADREQUEST._serialized_end = 2639
+ _REPOSITORYMODELUNLOADRESPONSE._serialized_start = 2641
+ _REPOSITORYMODELUNLOADRESPONSE._serialized_end = 2712
+ _GRPCINFERENCESERVICE._serialized_start = 2715
+ _GRPCINFERENCESERVICE._serialized_end = 3437
+# @@protoc_insertion_point(module_scope)
diff --git a/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2_pb2.pyi b/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2_pb2.pyi
new file mode 100644
index 000000000..dcaac5eb4
--- /dev/null
+++ b/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2_pb2.pyi
@@ -0,0 +1,414 @@
+from typing import (
+ ClassVar as _ClassVar,
+)
+from typing import (
+ Iterable as _Iterable,
+)
+from typing import (
+ Mapping as _Mapping,
+)
+from typing import (
+ Optional as _Optional,
+)
+from typing import (
+ Union as _Union,
+)
+
+from google.protobuf import descriptor as _descriptor
+from google.protobuf import message as _message
+from google.protobuf.internal import containers as _containers
+
+DESCRIPTOR: _descriptor.FileDescriptor
+
+class InferParameter(_message.Message):
+ __slots__ = ["bool_param", "int64_param", "string_param"]
+ BOOL_PARAM_FIELD_NUMBER: _ClassVar[int]
+ INT64_PARAM_FIELD_NUMBER: _ClassVar[int]
+ STRING_PARAM_FIELD_NUMBER: _ClassVar[int]
+ bool_param: bool
+ int64_param: int
+ string_param: str
+ def __init__(
+ self,
+ bool_param: bool = ...,
+ int64_param: _Optional[int] = ...,
+ string_param: _Optional[str] = ...,
+ ) -> None: ...
+
+class InferTensorContents(_message.Message):
+ __slots__ = [
+ "bool_contents",
+ "bytes_contents",
+ "fp32_contents",
+ "fp64_contents",
+ "int64_contents",
+ "int_contents",
+ "uint64_contents",
+ "uint_contents",
+ ]
+ BOOL_CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ BYTES_CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ FP32_CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ FP64_CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ INT64_CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ INT_CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ UINT64_CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ UINT_CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ bool_contents: _containers.RepeatedScalarFieldContainer[bool]
+ bytes_contents: _containers.RepeatedScalarFieldContainer[bytes]
+ fp32_contents: _containers.RepeatedScalarFieldContainer[float]
+ fp64_contents: _containers.RepeatedScalarFieldContainer[float]
+ int64_contents: _containers.RepeatedScalarFieldContainer[int]
+ int_contents: _containers.RepeatedScalarFieldContainer[int]
+ uint64_contents: _containers.RepeatedScalarFieldContainer[int]
+ uint_contents: _containers.RepeatedScalarFieldContainer[int]
+ def __init__(
+ self,
+ bool_contents: _Optional[_Iterable[bool]] = ...,
+ int_contents: _Optional[_Iterable[int]] = ...,
+ int64_contents: _Optional[_Iterable[int]] = ...,
+ uint_contents: _Optional[_Iterable[int]] = ...,
+ uint64_contents: _Optional[_Iterable[int]] = ...,
+ fp32_contents: _Optional[_Iterable[float]] = ...,
+ fp64_contents: _Optional[_Iterable[float]] = ...,
+ bytes_contents: _Optional[_Iterable[bytes]] = ...,
+ ) -> None: ...
+
+class ModelInferRequest(_message.Message):
+ __slots__ = [
+ "id",
+ "inputs",
+ "model_name",
+ "model_version",
+ "outputs",
+ "parameters",
+ "raw_input_contents",
+ ]
+
+ class InferInputTensor(_message.Message):
+ __slots__ = ["contents", "datatype", "name", "parameters", "shape"]
+
+ class ParametersEntry(_message.Message):
+ __slots__ = ["key", "value"]
+ KEY_FIELD_NUMBER: _ClassVar[int]
+ VALUE_FIELD_NUMBER: _ClassVar[int]
+ key: str
+ value: InferParameter
+ def __init__(
+ self,
+ key: _Optional[str] = ...,
+ value: _Optional[_Union[InferParameter, _Mapping]] = ...,
+ ) -> None: ...
+
+ CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ DATATYPE_FIELD_NUMBER: _ClassVar[int]
+ NAME_FIELD_NUMBER: _ClassVar[int]
+ PARAMETERS_FIELD_NUMBER: _ClassVar[int]
+ SHAPE_FIELD_NUMBER: _ClassVar[int]
+ contents: InferTensorContents
+ datatype: str
+ name: str
+ parameters: _containers.MessageMap[str, InferParameter]
+ shape: _containers.RepeatedScalarFieldContainer[int]
+ def __init__(
+ self,
+ name: _Optional[str] = ...,
+ datatype: _Optional[str] = ...,
+ shape: _Optional[_Iterable[int]] = ...,
+ parameters: _Optional[_Mapping[str, InferParameter]] = ...,
+ contents: _Optional[_Union[InferTensorContents, _Mapping]] = ...,
+ ) -> None: ...
+
+ class InferRequestedOutputTensor(_message.Message):
+ __slots__ = ["name", "parameters"]
+
+ class ParametersEntry(_message.Message):
+ __slots__ = ["key", "value"]
+ KEY_FIELD_NUMBER: _ClassVar[int]
+ VALUE_FIELD_NUMBER: _ClassVar[int]
+ key: str
+ value: InferParameter
+ def __init__(
+ self,
+ key: _Optional[str] = ...,
+ value: _Optional[_Union[InferParameter, _Mapping]] = ...,
+ ) -> None: ...
+
+ NAME_FIELD_NUMBER: _ClassVar[int]
+ PARAMETERS_FIELD_NUMBER: _ClassVar[int]
+ name: str
+ parameters: _containers.MessageMap[str, InferParameter]
+ def __init__(
+ self,
+ name: _Optional[str] = ...,
+ parameters: _Optional[_Mapping[str, InferParameter]] = ...,
+ ) -> None: ...
+
+ class ParametersEntry(_message.Message):
+ __slots__ = ["key", "value"]
+ KEY_FIELD_NUMBER: _ClassVar[int]
+ VALUE_FIELD_NUMBER: _ClassVar[int]
+ key: str
+ value: InferParameter
+ def __init__(
+ self,
+ key: _Optional[str] = ...,
+ value: _Optional[_Union[InferParameter, _Mapping]] = ...,
+ ) -> None: ...
+
+ ID_FIELD_NUMBER: _ClassVar[int]
+ INPUTS_FIELD_NUMBER: _ClassVar[int]
+ MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
+ MODEL_VERSION_FIELD_NUMBER: _ClassVar[int]
+ OUTPUTS_FIELD_NUMBER: _ClassVar[int]
+ PARAMETERS_FIELD_NUMBER: _ClassVar[int]
+ RAW_INPUT_CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ id: str
+ inputs: _containers.RepeatedCompositeFieldContainer[
+ ModelInferRequest.InferInputTensor
+ ]
+ model_name: str
+ model_version: str
+ outputs: _containers.RepeatedCompositeFieldContainer[
+ ModelInferRequest.InferRequestedOutputTensor
+ ]
+ parameters: _containers.MessageMap[str, InferParameter]
+ raw_input_contents: _containers.RepeatedScalarFieldContainer[bytes]
+ def __init__(
+ self,
+ model_name: _Optional[str] = ...,
+ model_version: _Optional[str] = ...,
+ id: _Optional[str] = ...,
+ parameters: _Optional[_Mapping[str, InferParameter]] = ...,
+ inputs: _Optional[
+ _Iterable[_Union[ModelInferRequest.InferInputTensor, _Mapping]]
+ ] = ...,
+ outputs: _Optional[
+ _Iterable[_Union[ModelInferRequest.InferRequestedOutputTensor, _Mapping]]
+ ] = ...,
+ raw_input_contents: _Optional[_Iterable[bytes]] = ...,
+ ) -> None: ...
+
+class ModelInferResponse(_message.Message):
+ __slots__ = [
+ "id",
+ "model_name",
+ "model_version",
+ "outputs",
+ "parameters",
+ "raw_output_contents",
+ ]
+
+ class InferOutputTensor(_message.Message):
+ __slots__ = ["contents", "datatype", "name", "parameters", "shape"]
+
+ class ParametersEntry(_message.Message):
+ __slots__ = ["key", "value"]
+ KEY_FIELD_NUMBER: _ClassVar[int]
+ VALUE_FIELD_NUMBER: _ClassVar[int]
+ key: str
+ value: InferParameter
+ def __init__(
+ self,
+ key: _Optional[str] = ...,
+ value: _Optional[_Union[InferParameter, _Mapping]] = ...,
+ ) -> None: ...
+
+ CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ DATATYPE_FIELD_NUMBER: _ClassVar[int]
+ NAME_FIELD_NUMBER: _ClassVar[int]
+ PARAMETERS_FIELD_NUMBER: _ClassVar[int]
+ SHAPE_FIELD_NUMBER: _ClassVar[int]
+ contents: InferTensorContents
+ datatype: str
+ name: str
+ parameters: _containers.MessageMap[str, InferParameter]
+ shape: _containers.RepeatedScalarFieldContainer[int]
+ def __init__(
+ self,
+ name: _Optional[str] = ...,
+ datatype: _Optional[str] = ...,
+ shape: _Optional[_Iterable[int]] = ...,
+ parameters: _Optional[_Mapping[str, InferParameter]] = ...,
+ contents: _Optional[_Union[InferTensorContents, _Mapping]] = ...,
+ ) -> None: ...
+
+ class ParametersEntry(_message.Message):
+ __slots__ = ["key", "value"]
+ KEY_FIELD_NUMBER: _ClassVar[int]
+ VALUE_FIELD_NUMBER: _ClassVar[int]
+ key: str
+ value: InferParameter
+ def __init__(
+ self,
+ key: _Optional[str] = ...,
+ value: _Optional[_Union[InferParameter, _Mapping]] = ...,
+ ) -> None: ...
+
+ ID_FIELD_NUMBER: _ClassVar[int]
+ MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
+ MODEL_VERSION_FIELD_NUMBER: _ClassVar[int]
+ OUTPUTS_FIELD_NUMBER: _ClassVar[int]
+ PARAMETERS_FIELD_NUMBER: _ClassVar[int]
+ RAW_OUTPUT_CONTENTS_FIELD_NUMBER: _ClassVar[int]
+ id: str
+ model_name: str
+ model_version: str
+ outputs: _containers.RepeatedCompositeFieldContainer[
+ ModelInferResponse.InferOutputTensor
+ ]
+ parameters: _containers.MessageMap[str, InferParameter]
+ raw_output_contents: _containers.RepeatedScalarFieldContainer[bytes]
+ def __init__(
+ self,
+ model_name: _Optional[str] = ...,
+ model_version: _Optional[str] = ...,
+ id: _Optional[str] = ...,
+ parameters: _Optional[_Mapping[str, InferParameter]] = ...,
+ outputs: _Optional[
+ _Iterable[_Union[ModelInferResponse.InferOutputTensor, _Mapping]]
+ ] = ...,
+ raw_output_contents: _Optional[_Iterable[bytes]] = ...,
+ ) -> None: ...
+
+class ModelMetadataRequest(_message.Message):
+ __slots__ = ["name", "version"]
+ NAME_FIELD_NUMBER: _ClassVar[int]
+ VERSION_FIELD_NUMBER: _ClassVar[int]
+ name: str
+ version: str
+ def __init__(
+ self, name: _Optional[str] = ..., version: _Optional[str] = ...
+ ) -> None: ...
+
+class ModelMetadataResponse(_message.Message):
+ __slots__ = ["inputs", "name", "outputs", "platform", "versions"]
+
+ class TensorMetadata(_message.Message):
+ __slots__ = ["datatype", "name", "shape"]
+ DATATYPE_FIELD_NUMBER: _ClassVar[int]
+ NAME_FIELD_NUMBER: _ClassVar[int]
+ SHAPE_FIELD_NUMBER: _ClassVar[int]
+ datatype: str
+ name: str
+ shape: _containers.RepeatedScalarFieldContainer[int]
+ def __init__(
+ self,
+ name: _Optional[str] = ...,
+ datatype: _Optional[str] = ...,
+ shape: _Optional[_Iterable[int]] = ...,
+ ) -> None: ...
+
+ INPUTS_FIELD_NUMBER: _ClassVar[int]
+ NAME_FIELD_NUMBER: _ClassVar[int]
+ OUTPUTS_FIELD_NUMBER: _ClassVar[int]
+ PLATFORM_FIELD_NUMBER: _ClassVar[int]
+ VERSIONS_FIELD_NUMBER: _ClassVar[int]
+ inputs: _containers.RepeatedCompositeFieldContainer[
+ ModelMetadataResponse.TensorMetadata
+ ]
+ name: str
+ outputs: _containers.RepeatedCompositeFieldContainer[
+ ModelMetadataResponse.TensorMetadata
+ ]
+ platform: str
+ versions: _containers.RepeatedScalarFieldContainer[str]
+ def __init__(
+ self,
+ name: _Optional[str] = ...,
+ versions: _Optional[_Iterable[str]] = ...,
+ platform: _Optional[str] = ...,
+ inputs: _Optional[
+ _Iterable[_Union[ModelMetadataResponse.TensorMetadata, _Mapping]]
+ ] = ...,
+ outputs: _Optional[
+ _Iterable[_Union[ModelMetadataResponse.TensorMetadata, _Mapping]]
+ ] = ...,
+ ) -> None: ...
+
+class ModelReadyRequest(_message.Message):
+ __slots__ = ["name", "version"]
+ NAME_FIELD_NUMBER: _ClassVar[int]
+ VERSION_FIELD_NUMBER: _ClassVar[int]
+ name: str
+ version: str
+ def __init__(
+ self, name: _Optional[str] = ..., version: _Optional[str] = ...
+ ) -> None: ...
+
+class ModelReadyResponse(_message.Message):
+ __slots__ = ["ready"]
+ READY_FIELD_NUMBER: _ClassVar[int]
+ ready: bool
+ def __init__(self, ready: bool = ...) -> None: ...
+
+class RepositoryModelLoadRequest(_message.Message):
+ __slots__ = ["model_name"]
+ MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
+ model_name: str
+ def __init__(self, model_name: _Optional[str] = ...) -> None: ...
+
+class RepositoryModelLoadResponse(_message.Message):
+ __slots__ = ["isLoaded", "model_name"]
+ ISLOADED_FIELD_NUMBER: _ClassVar[int]
+ MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
+ isLoaded: bool
+ model_name: str
+ def __init__(
+ self, model_name: _Optional[str] = ..., isLoaded: bool = ...
+ ) -> None: ...
+
+class RepositoryModelUnloadRequest(_message.Message):
+ __slots__ = ["model_name"]
+ MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
+ model_name: str
+ def __init__(self, model_name: _Optional[str] = ...) -> None: ...
+
+class RepositoryModelUnloadResponse(_message.Message):
+ __slots__ = ["isUnloaded", "model_name"]
+ ISUNLOADED_FIELD_NUMBER: _ClassVar[int]
+ MODEL_NAME_FIELD_NUMBER: _ClassVar[int]
+ isUnloaded: bool
+ model_name: str
+ def __init__(
+ self, model_name: _Optional[str] = ..., isUnloaded: bool = ...
+ ) -> None: ...
+
+class ServerLiveRequest(_message.Message):
+ __slots__ = []
+ def __init__(self) -> None: ...
+
+class ServerLiveResponse(_message.Message):
+ __slots__ = ["live"]
+ LIVE_FIELD_NUMBER: _ClassVar[int]
+ live: bool
+ def __init__(self, live: bool = ...) -> None: ...
+
+class ServerMetadataRequest(_message.Message):
+ __slots__ = []
+ def __init__(self) -> None: ...
+
+class ServerMetadataResponse(_message.Message):
+ __slots__ = ["extensions", "name", "version"]
+ EXTENSIONS_FIELD_NUMBER: _ClassVar[int]
+ NAME_FIELD_NUMBER: _ClassVar[int]
+ VERSION_FIELD_NUMBER: _ClassVar[int]
+ extensions: _containers.RepeatedScalarFieldContainer[str]
+ name: str
+ version: str
+ def __init__(
+ self,
+ name: _Optional[str] = ...,
+ version: _Optional[str] = ...,
+ extensions: _Optional[_Iterable[str]] = ...,
+ ) -> None: ...
+
+class ServerReadyRequest(_message.Message):
+ __slots__ = []
+ def __init__(self) -> None: ...
+
+class ServerReadyResponse(_message.Message):
+ __slots__ = ["ready"]
+ READY_FIELD_NUMBER: _ClassVar[int]
+ ready: bool
+ def __init__(self, ready: bool = ...) -> None: ...
diff --git a/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2_pb2_grpc.py b/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2_pb2_grpc.py
new file mode 100644
index 000000000..a5f986c20
--- /dev/null
+++ b/hsml/python/hsml/client/istio/grpc/proto/grpc_predict_v2_pb2_grpc.py
@@ -0,0 +1,419 @@
+# Copyright 2022 The KServe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
+"""Client and server classes corresponding to protobuf-defined services."""
+
+import hsml.client.istio.grpc.inference_client as inference_client
+import hsml.client.istio.grpc.proto.grpc_predict_v2_pb2 as grpc__predict__v2__pb2
+
+
+class GRPCInferenceServiceStub(object):
+ """Inference Server GRPC endpoints."""
+
+ def __init__(self, channel):
+ """Constructor.
+
+ Args:
+ channel: A grpc.Channel.
+ """
+ self.ServerLive = channel.unary_unary(
+ "/inference.GRPCInferenceService/ServerLive",
+ request_serializer=grpc__predict__v2__pb2.ServerLiveRequest.SerializeToString,
+ response_deserializer=grpc__predict__v2__pb2.ServerLiveResponse.FromString,
+ )
+ self.ServerReady = channel.unary_unary(
+ "/inference.GRPCInferenceService/ServerReady",
+ request_serializer=grpc__predict__v2__pb2.ServerReadyRequest.SerializeToString,
+ response_deserializer=grpc__predict__v2__pb2.ServerReadyResponse.FromString,
+ )
+ self.ModelReady = channel.unary_unary(
+ "/inference.GRPCInferenceService/ModelReady",
+ request_serializer=grpc__predict__v2__pb2.ModelReadyRequest.SerializeToString,
+ response_deserializer=grpc__predict__v2__pb2.ModelReadyResponse.FromString,
+ )
+ self.ServerMetadata = channel.unary_unary(
+ "/inference.GRPCInferenceService/ServerMetadata",
+ request_serializer=grpc__predict__v2__pb2.ServerMetadataRequest.SerializeToString,
+ response_deserializer=grpc__predict__v2__pb2.ServerMetadataResponse.FromString,
+ )
+ self.ModelMetadata = channel.unary_unary(
+ "/inference.GRPCInferenceService/ModelMetadata",
+ request_serializer=grpc__predict__v2__pb2.ModelMetadataRequest.SerializeToString,
+ response_deserializer=grpc__predict__v2__pb2.ModelMetadataResponse.FromString,
+ )
+ self.ModelInfer = channel.unary_unary(
+ "/inference.GRPCInferenceService/ModelInfer",
+ request_serializer=grpc__predict__v2__pb2.ModelInferRequest.SerializeToString,
+ response_deserializer=grpc__predict__v2__pb2.ModelInferResponse.FromString,
+ )
+ self.RepositoryModelLoad = channel.unary_unary(
+ "/inference.GRPCInferenceService/RepositoryModelLoad",
+ request_serializer=grpc__predict__v2__pb2.RepositoryModelLoadRequest.SerializeToString,
+ response_deserializer=grpc__predict__v2__pb2.RepositoryModelLoadResponse.FromString,
+ )
+ self.RepositoryModelUnload = channel.unary_unary(
+ "/inference.GRPCInferenceService/RepositoryModelUnload",
+ request_serializer=grpc__predict__v2__pb2.RepositoryModelUnloadRequest.SerializeToString,
+ response_deserializer=grpc__predict__v2__pb2.RepositoryModelUnloadResponse.FromString,
+ )
+
+
+class GRPCInferenceServiceServicer(object):
+ """Inference Server GRPC endpoints."""
+
+ def ServerLive(self, request, context):
+ """The ServerLive API indicates if the inference server is able to receive
+ and respond to metadata and inference requests.
+ """
+ context.set_code(inference_client.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+ def ServerReady(self, request, context):
+ """The ServerReady API indicates if the server is ready for inferencing."""
+ context.set_code(inference_client.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+ def ModelReady(self, request, context):
+ """The ModelReady API indicates if a specific model is ready for inferencing."""
+ context.set_code(inference_client.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+ def ServerMetadata(self, request, context):
+ """The ServerMetadata API provides information about the server. Errors are
+ indicated by the google.rpc.Status returned for the request. The OK code
+ indicates success and other codes indicate failure.
+ """
+ context.set_code(inference_client.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+ def ModelMetadata(self, request, context):
+ """The per-model metadata API provides information about a model. Errors are
+ indicated by the google.rpc.Status returned for the request. The OK code
+ indicates success and other codes indicate failure.
+ """
+ context.set_code(inference_client.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+ def ModelInfer(self, request, context):
+ """The ModelInfer API performs inference using the specified model. Errors are
+ indicated by the google.rpc.Status returned for the request. The OK code
+ indicates success and other codes indicate failure.
+ """
+ context.set_code(inference_client.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+ def RepositoryModelLoad(self, request, context):
+ """Load or reload a model from a repository."""
+ context.set_code(inference_client.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+ def RepositoryModelUnload(self, request, context):
+ """Unload a model."""
+ context.set_code(inference_client.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
+
+def add_GRPCInferenceServiceServicer_to_server(servicer, server):
+ rpc_method_handlers = {
+ "ServerLive": inference_client.unary_unary_rpc_method_handler(
+ servicer.ServerLive,
+ request_deserializer=grpc__predict__v2__pb2.ServerLiveRequest.FromString,
+ response_serializer=grpc__predict__v2__pb2.ServerLiveResponse.SerializeToString,
+ ),
+ "ServerReady": inference_client.unary_unary_rpc_method_handler(
+ servicer.ServerReady,
+ request_deserializer=grpc__predict__v2__pb2.ServerReadyRequest.FromString,
+ response_serializer=grpc__predict__v2__pb2.ServerReadyResponse.SerializeToString,
+ ),
+ "ModelReady": inference_client.unary_unary_rpc_method_handler(
+ servicer.ModelReady,
+ request_deserializer=grpc__predict__v2__pb2.ModelReadyRequest.FromString,
+ response_serializer=grpc__predict__v2__pb2.ModelReadyResponse.SerializeToString,
+ ),
+ "ServerMetadata": inference_client.unary_unary_rpc_method_handler(
+ servicer.ServerMetadata,
+ request_deserializer=grpc__predict__v2__pb2.ServerMetadataRequest.FromString,
+ response_serializer=grpc__predict__v2__pb2.ServerMetadataResponse.SerializeToString,
+ ),
+ "ModelMetadata": inference_client.unary_unary_rpc_method_handler(
+ servicer.ModelMetadata,
+ request_deserializer=grpc__predict__v2__pb2.ModelMetadataRequest.FromString,
+ response_serializer=grpc__predict__v2__pb2.ModelMetadataResponse.SerializeToString,
+ ),
+ "ModelInfer": inference_client.unary_unary_rpc_method_handler(
+ servicer.ModelInfer,
+ request_deserializer=grpc__predict__v2__pb2.ModelInferRequest.FromString,
+ response_serializer=grpc__predict__v2__pb2.ModelInferResponse.SerializeToString,
+ ),
+ "RepositoryModelLoad": inference_client.unary_unary_rpc_method_handler(
+ servicer.RepositoryModelLoad,
+ request_deserializer=grpc__predict__v2__pb2.RepositoryModelLoadRequest.FromString,
+ response_serializer=grpc__predict__v2__pb2.RepositoryModelLoadResponse.SerializeToString,
+ ),
+ "RepositoryModelUnload": inference_client.unary_unary_rpc_method_handler(
+ servicer.RepositoryModelUnload,
+ request_deserializer=grpc__predict__v2__pb2.RepositoryModelUnloadRequest.FromString,
+ response_serializer=grpc__predict__v2__pb2.RepositoryModelUnloadResponse.SerializeToString,
+ ),
+ }
+ generic_handler = inference_client.method_handlers_generic_handler(
+ "inference.GRPCInferenceService", rpc_method_handlers
+ )
+ server.add_generic_rpc_handlers((generic_handler,))
+
+
+# This class is part of an EXPERIMENTAL API.
+class GRPCInferenceService(object):
+ """Inference Server GRPC endpoints."""
+
+ @staticmethod
+ def ServerLive(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return inference_client.experimental.unary_unary(
+ request,
+ target,
+ "/inference.GRPCInferenceService/ServerLive",
+ grpc__predict__v2__pb2.ServerLiveRequest.SerializeToString,
+ grpc__predict__v2__pb2.ServerLiveResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
+
+ @staticmethod
+ def ServerReady(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return inference_client.experimental.unary_unary(
+ request,
+ target,
+ "/inference.GRPCInferenceService/ServerReady",
+ grpc__predict__v2__pb2.ServerReadyRequest.SerializeToString,
+ grpc__predict__v2__pb2.ServerReadyResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
+
+ @staticmethod
+ def ModelReady(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return inference_client.experimental.unary_unary(
+ request,
+ target,
+ "/inference.GRPCInferenceService/ModelReady",
+ grpc__predict__v2__pb2.ModelReadyRequest.SerializeToString,
+ grpc__predict__v2__pb2.ModelReadyResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
+
+ @staticmethod
+ def ServerMetadata(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return inference_client.experimental.unary_unary(
+ request,
+ target,
+ "/inference.GRPCInferenceService/ServerMetadata",
+ grpc__predict__v2__pb2.ServerMetadataRequest.SerializeToString,
+ grpc__predict__v2__pb2.ServerMetadataResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
+
+ @staticmethod
+ def ModelMetadata(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return inference_client.experimental.unary_unary(
+ request,
+ target,
+ "/inference.GRPCInferenceService/ModelMetadata",
+ grpc__predict__v2__pb2.ModelMetadataRequest.SerializeToString,
+ grpc__predict__v2__pb2.ModelMetadataResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
+
+ @staticmethod
+ def ModelInfer(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return inference_client.experimental.unary_unary(
+ request,
+ target,
+ "/inference.GRPCInferenceService/ModelInfer",
+ grpc__predict__v2__pb2.ModelInferRequest.SerializeToString,
+ grpc__predict__v2__pb2.ModelInferResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
+
+ @staticmethod
+ def RepositoryModelLoad(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return inference_client.experimental.unary_unary(
+ request,
+ target,
+ "/inference.GRPCInferenceService/RepositoryModelLoad",
+ grpc__predict__v2__pb2.RepositoryModelLoadRequest.SerializeToString,
+ grpc__predict__v2__pb2.RepositoryModelLoadResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
+
+ @staticmethod
+ def RepositoryModelUnload(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return inference_client.experimental.unary_unary(
+ request,
+ target,
+ "/inference.GRPCInferenceService/RepositoryModelUnload",
+ grpc__predict__v2__pb2.RepositoryModelUnloadRequest.SerializeToString,
+ grpc__predict__v2__pb2.RepositoryModelUnloadResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
diff --git a/hsml/python/hsml/client/istio/internal.py b/hsml/python/hsml/client/istio/internal.py
new file mode 100644
index 000000000..b1befd39d
--- /dev/null
+++ b/hsml/python/hsml/client/istio/internal.py
@@ -0,0 +1,206 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import base64
+import os
+import textwrap
+from pathlib import Path
+
+import requests
+from hsml.client import auth, exceptions
+from hsml.client.istio import base as istio
+
+
+try:
+ import jks
+except ImportError:
+ pass
+
+
+class Client(istio.Client):
+ REQUESTS_VERIFY = "REQUESTS_VERIFY"
+ PROJECT_ID = "HOPSWORKS_PROJECT_ID"
+ PROJECT_NAME = "HOPSWORKS_PROJECT_NAME"
+ HADOOP_USER_NAME = "HADOOP_USER_NAME"
+ HDFS_USER = "HDFS_USER"
+
+ DOMAIN_CA_TRUSTSTORE_PEM = "DOMAIN_CA_TRUSTSTORE_PEM"
+ MATERIAL_DIRECTORY = "MATERIAL_DIRECTORY"
+ T_CERTIFICATE = "t_certificate"
+ K_CERTIFICATE = "k_certificate"
+ TRUSTSTORE_SUFFIX = "__tstore.jks"
+ KEYSTORE_SUFFIX = "__kstore.jks"
+ PEM_CA_CHAIN = "ca_chain.pem"
+ CERT_KEY_SUFFIX = "__cert.key"
+ MATERIAL_PWD = "material_passwd"
+ SECRETS_DIR = "SECRETS_DIR"
+
+ def __init__(self, host, port):
+ """Initializes a client being run from a job/notebook directly on Hopsworks."""
+ self._host = host
+ self._port = port
+ self._base_url = "http://" + self._host + ":" + str(self._port)
+
+ trust_store_path = self._get_trust_store_path()
+ hostname_verification = (
+ os.environ[self.REQUESTS_VERIFY]
+ if self.REQUESTS_VERIFY in os.environ
+ else "true"
+ )
+ self._project_id = os.environ[self.PROJECT_ID]
+ self._project_name = self._project_name()
+ self._auth = auth.ApiKeyAuth(self._get_serving_api_key())
+ self._verify = self._get_verify(hostname_verification, trust_store_path)
+ self._session = requests.session()
+
+ self._connected = True
+
+ def _project_name(self):
+ try:
+ return os.environ[self.PROJECT_NAME]
+ except KeyError:
+ pass
+
+ hops_user = self._project_user()
+ hops_user_split = hops_user.split(
+ "__"
+ ) # project users have username project__user
+ project = hops_user_split[0]
+ return project
+
+ def _project_user(self):
+ try:
+ hops_user = os.environ[self.HADOOP_USER_NAME]
+ except KeyError:
+ hops_user = os.environ[self.HDFS_USER]
+ return hops_user
+
+ def _get_trust_store_path(self):
+ """Convert truststore from jks to pem and return the location"""
+ ca_chain_path = Path(self.PEM_CA_CHAIN)
+ if not ca_chain_path.exists():
+ self._write_ca_chain(ca_chain_path)
+ return str(ca_chain_path)
+
+ def _write_ca_chain(self, ca_chain_path):
+ """
+ Converts JKS trustore file into PEM to be compatible with Python libraries
+ """
+ keystore_pw = self._cert_key
+ keystore_ca_cert = self._convert_jks_to_pem(
+ self._get_jks_key_store_path(), keystore_pw
+ )
+ truststore_ca_cert = self._convert_jks_to_pem(
+ self._get_jks_trust_store_path(), keystore_pw
+ )
+
+ with ca_chain_path.open("w") as f:
+ f.write(keystore_ca_cert + truststore_ca_cert)
+
+ def _convert_jks_to_pem(self, jks_path, keystore_pw):
+ """
+ Converts a keystore JKS that contains client private key,
+ client certificate and CA certificate that was used to
+ sign the certificate to PEM format and returns the CA certificate.
+ Args:
+ :jks_path: path to the JKS file
+ :pw: password for decrypting the JKS file
+ Returns:
+ strings: (ca_cert)
+ """
+ # load the keystore and decrypt it with password
+ ks = jks.KeyStore.load(jks_path, keystore_pw, try_decrypt_keys=True)
+ ca_certs = ""
+
+ # Convert CA Certificates into PEM format and append to string
+ for _alias, c in ks.certs.items():
+ ca_certs = ca_certs + self._bytes_to_pem_str(c.cert, "CERTIFICATE")
+ return ca_certs
+
+ def _bytes_to_pem_str(self, der_bytes, pem_type):
+ """
+ Utility function for creating PEM files
+
+ Args:
+ der_bytes: DER encoded bytes
+ pem_type: type of PEM, e.g Certificate, Private key, or RSA private key
+
+ Returns:
+ PEM String for a DER-encoded certificate or private key
+ """
+ pem_str = ""
+ pem_str = pem_str + "-----BEGIN {}-----".format(pem_type) + "\n"
+ pem_str = (
+ pem_str
+ + "\r\n".join(
+ textwrap.wrap(base64.b64encode(der_bytes).decode("ascii"), 64)
+ )
+ + "\n"
+ )
+ pem_str = pem_str + "-----END {}-----".format(pem_type) + "\n"
+ return pem_str
+
+ def _get_jks_trust_store_path(self):
+ """
+ Get truststore location
+
+ Returns:
+ truststore location
+ """
+ t_certificate = Path(self.T_CERTIFICATE)
+ if t_certificate.exists():
+ return str(t_certificate)
+ else:
+ username = os.environ[self.HADOOP_USER_NAME]
+ material_directory = Path(os.environ[self.MATERIAL_DIRECTORY])
+ return str(material_directory.joinpath(username + self.TRUSTSTORE_SUFFIX))
+
+ def _get_jks_key_store_path(self):
+ """
+ Get keystore location
+
+ Returns:
+ keystore location
+ """
+ k_certificate = Path(self.K_CERTIFICATE)
+ if k_certificate.exists():
+ return str(k_certificate)
+ else:
+ username = os.environ[self.HADOOP_USER_NAME]
+ material_directory = Path(os.environ[self.MATERIAL_DIRECTORY])
+ return str(material_directory.joinpath(username + self.KEYSTORE_SUFFIX))
+
+ def _get_cert_pw(self):
+ """
+ Get keystore password from local container
+
+ Returns:
+ Certificate password
+ """
+ pwd_path = Path(self.MATERIAL_PWD)
+ if not pwd_path.exists():
+ username = os.environ[self.HADOOP_USER_NAME]
+ material_directory = Path(os.environ[self.MATERIAL_DIRECTORY])
+ pwd_path = material_directory.joinpath(username + self.CERT_KEY_SUFFIX)
+
+ with pwd_path.open() as f:
+ return f.read()
+
+ def _get_serving_api_key(self):
+ """Retrieve serving API key from environment variable."""
+ if self.SERVING_API_KEY not in os.environ:
+ raise exceptions.InternalClientError("Serving API key not found")
+ return os.environ[self.SERVING_API_KEY]
diff --git a/hsml/python/hsml/client/istio/utils/__init__.py b/hsml/python/hsml/client/istio/utils/__init__.py
new file mode 100644
index 000000000..ff8055b9b
--- /dev/null
+++ b/hsml/python/hsml/client/istio/utils/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/client/istio/utils/infer_type.py b/hsml/python/hsml/client/istio/utils/infer_type.py
new file mode 100644
index 000000000..e1dd2ab92
--- /dev/null
+++ b/hsml/python/hsml/client/istio/utils/infer_type.py
@@ -0,0 +1,812 @@
+# Copyright 2023 The KServe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This implementation has been borrowed from kserve/kserve repository
+# https://github.com/kserve/kserve/blob/release-0.11/python/kserve/kserve/protocol/infer_type.py
+
+import struct
+from typing import Dict, List, Optional
+
+import numpy
+import numpy as np
+import pandas as pd
+from hsml.client.istio.grpc.errors import InvalidInput
+from hsml.client.istio.grpc.proto.grpc_predict_v2_pb2 import (
+ InferTensorContents,
+ ModelInferRequest,
+ ModelInferResponse,
+)
+from hsml.client.istio.utils.numpy_codec import from_np_dtype, to_np_dtype
+
+
+GRPC_CONTENT_DATATYPE_MAPPINGS = {
+ "BOOL": "bool_contents",
+ "INT8": "int_contents",
+ "INT16": "int_contents",
+ "INT32": "int_contents",
+ "INT64": "int64_contents",
+ "UINT8": "uint_contents",
+ "UINT16": "uint_contents",
+ "UINT32": "uint_contents",
+ "UINT64": "uint64_contents",
+ "FP32": "fp32_contents",
+ "FP64": "fp64_contents",
+ "BYTES": "bytes_contents",
+}
+
+
+def raise_error(msg):
+ """
+ Raise error with the provided message
+ """
+ raise InferenceServerException(msg=msg) from None
+
+
+def serialize_byte_tensor(input_tensor):
+ """
+ Serializes a bytes tensor into a flat numpy array of length prepended
+ bytes. The numpy array should use dtype of np.object. For np.bytes,
+ numpy will remove trailing zeros at the end of byte sequence and because
+ of this it should be avoided.
+
+ Parameters
+ ----------
+ input_tensor : np.array
+ The bytes tensor to serialize.
+
+ Returns
+ -------
+ serialized_bytes_tensor : np.array
+ The 1-D numpy array of type uint8 containing the serialized bytes in row-major form.
+
+ Raises
+ ------
+ InferenceServerException
+ If unable to serialize the given tensor.
+ """
+
+ if input_tensor.size == 0:
+ return np.empty([0], dtype=np.object_)
+
+ # If the input is a tensor of string/bytes objects, then must flatten those into
+ # a 1-dimensional array containing the 4-byte byte size followed by the
+ # actual element bytes. All elements are concatenated together in row-major
+ # order.
+
+ if (input_tensor.dtype != np.object_) and (input_tensor.dtype.type != np.bytes_):
+ raise_error("cannot serialize bytes tensor: invalid datatype")
+
+ flattened_ls = []
+ # 'C' order is row-major.
+ for obj in np.nditer(input_tensor, flags=["refs_ok"], order="C"):
+ # If directly passing bytes to BYTES type,
+ # don't convert it to str as Python will encode the
+ # bytes which may distort the meaning
+ if input_tensor.dtype == np.object_:
+ if isinstance(obj.item(), bytes):
+ s = obj.item()
+ else:
+ s = str(obj.item()).encode("utf-8")
+ else:
+ s = obj.item()
+ flattened_ls.append(struct.pack(" np.ndarray:
+ dtype = to_np_dtype(self.datatype)
+ if dtype is None:
+ raise InvalidInput("invalid datatype in the input")
+ if self._raw_data is not None:
+ np_array = np.frombuffer(self._raw_data, dtype=dtype)
+ return np_array.reshape(self._shape)
+ else:
+ np_array = np.array(self._data, dtype=dtype)
+ return np_array.reshape(self._shape)
+
+ def set_data_from_numpy(self, input_tensor, binary_data=True):
+ """Set the tensor data from the specified numpy array for
+ input associated with this object.
+ Parameters
+ ----------
+ input_tensor : numpy array
+ The tensor data in numpy array format
+ binary_data : bool
+ Indicates whether to set data for the input in binary format
+ or explicit tensor within JSON. The default value is True,
+ which means the data will be delivered as binary data in the
+ HTTP body after the JSON object.
+ Raises
+ ------
+ InferenceServerException
+ If failed to set data for the tensor.
+ """
+ if not isinstance(input_tensor, (np.ndarray,)):
+ raise_error("input_tensor must be a numpy array")
+
+ dtype = from_np_dtype(input_tensor.dtype)
+ if self._datatype != dtype:
+ raise_error(
+ "got unexpected datatype {} from numpy array, expected {}".format(
+ dtype, self._datatype
+ )
+ )
+ valid_shape = True
+ if len(self._shape) != len(input_tensor.shape):
+ valid_shape = False
+ else:
+ for i in range(len(self._shape)):
+ if self._shape[i] != input_tensor.shape[i]:
+ valid_shape = False
+ if not valid_shape:
+ raise_error(
+ "got unexpected numpy array shape [{}], expected [{}]".format(
+ str(input_tensor.shape)[1:-1], str(self._shape)[1:-1]
+ )
+ )
+
+ if not binary_data:
+ self._parameters.pop("binary_data_size", None)
+ self._raw_data = None
+ if self._datatype == "BYTES":
+ self._data = []
+ try:
+ if input_tensor.size > 0:
+ for obj in np.nditer(
+ input_tensor, flags=["refs_ok"], order="C"
+ ):
+ # We need to convert the object to string using utf-8,
+ # if we want to use the binary_data=False. JSON requires
+ # the input to be a UTF-8 string.
+ if input_tensor.dtype == np.object_:
+ if isinstance(obj.item(), bytes):
+ self._data.append(str(obj.item(), encoding="utf-8"))
+ else:
+ self._data.append(str(obj.item()))
+ else:
+ self._data.append(str(obj.item(), encoding="utf-8"))
+ except UnicodeDecodeError:
+ raise_error(
+ f'Failed to encode "{obj.item()}" using UTF-8. Please use binary_data=True, if'
+ " you want to pass a byte array."
+ )
+ else:
+ self._data = [val.item() for val in input_tensor.flatten()]
+ else:
+ self._data = None
+ if self._datatype == "BYTES":
+ serialized_output = serialize_byte_tensor(input_tensor)
+ if serialized_output.size > 0:
+ self._raw_data = serialized_output.item()
+ else:
+ self._raw_data = b""
+ else:
+ self._raw_data = input_tensor.tobytes()
+ self._parameters["binary_data_size"] = len(self._raw_data)
+
+
+def get_content(datatype: str, data: InferTensorContents):
+ if datatype == "BOOL":
+ return list(data.bool_contents)
+ elif datatype in ["UINT8", "UINT16", "UINT32"]:
+ return list(data.uint_contents)
+ elif datatype == "UINT64":
+ return list(data.uint64_contents)
+ elif datatype in ["INT8", "INT16", "INT32"]:
+ return list(data.int_contents)
+ elif datatype == "INT64":
+ return list(data.int64_contents)
+ elif datatype == "FP32":
+ return list(data.fp32_contents)
+ elif datatype == "FP64":
+ return list(data.fp64_contents)
+ elif datatype == "BYTES":
+ return list(data.bytes_contents)
+ else:
+ raise InvalidInput("invalid content type")
+
+
+class InferRequest:
+ """InferenceRequest Model
+
+ $inference_request =
+ {
+ "id" : $string #optional,
+ "parameters" : $parameters #optional,
+ "inputs" : [ $request_input, ... ],
+ "outputs" : [ $request_output, ... ] #optional
+ }
+ """
+
+ id: Optional[str]
+ model_name: str
+ parameters: Optional[Dict]
+ inputs: List[InferInput]
+ from_grpc: bool
+
+ def __init__(
+ self,
+ model_name: str,
+ infer_inputs: List[InferInput],
+ request_id=None,
+ raw_inputs=None,
+ from_grpc=False,
+ parameters=None,
+ ):
+ if parameters is None:
+ parameters = {}
+ self.id = request_id
+ self.model_name = model_name
+ self.inputs = infer_inputs
+ self.parameters = parameters
+ self.from_grpc = from_grpc
+ if raw_inputs:
+ for i, raw_input in enumerate(raw_inputs):
+ self.inputs[i]._raw_data = raw_input
+
+ @classmethod
+ def from_grpc(cls, request: ModelInferRequest):
+ infer_inputs = [
+ InferInput(
+ name=input_tensor.name,
+ shape=list(input_tensor.shape),
+ datatype=input_tensor.datatype,
+ data=get_content(input_tensor.datatype, input_tensor.contents),
+ parameters=input_tensor.parameters,
+ )
+ for input_tensor in request.inputs
+ ]
+ return cls(
+ request_id=request.id,
+ model_name=request.model_name,
+ infer_inputs=infer_inputs,
+ raw_inputs=request.raw_input_contents,
+ from_grpc=True,
+ parameters=request.parameters,
+ )
+
+ def to_rest(self) -> Dict:
+ """Converts the InferRequest object to v2 REST InferenceRequest message"""
+ infer_inputs = []
+ for infer_input in self.inputs:
+ infer_input_dict = {
+ "name": infer_input.name,
+ "shape": infer_input.shape,
+ "datatype": infer_input.datatype,
+ }
+ if isinstance(infer_input.data, numpy.ndarray):
+ infer_input.set_data_from_numpy(infer_input.data, binary_data=False)
+ infer_input_dict["data"] = infer_input.data
+ else:
+ infer_input_dict["data"] = infer_input.data
+ infer_inputs.append(infer_input_dict)
+ return {"id": self.id, "inputs": infer_inputs}
+
+ def to_grpc(self) -> ModelInferRequest:
+ """Converts the InferRequest object to gRPC ModelInferRequest message"""
+ infer_inputs = []
+ raw_input_contents = []
+ for infer_input in self.inputs:
+ if isinstance(infer_input.data, numpy.ndarray):
+ infer_input.set_data_from_numpy(infer_input.data, binary_data=True)
+ infer_input_dict = {
+ "name": infer_input.name,
+ "shape": infer_input.shape,
+ "datatype": infer_input.datatype,
+ }
+ if infer_input._raw_data is not None:
+ raw_input_contents.append(infer_input._raw_data)
+ else:
+ if not isinstance(infer_input.data, List):
+ raise InvalidInput("input data is not a List")
+ infer_input_dict["contents"] = {}
+ data_key = GRPC_CONTENT_DATATYPE_MAPPINGS.get(
+ infer_input.datatype, None
+ )
+ if data_key is not None:
+ infer_input._data = [
+ bytes(val, "utf-8") if isinstance(val, str) else val
+ for val in infer_input.data
+ ] # str to byte conversion for grpc proto
+ infer_input_dict["contents"][data_key] = infer_input.data
+ else:
+ raise InvalidInput("invalid input datatype")
+ infer_inputs.append(infer_input_dict)
+
+ return ModelInferRequest(
+ id=self.id,
+ model_name=self.model_name,
+ inputs=infer_inputs,
+ raw_input_contents=raw_input_contents,
+ )
+
+ def as_dataframe(self) -> pd.DataFrame:
+ """
+ Decode the tensor inputs as pandas dataframe
+ """
+ dfs = []
+ for input in self.inputs:
+ input_data = input.data
+ if input.datatype == "BYTES":
+ input_data = [
+ str(val, "utf-8") if isinstance(val, bytes) else val
+ for val in input.data
+ ]
+ dfs.append(pd.DataFrame(input_data, columns=[input.name]))
+ return pd.concat(dfs, axis=1)
+
+
+class InferOutput:
+ def __init__(self, name, shape, datatype, data=None, parameters=None):
+ """An object of InferOutput class is used to describe
+ input tensor for an inference request.
+ Parameters
+ ----------
+ name : str
+ The name of input whose data will be described by this object
+ shape : list
+ The shape of the associated input.
+ datatype : str
+ The datatype of the associated input.
+ data : Union[List, InferTensorContents]
+ The data of the REST/gRPC input. When data is not set, raw_data is used for gRPC for numpy array bytes.
+ parameters : dict
+ The additional server-specific parameters.
+ """
+ if parameters is None:
+ parameters = {}
+ self._name = name
+ self._shape = shape
+ self._datatype = datatype
+ self._parameters = parameters
+ self._data = data
+ self._raw_data = None
+
+ @property
+ def name(self):
+ """Get the name of input associated with this object.
+ Returns
+ -------
+ str
+ The name of input
+ """
+ return self._name
+
+ @property
+ def datatype(self):
+ """Get the datatype of input associated with this object.
+ Returns
+ -------
+ str
+ The datatype of input
+ """
+ return self._datatype
+
+ @property
+ def data(self):
+ """Get the data of InferOutput"""
+ return self._data
+
+ @property
+ def shape(self):
+ """Get the shape of input associated with this object.
+ Returns
+ -------
+ list
+ The shape of input
+ """
+ return self._shape
+
+ @property
+ def parameters(self):
+ """Get the parameters of input associated with this object.
+ Returns
+ -------
+ dict
+ The key, value pair of string and InferParameter
+ """
+ return self._parameters
+
+ def set_shape(self, shape):
+ """Set the shape of input.
+ Parameters
+ ----------
+ shape : list
+ The shape of the associated input.
+ """
+ self._shape = shape
+
+ def as_numpy(self) -> numpy.ndarray:
+ """
+ Decode the tensor data as numpy array
+ """
+ dtype = to_np_dtype(self.datatype)
+ if dtype is None:
+ raise InvalidInput("invalid datatype in the input")
+ if self._raw_data is not None:
+ np_array = np.frombuffer(self._raw_data, dtype=dtype)
+ return np_array.reshape(self._shape)
+ else:
+ np_array = np.array(self._data, dtype=dtype)
+ return np_array.reshape(self._shape)
+
+ def set_data_from_numpy(self, input_tensor, binary_data=True):
+ """Set the tensor data from the specified numpy array for
+ input associated with this object.
+ Parameters
+ ----------
+ input_tensor : numpy array
+ The tensor data in numpy array format
+ binary_data : bool
+ Indicates whether to set data for the input in binary format
+ or explicit tensor within JSON. The default value is True,
+ which means the data will be delivered as binary data in the
+ HTTP body after the JSON object.
+ Raises
+ ------
+ InferenceServerException
+ If failed to set data for the tensor.
+ """
+ if not isinstance(input_tensor, (np.ndarray,)):
+ raise_error("input_tensor must be a numpy array")
+
+ dtype = from_np_dtype(input_tensor.dtype)
+ if self._datatype != dtype:
+ raise_error(
+ "got unexpected datatype {} from numpy array, expected {}".format(
+ dtype, self._datatype
+ )
+ )
+ valid_shape = True
+ if len(self._shape) != len(input_tensor.shape):
+ valid_shape = False
+ else:
+ for i in range(len(self._shape)):
+ if self._shape[i] != input_tensor.shape[i]:
+ valid_shape = False
+ if not valid_shape:
+ raise_error(
+ "got unexpected numpy array shape [{}], expected [{}]".format(
+ str(input_tensor.shape)[1:-1], str(self._shape)[1:-1]
+ )
+ )
+
+ if not binary_data:
+ self._parameters.pop("binary_data_size", None)
+ self._raw_data = None
+ if self._datatype == "BYTES":
+ self._data = []
+ try:
+ if input_tensor.size > 0:
+ for obj in np.nditer(
+ input_tensor, flags=["refs_ok"], order="C"
+ ):
+ # We need to convert the object to string using utf-8,
+ # if we want to use the binary_data=False. JSON requires
+ # the input to be a UTF-8 string.
+ if input_tensor.dtype == np.object_:
+ if isinstance(obj.item(), bytes):
+ self._data.append(str(obj.item(), encoding="utf-8"))
+ else:
+ self._data.append(str(obj.item()))
+ else:
+ self._data.append(str(obj.item(), encoding="utf-8"))
+ except UnicodeDecodeError:
+ raise_error(
+ f'Failed to encode "{obj.item()}" using UTF-8. Please use binary_data=True, if'
+ " you want to pass a byte array."
+ )
+ else:
+ self._data = [val.item() for val in input_tensor.flatten()]
+ else:
+ self._data = None
+ if self._datatype == "BYTES":
+ serialized_output = serialize_byte_tensor(input_tensor)
+ if serialized_output.size > 0:
+ self._raw_data = serialized_output.item()
+ else:
+ self._raw_data = b""
+ else:
+ self._raw_data = input_tensor.tobytes()
+ self._parameters["binary_data_size"] = len(self._raw_data)
+
+
+class InferResponse:
+ """InferenceResponse
+
+ $inference_response =
+ {
+ "model_name" : $string,
+ "model_version" : $string #optional,
+ "id" : $string,
+ "parameters" : $parameters #optional,
+ "outputs" : [ $response_output, ... ]
+ }
+ """
+
+ id: str
+ model_name: str
+ parameters: Optional[Dict]
+ outputs: List[InferOutput]
+ from_grpc: bool
+
+ def __init__(
+ self,
+ response_id: str,
+ model_name: str,
+ infer_outputs: List[InferOutput],
+ raw_outputs=None,
+ from_grpc=False,
+ parameters=None,
+ ):
+ if parameters is None:
+ parameters = {}
+ self.id = response_id
+ self.model_name = model_name
+ self.outputs = infer_outputs
+ self.parameters = parameters
+ self.from_grpc = from_grpc
+ if raw_outputs:
+ for i, raw_output in enumerate(raw_outputs):
+ self.outputs[i]._raw_data = raw_output
+
+ @classmethod
+ def from_grpc(cls, response: ModelInferResponse) -> "InferResponse":
+ infer_outputs = [
+ InferOutput(
+ name=output.name,
+ shape=list(output.shape),
+ datatype=output.datatype,
+ data=get_content(output.datatype, output.contents),
+ parameters=output.parameters,
+ )
+ for output in response.outputs
+ ]
+ return cls(
+ model_name=response.model_name,
+ response_id=response.id,
+ parameters=response.parameters,
+ infer_outputs=infer_outputs,
+ raw_outputs=response.raw_output_contents,
+ from_grpc=True,
+ )
+
+ @classmethod
+ def from_rest(cls, model_name: str, response: Dict) -> "InferResponse":
+ infer_outputs = [
+ InferOutput(
+ name=output["name"],
+ shape=list(output["shape"]),
+ datatype=output["datatype"],
+ data=output["data"],
+ parameters=output.get("parameters", {}),
+ )
+ for output in response["outputs"]
+ ]
+ return cls(
+ model_name=model_name,
+ response_id=response.get("id", None),
+ parameters=response.get("parameters", {}),
+ infer_outputs=infer_outputs,
+ )
+
+ def to_rest(self) -> Dict:
+ """Converts the InferResponse object to v2 REST InferenceRequest message"""
+ infer_outputs = []
+ for infer_output in self.outputs:
+ infer_output_dict = {
+ "name": infer_output.name,
+ "shape": infer_output.shape,
+ "datatype": infer_output.datatype,
+ }
+ if isinstance(infer_output.data, numpy.ndarray):
+ infer_output.set_data_from_numpy(infer_output.data, binary_data=False)
+ infer_output_dict["data"] = infer_output.data
+ elif isinstance(infer_output._raw_data, bytes):
+ infer_output_dict["data"] = infer_output.as_numpy().tolist()
+ else:
+ infer_output_dict["data"] = infer_output.data
+ infer_outputs.append(infer_output_dict)
+ res = {"id": self.id, "model_name": self.model_name, "outputs": infer_outputs}
+ return res
+
+ def to_grpc(self) -> ModelInferResponse:
+ """Converts the InferResponse object to gRPC ModelInferRequest message"""
+ infer_outputs = []
+ raw_output_contents = []
+ for infer_output in self.outputs:
+ if isinstance(infer_output.data, numpy.ndarray):
+ infer_output.set_data_from_numpy(infer_output.data, binary_data=True)
+ infer_output_dict = {
+ "name": infer_output.name,
+ "shape": infer_output.shape,
+ "datatype": infer_output.datatype,
+ }
+ if infer_output._raw_data is not None:
+ raw_output_contents.append(infer_output._raw_data)
+ else:
+ if not isinstance(infer_output.data, List):
+ raise InvalidInput("output data is not a List")
+ infer_output_dict["contents"] = {}
+ data_key = GRPC_CONTENT_DATATYPE_MAPPINGS.get(
+ infer_output.datatype, None
+ )
+ if data_key is not None:
+ infer_output._data = [
+ bytes(val, "utf-8") if isinstance(val, str) else val
+ for val in infer_output.data
+ ] # str to byte conversion for grpc proto
+ infer_output_dict["contents"][data_key] = infer_output.data
+ else:
+ raise InvalidInput("to_grpc: invalid output datatype")
+ infer_outputs.append(infer_output_dict)
+
+ return ModelInferResponse(
+ id=self.id,
+ model_name=self.model_name,
+ outputs=infer_outputs,
+ raw_output_contents=raw_output_contents,
+ )
diff --git a/hsml/python/hsml/client/istio/utils/numpy_codec.py b/hsml/python/hsml/client/istio/utils/numpy_codec.py
new file mode 100644
index 000000000..3c6ecb606
--- /dev/null
+++ b/hsml/python/hsml/client/istio/utils/numpy_codec.py
@@ -0,0 +1,67 @@
+# Copyright 2021 The KServe Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This implementation has been borrowed from kserve/kserve repository
+# https://github.com/kserve/kserve/blob/release-0.11/python/kserve/kserve/utils/numpy_codec.py
+
+import numpy as np
+
+
+def to_np_dtype(dtype):
+ dtype_map = {
+ "BOOL": bool,
+ "INT8": np.int8,
+ "INT16": np.int16,
+ "INT32": np.int32,
+ "INT64": np.int64,
+ "UINT8": np.uint8,
+ "UINT16": np.uint16,
+ "UINT32": np.uint32,
+ "UINT64": np.uint64,
+ "FP16": np.float16,
+ "FP32": np.float32,
+ "FP64": np.float64,
+ "BYTES": np.object_,
+ }
+ return dtype_map.get(dtype, None)
+
+
+def from_np_dtype(np_dtype):
+ if np_dtype == bool:
+ return "BOOL"
+ elif np_dtype == np.int8:
+ return "INT8"
+ elif np_dtype == np.int16:
+ return "INT16"
+ elif np_dtype == np.int32:
+ return "INT32"
+ elif np_dtype == np.int64:
+ return "INT64"
+ elif np_dtype == np.uint8:
+ return "UINT8"
+ elif np_dtype == np.uint16:
+ return "UINT16"
+ elif np_dtype == np.uint32:
+ return "UINT32"
+ elif np_dtype == np.uint64:
+ return "UINT64"
+ elif np_dtype == np.float16:
+ return "FP16"
+ elif np_dtype == np.float32:
+ return "FP32"
+ elif np_dtype == np.float64:
+ return "FP64"
+ elif np_dtype == np.object_ or np_dtype.type == np.bytes_:
+ return "BYTES"
+ return None
diff --git a/hsml/python/hsml/connection.py b/hsml/python/hsml/connection.py
new file mode 100644
index 000000000..d9d61b9e8
--- /dev/null
+++ b/hsml/python/hsml/connection.py
@@ -0,0 +1,294 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+
+from hsml import client
+from hsml.core import model_api, model_registry_api, model_serving_api
+from hsml.decorators import connected, not_connected
+from requests.exceptions import ConnectionError
+
+
+CONNECTION_SAAS_HOSTNAME = "c.app.hopsworks.ai"
+
+HOPSWORKS_PORT_DEFAULT = 443
+HOSTNAME_VERIFICATION_DEFAULT = True
+
+
+class Connection:
+ """A Hopsworks Model Management connection object.
+
+ The connection is project specific, so you can access the project's own Model Registry and Model Serving.
+
+ This class provides convenience classmethods accessible from the `hsml`-module:
+
+ !!! example "Connection factory"
+ For convenience, `hsml` provides a factory method, accessible from the top level
+ module, so you don't have to import the `Connection` class manually:
+
+ ```python
+ import hsml
+ conn = hsml.connection()
+ ```
+
+ !!! hint "Save API Key as File"
+ To get started quickly, you can simply create a file with the previously
+ created Hopsworks API Key and place it on the environment from which you
+ wish to connect to Hopsworks.
+
+ You can then connect by simply passing the path to the key file when
+ instantiating a connection:
+
+ ```python hl_lines="6"
+ import hsml
+ conn = hsml.connection(
+ 'my_instance', # DNS of your Hopsworks instance
+ 443, # Port to reach your Hopsworks instance, defaults to 443
+ 'my_project', # Name of your Hopsworks project
+ api_key_file='modelregistry.key', # The file containing the API key generated above
+ hostname_verification=True) # Disable for self-signed certificates
+ )
+ mr = conn.get_model_registry() # Get the project's default model registry
+ ms = conn.get_model_serving() # Uses the previous model registry
+ ```
+
+ Clients in external clusters need to connect to the Hopsworks Model Registry and Model Serving using an
+ API key. The API key is generated inside the Hopsworks platform, and requires at
+ least the "project", "modelregistry", "dataset.create", "dataset.view", "dataset.delete", "serving" and "kafka" scopes
+ to be able to access a model registry and its model serving.
+ For more information, see the [integration guides](../../integrations/overview.md).
+
+ # Arguments
+ host: The hostname of the Hopsworks instance, defaults to `None`.
+ port: The port on which the Hopsworks instance can be reached,
+ defaults to `443`.
+ project: The name of the project to connect to. When running on Hopsworks, this
+ defaults to the project from where the client is run from.
+ Defaults to `None`.
+ hostname_verification: Whether or not to verify Hopsworks certificate, defaults
+ to `True`.
+ trust_store_path: Path on the file system containing the Hopsworks certificates,
+ defaults to `None`.
+ api_key_file: Path to a file containing the API Key.
+ api_key_value: API Key as string, if provided, however, this should be used with care,
+ especially if the used notebook or job script is accessible by multiple parties. Defaults to `None`.
+
+ # Returns
+ `Connection`. Connection handle to perform operations on a Hopsworks project.
+ """
+
+ def __init__(
+ self,
+ host: str = None,
+ port: int = HOPSWORKS_PORT_DEFAULT,
+ project: str = None,
+ hostname_verification: bool = HOSTNAME_VERIFICATION_DEFAULT,
+ trust_store_path: str = None,
+ api_key_file: str = None,
+ api_key_value: str = None,
+ ):
+ self._host = host
+ self._port = port
+ self._project = project
+ self._hostname_verification = hostname_verification
+ self._trust_store_path = trust_store_path
+ self._api_key_file = api_key_file
+ self._api_key_value = api_key_value
+ self._connected = False
+ self._model_api = model_api.ModelApi()
+ self._model_registry_api = model_registry_api.ModelRegistryApi()
+ self._model_serving_api = model_serving_api.ModelServingApi()
+
+ self.connect()
+
+ @connected
+ def get_model_registry(self, project: str = None):
+ """Get a reference to a model registry to perform operations on, defaulting to the project's default model registry.
+ Shared model registries can be retrieved by passing the `project` argument.
+
+ # Arguments
+ project: The name of the project that owns the shared model registry,
+ the model registry must be shared with the project the connection was established for, defaults to `None`.
+ # Returns
+ `ModelRegistry`. A model registry handle object to perform operations on.
+ """
+ return self._model_registry_api.get(project)
+
+ @connected
+ def get_model_serving(self):
+ """Get a reference to model serving to perform operations on. Model serving operates on top of a model registry, defaulting to the project's default model registry.
+
+ !!! example
+ ```python
+
+ import hopsworks
+
+ project = hopsworks.login()
+
+ ms = project.get_model_serving()
+ ```
+
+ # Returns
+ `ModelServing`. A model serving handle object to perform operations on.
+ """
+ return self._model_serving_api.get()
+
+ @not_connected
+ def connect(self):
+ """Instantiate the connection.
+
+ Creating a `Connection` object implicitly calls this method for you to
+ instantiate the connection. However, it is possible to close the connection
+ gracefully with the `close()` method, in order to clean up materialized
+ certificates. This might be desired when working on external environments.
+ Subsequently you can call `connect()` again to reopen the connection.
+
+ !!! example
+ ```python
+ import hsml
+ conn = hsml.connection()
+ conn.close()
+ conn.connect()
+ ```
+ """
+ self._connected = True
+ try:
+ # init client
+ if client.hopsworks.base.Client.REST_ENDPOINT not in os.environ:
+ client.init(
+ "external",
+ self._host,
+ self._port,
+ self._project,
+ self._hostname_verification,
+ self._trust_store_path,
+ self._api_key_file,
+ self._api_key_value,
+ )
+ else:
+ client.init("internal")
+
+ self._model_api = model_api.ModelApi()
+ self._model_serving_api.load_default_configuration() # istio client, default resources,...
+ except (TypeError, ConnectionError):
+ self._connected = False
+ raise
+ print("Connected. Call `.close()` to terminate connection gracefully.")
+
+ def close(self):
+ """Close a connection gracefully.
+
+ This will clean up any materialized certificates on the local file system of
+ external environments.
+
+ Usage is recommended but optional.
+ """
+ client.stop()
+ self._model_api = None
+ self._connected = False
+ print("Connection closed.")
+
+ @classmethod
+ def connection(
+ cls,
+ host: str = None,
+ port: int = HOPSWORKS_PORT_DEFAULT,
+ project: str = None,
+ hostname_verification: bool = HOSTNAME_VERIFICATION_DEFAULT,
+ trust_store_path: str = None,
+ api_key_file: str = None,
+ api_key_value: str = None,
+ ):
+ """Connection factory method, accessible through `hsml.connection()`."""
+ return cls(
+ host,
+ port,
+ project,
+ hostname_verification,
+ trust_store_path,
+ api_key_file,
+ api_key_value,
+ )
+
+ @property
+ def host(self):
+ return self._host
+
+ @host.setter
+ @not_connected
+ def host(self, host):
+ self._host = host
+
+ @property
+ def port(self):
+ return self._port
+
+ @port.setter
+ @not_connected
+ def port(self, port):
+ self._port = port
+
+ @property
+ def project(self):
+ return self._project
+
+ @project.setter
+ @not_connected
+ def project(self, project):
+ self._project = project
+
+ @property
+ def hostname_verification(self):
+ return self._hostname_verification
+
+ @hostname_verification.setter
+ @not_connected
+ def hostname_verification(self, hostname_verification):
+ self._hostname_verification = hostname_verification
+
+ @property
+ def trust_store_path(self):
+ return self._trust_store_path
+
+ @trust_store_path.setter
+ @not_connected
+ def trust_store_path(self, trust_store_path):
+ self._trust_store_path = trust_store_path
+
+ @property
+ def api_key_file(self):
+ return self._api_key_file
+
+ @property
+ def api_key_value(self):
+ return self._api_key_value
+
+ @api_key_file.setter
+ @not_connected
+ def api_key_file(self, api_key_file):
+ self._api_key_file = api_key_file
+
+ @api_key_value.setter
+ @not_connected
+ def api_key_value(self, api_key_value):
+ self._api_key_value = api_key_value
+
+ def __enter__(self):
+ self.connect()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.close()
diff --git a/hsml/python/hsml/constants.py b/hsml/python/hsml/constants.py
new file mode 100644
index 000000000..1c4bfc4f3
--- /dev/null
+++ b/hsml/python/hsml/constants.py
@@ -0,0 +1,119 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+DEFAULT = dict() # used as default parameter for a class object
+
+
+class MODEL:
+ FRAMEWORK_TENSORFLOW = "TENSORFLOW"
+ FRAMEWORK_TORCH = "TORCH"
+ FRAMEWORK_PYTHON = "PYTHON"
+ FRAMEWORK_SKLEARN = "SKLEARN"
+
+
+class MODEL_REGISTRY:
+ HOPSFS_MOUNT_PREFIX = "/home/yarnapp/hopsfs/"
+
+
+class MODEL_SERVING:
+ MODELS_DATASET = "Models"
+
+
+class ARTIFACT_VERSION:
+ CREATE = "CREATE"
+
+
+class RESOURCES:
+ MIN_NUM_INSTANCES = 1 # disable scale-to-zero by default
+ # default values, not hard limits
+ MIN_CORES = 0.2
+ MIN_MEMORY = 32
+ MIN_GPUS = 0
+ MAX_CORES = 2
+ MAX_MEMORY = 1024
+ MAX_GPUS = 0
+
+
+class KAFKA_TOPIC:
+ NONE = "NONE"
+ CREATE = "CREATE"
+ NUM_REPLICAS = 1
+ NUM_PARTITIONS = 1
+
+
+class INFERENCE_LOGGER:
+ MODE_NONE = "NONE"
+ MODE_ALL = "ALL"
+ MODE_MODEL_INPUTS = "MODEL_INPUTS"
+ MODE_PREDICTIONS = "PREDICTIONS"
+
+
+class INFERENCE_BATCHER:
+ ENABLED = False
+
+
+class DEPLOYMENT:
+ ACTION_START = "START"
+ ACTION_STOP = "STOP"
+
+
+class PREDICTOR:
+ # model server
+ MODEL_SERVER_PYTHON = "PYTHON"
+ MODEL_SERVER_TF_SERVING = "TENSORFLOW_SERVING"
+ # serving tool
+ SERVING_TOOL_DEFAULT = "DEFAULT"
+ SERVING_TOOL_KSERVE = "KSERVE"
+
+
+class PREDICTOR_STATE:
+ # status
+ STATUS_CREATING = "Creating"
+ STATUS_CREATED = "Created"
+ STATUS_STARTING = "Starting"
+ STATUS_FAILED = "Failed"
+ STATUS_RUNNING = "Running"
+ STATUS_IDLE = "Idle"
+ STATUS_UPDATING = "Updating"
+ STATUS_STOPPING = "Stopping"
+ STATUS_STOPPED = "Stopped"
+ # condition type
+ CONDITION_TYPE_STOPPED = "STOPPED"
+ CONDITION_TYPE_SCHEDULED = "SCHEDULED"
+ CONDITION_TYPE_INITIALIZED = "INITIALIZED"
+ CONDITION_TYPE_STARTED = "STARTED"
+ CONDITION_TYPE_READY = "READY"
+
+
+class INFERENCE_ENDPOINTS:
+ # endpoint type
+ ENDPOINT_TYPE_NODE = "NODE"
+ ENDPOINT_TYPE_KUBE_CLUSTER = "KUBE_CLUSTER"
+ ENDPOINT_TYPE_LOAD_BALANCER = "LOAD_BALANCER"
+ # port name
+ PORT_NAME_HTTP = "HTTP"
+ PORT_NAME_HTTPS = "HTTPS"
+ PORT_NAME_STATUS_PORT = "STATUS"
+ PORT_NAME_TLS = "TLS"
+ # protocol
+ API_PROTOCOL_REST = "REST"
+ API_PROTOCOL_GRPC = "GRPC"
+
+
+class DEPLOYABLE_COMPONENT:
+ PREDICTOR = "predictor"
+ TRANSFORMER = "transformer"
diff --git a/hsml/python/hsml/core/__init__.py b/hsml/python/hsml/core/__init__.py
new file mode 100644
index 000000000..ff0a6f046
--- /dev/null
+++ b/hsml/python/hsml/core/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/core/dataset_api.py b/hsml/python/hsml/core/dataset_api.py
new file mode 100644
index 000000000..dc6301b48
--- /dev/null
+++ b/hsml/python/hsml/core/dataset_api.py
@@ -0,0 +1,582 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+import json
+import math
+import os
+import time
+from concurrent.futures import ThreadPoolExecutor, wait
+
+from hsml import client, tag
+from hsml.client.exceptions import RestAPIError
+from tqdm.auto import tqdm
+
+
+class Chunk:
+ def __init__(self, content, number, status):
+ self.content = content
+ self.number = number
+ self.status = status
+ self.retries = 0
+
+
+class DatasetApi:
+ def __init__(self):
+ pass
+
+ DEFAULT_UPLOAD_FLOW_CHUNK_SIZE = 10
+ DEFAULT_UPLOAD_SIMULTANEOUS_UPLOADS = 3
+ DEFAULT_UPLOAD_MAX_CHUNK_RETRIES = 1
+
+ DEFAULT_DOWNLOAD_FLOW_CHUNK_SIZE = 1_048_576
+ FLOW_PERMANENT_ERRORS = [404, 413, 415, 500, 501]
+
+ def upload(
+ self,
+ local_path: str,
+ upload_path: str,
+ overwrite: bool = False,
+ chunk_size=DEFAULT_UPLOAD_FLOW_CHUNK_SIZE,
+ simultaneous_uploads=DEFAULT_UPLOAD_SIMULTANEOUS_UPLOADS,
+ max_chunk_retries=DEFAULT_UPLOAD_MAX_CHUNK_RETRIES,
+ chunk_retry_interval=1,
+ ):
+ """Upload a file to the Hopsworks filesystem.
+
+ ```python
+
+ conn = hsml.connection(project="my-project")
+
+ dataset_api = conn.get_dataset_api()
+
+ uploaded_file_path = dataset_api.upload("my_local_file.txt", "Resources")
+
+ ```
+ # Arguments
+ local_path: local path to file to upload
+ upload_path: path to directory where to upload the file in Hopsworks Filesystem
+ overwrite: overwrite file if exists
+ chunk_size: upload chunk size in megabytes. Default 10 MB
+ simultaneous_uploads: number of simultaneous chunks to upload. Default 3
+ max_chunk_retries: maximum retry for a chunk. Default is 1
+ chunk_retry_interval: chunk retry interval in seconds. Default is 1sec
+ # Returns
+ `str`: Path to uploaded file
+ # Raises
+ `RestAPIError`: If unable to upload the file
+ """
+ # local path could be absolute or relative,
+ if not os.path.isabs(local_path) and os.path.exists(
+ os.path.join(os.getcwd(), local_path)
+ ):
+ local_path = os.path.join(os.getcwd(), local_path)
+
+ file_size = os.path.getsize(local_path)
+
+ _, file_name = os.path.split(local_path)
+
+ destination_path = upload_path + "/" + file_name
+ chunk_size_bytes = chunk_size * 1024 * 1024
+
+ if self.path_exists(destination_path):
+ if overwrite:
+ self.rm(destination_path)
+ else:
+ raise Exception(
+ "{} already exists, set overwrite=True to overwrite it".format(
+ local_path
+ )
+ )
+
+ num_chunks = math.ceil(file_size / chunk_size_bytes)
+
+ base_params = self._get_flow_base_params(
+ file_name, num_chunks, file_size, chunk_size_bytes
+ )
+
+ chunk_number = 1
+ with open(local_path, "rb") as f:
+ pbar = None
+ try:
+ pbar = tqdm(
+ total=file_size,
+ bar_format="{desc}: {percentage:.3f}%|{bar}| {n_fmt}/{total_fmt} elapsed<{elapsed} remaining<{remaining}",
+ desc="Uploading",
+ )
+ except Exception:
+ self._log.exception("Failed to initialize progress bar.")
+ self._log.info("Starting upload")
+ with ThreadPoolExecutor(simultaneous_uploads) as executor:
+ while True:
+ chunks = []
+ for _ in range(simultaneous_uploads):
+ chunk = f.read(chunk_size_bytes)
+ if not chunk:
+ break
+ chunks.append(Chunk(chunk, chunk_number, "pending"))
+ chunk_number += 1
+
+ if len(chunks) == 0:
+ break
+
+ # upload each chunk and update pbar
+ futures = [
+ executor.submit(
+ self._upload_chunk,
+ base_params,
+ upload_path,
+ file_name,
+ chunk,
+ pbar,
+ max_chunk_retries,
+ chunk_retry_interval,
+ )
+ for chunk in chunks
+ ]
+ # wait for all upload tasks to complete
+ _, _ = wait(futures)
+ try:
+ _ = [future.result() for future in futures]
+ except Exception as e:
+ if pbar is not None:
+ pbar.close()
+ raise e
+
+ if pbar is not None:
+ pbar.close()
+ else:
+ self._log.info("Upload finished")
+
+ return upload_path + "/" + os.path.basename(local_path)
+
+ def _upload_chunk(
+ self,
+ base_params,
+ upload_path,
+ file_name,
+ chunk: Chunk,
+ pbar,
+ max_chunk_retries,
+ chunk_retry_interval,
+ ):
+ query_params = copy.copy(base_params)
+ query_params["flowCurrentChunkSize"] = len(chunk.content)
+ query_params["flowChunkNumber"] = chunk.number
+
+ chunk.status = "uploading"
+ while True:
+ try:
+ self._upload_request(
+ query_params, upload_path, file_name, chunk.content
+ )
+ break
+ except RestAPIError as re:
+ chunk.retries += 1
+ if (
+ re.response.status_code in DatasetApi.FLOW_PERMANENT_ERRORS
+ or chunk.retries > max_chunk_retries
+ ):
+ chunk.status = "failed"
+ raise re
+ time.sleep(chunk_retry_interval)
+ continue
+
+ chunk.status = "uploaded"
+
+ if pbar is not None:
+ pbar.update(query_params["flowCurrentChunkSize"])
+
+ def _get_flow_base_params(self, file_name, num_chunks, size, chunk_size):
+ return {
+ "templateId": -1,
+ "flowChunkSize": chunk_size,
+ "flowTotalSize": size,
+ "flowIdentifier": str(size) + "_" + file_name,
+ "flowFilename": file_name,
+ "flowRelativePath": file_name,
+ "flowTotalChunks": num_chunks,
+ }
+
+ def _upload_request(self, params, path, file_name, chunk):
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "dataset", "upload", path]
+
+ # Flow configuration params are sent as form data
+ _client._send_request(
+ "POST", path_params, data=params, files={"file": (file_name, chunk)}
+ )
+
+ def download(self, path, local_path):
+ """Download file/directory on a path in datasets.
+ :param path: path to download
+ :type path: str
+ :param local_path: path to download in datasets
+ :type local_path: str
+ """
+
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "dataset",
+ "download",
+ "with_auth",
+ path,
+ ]
+ query_params = {"type": "DATASET"}
+
+ with _client._send_request(
+ "GET", path_params, query_params=query_params, stream=True
+ ) as response:
+ with open(local_path, "wb") as f:
+ downloaded = 0
+ # if not response.headers.get("Content-Length"), file is still downloading
+ for chunk in response.iter_content(
+ chunk_size=self.DEFAULT_DOWNLOAD_FLOW_CHUNK_SIZE
+ ):
+ f.write(chunk)
+ downloaded += len(chunk)
+
+ def get(self, remote_path):
+ """Get metadata about a path in datasets.
+
+ :param remote_path: path to check
+ :type remote_path: str
+ :return: dataset metadata
+ :rtype: dict
+ """
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "dataset", remote_path]
+ headers = {"content-type": "application/json"}
+ return _client._send_request("GET", path_params, headers=headers)
+
+ def path_exists(self, remote_path):
+ """Check if a path exists in datasets.
+
+ :param remote_path: path to check
+ :type remote_path: str
+ :return: boolean whether path exists
+ :rtype: bool
+ """
+ try:
+ self.get(remote_path)
+ return True
+ except RestAPIError:
+ return False
+
+ def list(self, remote_path, sort_by=None, limit=1000):
+ """List all files in a directory in datasets.
+
+ :param remote_path: path to list
+ :type remote_path: str
+ :param sort_by: sort string
+ :type sort_by: str
+ :param limit: max number of returned files
+ :type limit: int
+ """
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "dataset", remote_path]
+ query_params = {"action": "listing", "sort_by": sort_by, "limit": limit}
+ headers = {"content-type": "application/json"}
+ return _client._send_request(
+ "GET", path_params, headers=headers, query_params=query_params
+ )
+
+ def chmod(self, remote_path, permissions):
+ """Chmod operation on file or directory in datasets.
+
+ :param remote_path: path to chmod
+ :type remote_path: str
+ :param permissions: permissions string, for example u+x
+ :type permissions: str
+ """
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "dataset", remote_path]
+ headers = {"content-type": "application/json"}
+ query_params = {"action": "PERMISSION", "permissions": permissions}
+ return _client._send_request(
+ "PUT", path_params, headers=headers, query_params=query_params
+ )
+
+ def mkdir(self, remote_path):
+ """Path to create in datasets.
+
+ :param remote_path: path to create
+ :type remote_path: str
+ """
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "dataset", remote_path]
+ query_params = {
+ "action": "create",
+ "searchable": "true",
+ "generate_readme": "false",
+ "type": "DATASET",
+ }
+ headers = {"content-type": "application/json"}
+ return _client._send_request(
+ "POST", path_params, headers=headers, query_params=query_params
+ )
+
+ def rm(self, remote_path):
+ """Remove a path in datasets.
+
+ :param remote_path: path to remove
+ :type remote_path: str
+ """
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "dataset", remote_path]
+ return _client._send_request("DELETE", path_params)
+
+ def _archive(
+ self,
+ remote_path,
+ destination_path=None,
+ block=False,
+ timeout=120,
+ action="unzip",
+ ):
+ """Internal (de)compression logic.
+
+ :param remote_path: path to file or directory to unzip
+ :type remote_path: str
+ :param destination_path: path to upload the zip
+ :type destination_path: str
+ :param block: if the operation should be blocking until complete
+ :type block: bool
+ :param timeout: timeout if the operation is blocking
+ :type timeout: int
+ :param action: zip or unzip
+ :type action: str
+ """
+
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "dataset", remote_path]
+
+ query_params = {"action": action}
+
+ if destination_path is not None:
+ query_params["destination_path"] = destination_path
+ query_params["destination_type"] = "DATASET"
+
+ headers = {"content-type": "application/json"}
+
+ _client._send_request(
+ "POST", path_params, headers=headers, query_params=query_params
+ )
+
+ if block is True:
+ # Wait for zip file to appear. When it does, check that parent dir zipState is not set to CHOWNING
+ count = 0
+ while count < timeout:
+ if action == "zip":
+ zip_path = remote_path + ".zip"
+ # Get the status of the zipped file
+ if destination_path is None:
+ zip_exists = self.path_exists(zip_path)
+ else:
+ zip_exists = self.path_exists(
+ destination_path + "/" + os.path.split(zip_path)[1]
+ )
+ # Get the zipState of the directory being zipped
+ dir_status = self.get(remote_path)
+ zip_state = (
+ dir_status["zipState"] if "zipState" in dir_status else None
+ )
+ if zip_exists and zip_state == "NONE":
+ return
+ else:
+ time.sleep(1)
+ elif action == "unzip":
+ # Get the status of the unzipped dir
+ unzipped_dir_exists = self.path_exists(
+ remote_path[: remote_path.index(".")]
+ )
+ # Get the zipState of the zip being extracted
+ dir_status = self.get(remote_path)
+ zip_state = (
+ dir_status["zipState"] if "zipState" in dir_status else None
+ )
+ if unzipped_dir_exists and zip_state == "NONE":
+ return
+ else:
+ time.sleep(1)
+ count += 1
+ raise Exception(
+ "Timeout of {} seconds exceeded while {} {}.".format(
+ timeout, action, remote_path
+ )
+ )
+
+ def unzip(self, remote_path, block=False, timeout=120):
+ """Unzip an archive in the dataset.
+
+ :param remote_path: path to file or directory to unzip
+ :type remote_path: str
+ :param block: if the operation should be blocking until complete
+ :type block: bool
+ :param timeout: timeout if the operation is blocking
+ :type timeout: int
+ """
+ self._archive(remote_path, block=block, timeout=timeout, action="unzip")
+
+ def zip(self, remote_path, destination_path=None, block=False, timeout=120):
+ """Zip a file or directory in the dataset.
+
+ :param remote_path: path to file or directory to zip
+ :type remote_path: str
+ :param destination_path: path to upload the zip
+ :type destination_path: str
+ :param block: if the operation should be blocking until complete
+ :type block: bool
+ :param timeout: timeout if the operation is blocking
+ :type timeout: int
+ """
+ self._archive(
+ remote_path,
+ destination_path=destination_path,
+ block=block,
+ timeout=timeout,
+ action="zip",
+ )
+
+ def move(self, source_path, destination_path):
+ """Move a file or directory in the dataset.
+
+ A tag consists of a name/value pair. Tag names are unique identifiers.
+ The value of a tag can be any valid json - primitives, arrays or json objects.
+
+ :param source_path: path to file or directory to move
+ :type source_path: str
+ :param destination_path: destination path
+ :type destination_path: str
+ """
+
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "dataset", source_path]
+
+ query_params = {"action": "move", "destination_path": destination_path}
+ headers = {"content-type": "application/json"}
+
+ _client._send_request(
+ "POST", path_params, headers=headers, query_params=query_params
+ )
+
+ def copy(self, source_path, destination_path):
+ """Copy a file or directory in the dataset.
+
+ A tag consists of a name/value pair. Tag names are unique identifiers.
+ The value of a tag can be any valid json - primitives, arrays or json objects.
+
+ :param source_path: path to file or directory to copy
+ :type source_path: str
+ :param destination_path: destination path
+ :type destination_path: str
+ """
+
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "dataset", source_path]
+
+ query_params = {"action": "copy", "destination_path": destination_path}
+ headers = {"content-type": "application/json"}
+
+ _client._send_request(
+ "POST", path_params, headers=headers, query_params=query_params
+ )
+
+ def add(self, path, name, value):
+ """Attach a name/value tag to a model.
+
+ A tag consists of a name/value pair. Tag names are unique identifiers.
+ The value of a tag can be any valid json - primitives, arrays or json objects.
+
+ :param path: path to add the tag
+ :type path: str
+ :param name: name of the tag to be added
+ :type name: str
+ :param value: value of the tag to be added
+ :type value: str
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "dataset",
+ "tags",
+ "schema",
+ name,
+ path,
+ ]
+ headers = {"content-type": "application/json"}
+ json_value = json.dumps(value)
+ _client._send_request("PUT", path_params, headers=headers, data=json_value)
+
+ def delete(self, path, name):
+ """Delete a tag.
+
+ Tag names are unique identifiers.
+
+ :param path: path to delete the tags
+ :type path: str
+ :param name: name of the tag to be removed
+ :type name: str
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "dataset",
+ "tags",
+ "schema",
+ name,
+ path,
+ ]
+ _client._send_request("DELETE", path_params)
+
+ def get_tags(self, path, name: str = None):
+ """Get the tags.
+
+ Gets all tags if no tag name is specified.
+
+ :param path: path to get the tags
+ :type path: str
+ :param name: tag name
+ :type name: str
+ :return: dict of tag name/values
+ :rtype: dict
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "dataset",
+ "tags",
+ ]
+
+ if name is not None:
+ path_params.append("schema")
+ path_params.append(name)
+ else:
+ path_params.append("all")
+
+ path_params.append(path)
+
+ return {
+ tag._name: json.loads(tag._value)
+ for tag in tag.Tag.from_response_json(
+ _client._send_request("GET", path_params)
+ )
+ }
diff --git a/hsml/python/hsml/core/explicit_provenance.py b/hsml/python/hsml/core/explicit_provenance.py
new file mode 100644
index 000000000..ea6ce9bd8
--- /dev/null
+++ b/hsml/python/hsml/core/explicit_provenance.py
@@ -0,0 +1,368 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import logging
+from enum import Enum
+from typing import Set
+
+import humps
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Artifact:
+ class MetaType(Enum):
+ DELETED = 1
+ INACCESSIBLE = 2
+ FAULTY = 3
+ NOT_SUPPORTED = 4
+
+ def __init__(
+ self,
+ model_registry_id,
+ name,
+ version,
+ type,
+ meta_type,
+ href=None,
+ exception_cause=None,
+ **kwargs,
+ ):
+ self._model_registry_id = model_registry_id
+ self._name = name
+ self._version = version
+ self._type = type
+ self._meta_type = meta_type
+ self._href = href
+ self._exception_cause = exception_cause
+
+ @property
+ def model_registry_id(self):
+ """Id of the model registry in which the artifact is located."""
+ return self._model_registry_id
+
+ @property
+ def name(self):
+ """Name of the artifact."""
+ return self._name
+
+ @property
+ def version(self):
+ """Version of the artifact"""
+ return self._version
+
+ def __str__(self):
+ return {
+ "model_registry_id": self._model_registry_id,
+ "name": self._name,
+ "version": self._version,
+ }
+
+ def __repr__(self):
+ return (
+ f"Artifact({self._model_registry_id!r}, {self._name!r}, "
+ f"{self._version!r}, {self._type!r}, {self._meta_type!r}, "
+ f"{self._href!r}, {self._exception_cause!r})"
+ )
+
+ @staticmethod
+ def from_response_json(json_dict: dict):
+ link_json = humps.decamelize(json_dict)
+ href = None
+ exception_cause = None
+ if link_json.get("exception_cause") is not None:
+ meta_type = Artifact.MetaType.FAULTY
+ exception_cause = link_json.get("exception_cause")
+ elif bool(link_json["deleted"]):
+ meta_type = Artifact.MetaType.DELETED
+ elif not bool(link_json["accessible"]):
+ meta_type = Artifact.MetaType.INACCESSIBLE
+ href = link_json["artifact"]["href"]
+ else:
+ meta_type = Artifact.MetaType.NOT_SUPPORTED
+ href = link_json["artifact"]["href"]
+ return Artifact(
+ link_json["artifact"]["project"],
+ link_json["artifact"]["name"],
+ link_json["artifact"]["version"],
+ link_json["artifact_type"],
+ meta_type,
+ href=href,
+ exception_cause=exception_cause,
+ )
+
+
+class Links:
+ def __init__(self, accessible=None, deleted=None, inaccessible=None, faulty=None):
+ if accessible is None:
+ self._accessible = []
+ else:
+ self._accessible = accessible
+ if deleted is None:
+ self._deleted = []
+ else:
+ self._deleted = deleted
+ if inaccessible is None:
+ self._inaccessible = []
+ else:
+ self._inaccessible = inaccessible
+ if faulty is None:
+ self._faulty = []
+ else:
+ self._faulty = faulty
+
+ @property
+ def deleted(self):
+ """List of [Artifact objects] which contains
+ minimal information (name, version) about the entities
+ (feature views, training datasets) they represent.
+ These entities have been removed from the feature store.
+ """
+ return self._deleted
+
+ @property
+ def inaccessible(self):
+ """List of [Artifact objects] which contains
+ minimal information (name, version) about the entities
+ (feature views, training datasets) they represent.
+ These entities exist in the feature store, however the user
+ does not have access to them anymore.
+ """
+ return self._inaccessible
+
+ @property
+ def accessible(self):
+ """List of [FeatureView|TrainingDataset objects] objects
+ which are part of the provenance graph requested. These entities
+ exist in the feature store and the user has access to them.
+ """
+ return self._accessible
+
+ @property
+ def faulty(self):
+ """List of [Artifact objects] which contains
+ minimal information (name, version) about the entities
+ (feature views, training datasets) they represent.
+ These entities exist in the feature store, however they are corrupted.
+ """
+ return self._faulty
+
+ class Direction(Enum):
+ UPSTREAM = 1
+ DOWNSTREAM = 2
+
+ class Type(Enum):
+ FEATURE_VIEW = 1
+ TRAINING_DATASET = 2
+
+ def __str__(self, indent=None):
+ return json.dumps(self, cls=ProvenanceEncoder, indent=indent)
+
+ def __repr__(self):
+ return (
+ f"Links({self._accessible!r}, {self._deleted!r}"
+ f", {self._inaccessible!r}, {self._faulty!r})"
+ )
+
+ @staticmethod
+ def get_one_accessible_parent(links):
+ if links is None:
+ _logger.info("There is no parent information")
+ return
+ elif links.inaccessible or links.deleted:
+ _logger.info(
+ "The parent is deleted or inaccessible. For more details get the full provenance from `_provenance` method"
+ )
+ return None
+ elif links.accessible:
+ if len(links.accessible) > 1:
+ msg = "Backend inconsistency - provenance returned more than one parent"
+ raise Exception(msg)
+ parent = links.accessible[0]
+ if isinstance(parent, Artifact):
+ msg = "The returned object is not a valid object. For more details get the full provenance from `_provenance` method"
+ raise Exception(msg)
+ return parent
+ else:
+ _logger.info("There is no parent information")
+ return None
+
+ @staticmethod
+ def __parse_feature_views(links_json: dict, artifacts: Set[str]):
+ from hsfs import feature_view
+ from hsfs.core import explicit_provenance as hsfs_explicit_provenance
+
+ links = Links()
+ for link_json in links_json:
+ if link_json["node"]["artifact_type"] in artifacts:
+ if link_json["node"].get("exception_cause") is not None:
+ links._faulty.append(
+ hsfs_explicit_provenance.Artifact.from_response_json(
+ link_json["node"]
+ )
+ )
+ elif bool(link_json["node"]["accessible"]):
+ fv = feature_view.FeatureView.from_response_json(
+ link_json["node"]["artifact"]
+ )
+ links.accessible.append(fv)
+ elif bool(link_json["node"]["deleted"]):
+ links.deleted.append(
+ hsfs_explicit_provenance.Artifact.from_response_json(
+ link_json["node"]
+ )
+ )
+ else:
+ links.inaccessible.append(
+ hsfs_explicit_provenance.Artifact.from_response_json(
+ link_json["node"]
+ )
+ )
+ else:
+ new_links = Links.__parse_feature_views(
+ link_json["upstream"], artifacts
+ )
+ links.faulty.extend(new_links.faulty)
+ links.accessible.extend(new_links.accessible)
+ links.inaccessible.extend(new_links.inaccessible)
+ links.deleted.extend(new_links.deleted)
+ return links
+
+ @staticmethod
+ def __parse_training_datasets(links_json: dict, artifacts: Set[str]):
+ from hsfs import training_dataset
+ from hsfs.core import explicit_provenance as hsfs_explicit_provenance
+
+ links = Links()
+ for link_json in links_json:
+ if link_json["node"]["artifact_type"] in artifacts:
+ if link_json["node"].get("exception_cause") is not None:
+ links._faulty.append(
+ hsfs_explicit_provenance.Artifact.from_response_json(
+ link_json["node"]
+ )
+ )
+ elif bool(link_json["node"]["accessible"]):
+ td = training_dataset.TrainingDataset.from_response_json_single(
+ link_json["node"]["artifact"]
+ )
+ links.accessible.append(td)
+ elif bool(link_json["node"]["deleted"]):
+ links.deleted.append(
+ hsfs_explicit_provenance.Artifact.from_response_json(
+ link_json["node"]
+ )
+ )
+ else:
+ links.inaccessible.append(
+ hsfs_explicit_provenance.Artifact.from_response_json(
+ link_json["node"]
+ )
+ )
+ return links
+
+ @staticmethod
+ def from_response_json(json_dict: dict, direction: Direction, artifact: Type):
+ """Parse explicit links from json response. There are three types of
+ Links: UpstreamFeatureGroups, DownstreamFeatureGroups, DownstreamFeatureViews
+
+ # Arguments
+ links_json: json response from the explicit provenance endpoint
+ direction: subset of links to parse - UPSTREAM/DOWNSTREAM
+ type: subset of links to parse - FEATURE_VIEW/TRAINING_DATASET/MODEL
+
+ # Returns
+ A ProvenanceLink object for the selected parse type.
+ """
+
+ import importlib.util
+
+ if not importlib.util.find_spec("hsfs"):
+ raise ValueError(
+ "hsfs is not installed in the environment - cannot parse feature store artifacts"
+ )
+ if not importlib.util.find_spec("hopsworks"):
+ raise ValueError(
+ "hopsworks is not installed in the environment - cannot switch from hsml connection to hsfs connection"
+ )
+
+ # make sure the hsfs connection is initialized so that the feature view/training dataset can actually be used after being returned
+ import hopsworks
+
+ if not hopsworks._connected_project:
+ raise Exception(
+ "hopsworks connection is not initialized - use hopsworks.login to connect if you want the ability to use provenance with connections between hsfs and hsml"
+ )
+
+ hopsworks._connected_project.get_feature_store()
+
+ links = Links.__from_response_json_feature_store_artifacts(
+ json_dict, direction, artifact
+ )
+ return links
+
+ @staticmethod
+ def __from_response_json_feature_store_artifacts(
+ json_dict: dict, direction: Direction, artifact: Type
+ ):
+ links_json = humps.decamelize(json_dict)
+ if direction == Links.Direction.UPSTREAM:
+ if artifact == Links.Type.FEATURE_VIEW:
+ return Links.__parse_feature_views(
+ links_json["upstream"],
+ {
+ "FEATURE_VIEW",
+ },
+ )
+ elif artifact == Links.Type.TRAINING_DATASET:
+ return Links.__parse_training_datasets(
+ links_json["upstream"], {"TRAINING_DATASET"}
+ )
+ else:
+ return Links()
+
+
+class ProvenanceEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, Links):
+ return {
+ "accessible": obj.accessible,
+ "inaccessible": obj.inaccessible,
+ "deleted": obj.deleted,
+ "faulty": obj.faulty,
+ }
+ else:
+ import importlib.util
+
+ if importlib.util.find_spec("hsfs"):
+ from hsfs import feature_view
+ from hsfs.core import explicit_provenance as hsfs_explicit_provenance
+
+ if isinstance(
+ obj,
+ (
+ feature_view.FeatureView,
+ hsfs_explicit_provenance.Artifact,
+ ),
+ ):
+ return {
+ "feature_store_name": obj.feature_store_name,
+ "name": obj.name,
+ "version": obj.version,
+ }
+ return json.JSONEncoder.default(self, obj)
diff --git a/hsml/python/hsml/core/model_api.py b/hsml/python/hsml/core/model_api.py
new file mode 100644
index 000000000..190a0aca8
--- /dev/null
+++ b/hsml/python/hsml/core/model_api.py
@@ -0,0 +1,301 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+from typing import Union
+
+from hsml import client, model, tag
+from hsml.core import explicit_provenance
+
+
+class ModelApi:
+ def __init__(self):
+ pass
+
+ def put(self, model_instance, query_params):
+ """Save model metadata to the model registry.
+
+ :param model_instance: metadata object of model to be saved
+ :type model_instance: Model
+ :return: updated metadata object of the model
+ :rtype: Model
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "modelregistries",
+ str(model_instance.model_registry_id),
+ "models",
+ model_instance.name + "_" + str(model_instance.version),
+ ]
+ headers = {"content-type": "application/json"}
+ return model_instance.update_from_response_json(
+ _client._send_request(
+ "PUT",
+ path_params,
+ headers=headers,
+ query_params=query_params,
+ data=model_instance.json(),
+ )
+ )
+
+ def get(self, name, version, model_registry_id, shared_registry_project_name=None):
+ """Get the metadata of a model with a certain name and version.
+
+ :param name: name of the model
+ :type name: str
+ :param version: version of the model
+ :type version: int
+ :return: model metadata object
+ :rtype: Model
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "modelregistries",
+ model_registry_id,
+ "models",
+ name + "_" + str(version),
+ ]
+ query_params = {"expand": "trainingdatasets"}
+
+ model_json = _client._send_request("GET", path_params, query_params)
+ model_meta = model.Model.from_response_json(model_json)
+
+ model_meta.shared_registry_project_name = shared_registry_project_name
+
+ return model_meta
+
+ def get_models(
+ self,
+ name,
+ model_registry_id,
+ shared_registry_project_name=None,
+ metric=None,
+ direction=None,
+ ):
+ """Get the metadata of models based on the name or optionally the best model given a metric and direction.
+
+ :param name: name of the model
+ :type name: str
+ :param metric: Name of the metric to maximize or minimize
+ :type metric: str
+ :param direction: Whether to maximize or minimize the metric, allowed values are 'max' or 'min'
+ :type direction: str
+ :return: model metadata object
+ :rtype: Model
+ """
+
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "modelregistries",
+ model_registry_id,
+ "models",
+ ]
+ query_params = {
+ "expand": "trainingdatasets",
+ "filter_by": ["name_eq:" + name],
+ }
+
+ if metric is not None and direction is not None:
+ if direction.lower() == "max":
+ direction = "desc"
+ elif direction.lower() == "min":
+ direction = "asc"
+
+ query_params["sort_by"] = metric + ":" + direction
+ query_params["limit"] = "1"
+
+ model_json = _client._send_request("GET", path_params, query_params)
+ models_meta = model.Model.from_response_json(model_json)
+
+ for model_meta in models_meta:
+ model_meta.shared_registry_project_name = shared_registry_project_name
+
+ return models_meta
+
+ def delete(self, model_instance):
+ """Delete the model and metadata.
+
+ :param model_instance: metadata object of model to delete
+ :type model_instance: Model
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "modelregistries",
+ str(model_instance.model_registry_id),
+ "models",
+ model_instance.id,
+ ]
+ _client._send_request("DELETE", path_params)
+
+ def set_tag(self, model_instance, name, value: Union[str, dict]):
+ """Attach a name/value tag to a model.
+
+ A tag consists of a name/value pair. Tag names are unique identifiers.
+ The value of a tag can be any valid json - primitives, arrays or json objects.
+
+ :param model_instance: model instance to attach tag
+ :type model_instance: Model
+ :param name: name of the tag to be added
+ :type name: str
+ :param value: value of the tag to be added
+ :type value: str or dict
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "modelregistries",
+ str(model_instance.model_registry_id),
+ "models",
+ model_instance.id,
+ "tags",
+ name,
+ ]
+ headers = {"content-type": "application/json"}
+ json_value = json.dumps(value)
+ _client._send_request("PUT", path_params, headers=headers, data=json_value)
+
+ def delete_tag(self, model_instance, name):
+ """Delete a tag.
+
+ Tag names are unique identifiers.
+
+ :param model_instance: model instance to delete tag from
+ :type model_instance: Model
+ :param name: name of the tag to be removed
+ :type name: str
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "modelregistries",
+ str(model_instance.model_registry_id),
+ "models",
+ model_instance.id,
+ "tags",
+ name,
+ ]
+ _client._send_request("DELETE", path_params)
+
+ def get_tags(self, model_instance, name: str = None):
+ """Get the tags.
+
+ Gets all tags if no tag name is specified.
+
+ :param model_instance: model instance to get the tags from
+ :type model_instance: Model
+ :param name: tag name
+ :type name: str
+ :return: dict of tag name/values
+ :rtype: dict
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "modelregistries",
+ str(model_instance.model_registry_id),
+ "models",
+ model_instance.id,
+ "tags",
+ ]
+
+ if name is not None:
+ path_params.append(name)
+
+ return {
+ tag._name: json.loads(tag._value)
+ for tag in tag.Tag.from_response_json(
+ _client._send_request("GET", path_params)
+ )
+ }
+
+ def get_feature_view_provenance(self, model_instance):
+ """Get the parent feature view of this model, based on explicit provenance.
+ These feature views can be accessible, deleted or inaccessible.
+ For deleted and inaccessible feature views, only a minimal information is returned.
+
+ # Arguments
+ model_instance: Metadata object of model.
+
+ # Returns
+ `ExplicitProvenance.Links`: the feature view used to generate this model
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "modelregistries",
+ str(model_instance.model_registry_id),
+ "models",
+ model_instance.id,
+ "provenance",
+ "links",
+ ]
+ query_params = {
+ "expand": "provenance_artifacts",
+ "upstreamLvls": 2,
+ "downstreamLvls": 0,
+ }
+ links_json = _client._send_request("GET", path_params, query_params)
+ return explicit_provenance.Links.from_response_json(
+ links_json,
+ explicit_provenance.Links.Direction.UPSTREAM,
+ explicit_provenance.Links.Type.FEATURE_VIEW,
+ )
+
+ def get_training_dataset_provenance(self, model_instance):
+ """Get the parent training dataset of this model, based on explicit provenance.
+ These training datasets can be accessible, deleted or inaccessible.
+ For deleted and inaccessible training dataset, only a minimal information is returned.
+
+ # Arguments
+ model_instance: Metadata object of model.
+
+ # Returns
+ `ExplicitProvenance.Links`: the training dataset used to generate this model
+ """
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "modelregistries",
+ str(model_instance.model_registry_id),
+ "models",
+ model_instance.id,
+ "provenance",
+ "links",
+ ]
+ query_params = {
+ "expand": "provenance_artifacts",
+ "upstreamLvls": 1,
+ "downstreamLvls": 0,
+ }
+ links_json = _client._send_request("GET", path_params, query_params)
+ return explicit_provenance.Links.from_response_json(
+ links_json,
+ explicit_provenance.Links.Direction.UPSTREAM,
+ explicit_provenance.Links.Type.TRAINING_DATASET,
+ )
diff --git a/hsml/python/hsml/core/model_registry_api.py b/hsml/python/hsml/core/model_registry_api.py
new file mode 100644
index 000000000..693136e36
--- /dev/null
+++ b/hsml/python/hsml/core/model_registry_api.py
@@ -0,0 +1,67 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from hsml import client
+from hsml.client.exceptions import ModelRegistryException
+from hsml.core import dataset_api
+from hsml.model_registry import ModelRegistry
+
+
+class ModelRegistryApi:
+ def __init__(self):
+ self._dataset_api = dataset_api.DatasetApi()
+
+ def get(self, project=None):
+ """Get model registry for specific project.
+ :param project: project of the model registry
+ :type project: str
+ :return: the model registry metadata
+ :rtype: ModelRegistry
+ """
+ _client = client.get_instance()
+
+ model_registry_id = _client._project_id
+ shared_registry_project_name = None
+
+ # In the case of shared model registry, validate that there is Models dataset shared to the connected project from the set project name
+ if project is not None:
+ path_params = ["project", _client._project_id, "modelregistries"]
+ model_registries = _client._send_request("GET", path_params)
+ for registry in model_registries["items"]:
+ if registry["name"] == project:
+ model_registry_id = registry["id"]
+ shared_registry_project_name = project
+
+ if shared_registry_project_name is None:
+ raise ModelRegistryException(
+ "No model registry shared with current project {}, from project {}".format(
+ _client._project_name, project
+ )
+ )
+ # In the case of default model registry, validate that there is a Models dataset in the connected project
+ elif project is None and not self._dataset_api.path_exists("Models"):
+ raise ModelRegistryException(
+ "No Models dataset exists in project {}, Please enable the Serving service or create the dataset manually.".format(
+ _client._project_name
+ )
+ )
+
+ return ModelRegistry(
+ _client._project_name,
+ _client._project_id,
+ model_registry_id,
+ shared_registry_project_name=shared_registry_project_name,
+ )
diff --git a/hsml/python/hsml/core/model_serving_api.py b/hsml/python/hsml/core/model_serving_api.py
new file mode 100644
index 000000000..437327742
--- /dev/null
+++ b/hsml/python/hsml/core/model_serving_api.py
@@ -0,0 +1,148 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import socket
+
+from hsml import client
+from hsml.client.exceptions import ModelRegistryException
+from hsml.constants import INFERENCE_ENDPOINTS
+from hsml.core import dataset_api, serving_api
+from hsml.inference_endpoint import get_endpoint_by_type
+from hsml.model_serving import ModelServing
+
+
+class ModelServingApi:
+ def __init__(self):
+ self._dataset_api = dataset_api.DatasetApi()
+ self._serving_api = serving_api.ServingApi()
+
+ def get(self):
+ """Get model serving for specific project.
+ :param project: project of the model registry
+ :type project: str
+ :return: the model serving metadata
+ :rtype: ModelServing
+ """
+
+ _client = client.get_instance()
+
+ # Validate that there is a Models dataset in the connected project
+ if not self._dataset_api.path_exists("Models"):
+ raise ModelRegistryException(
+ "No Models dataset exists in project {}, Please enable the Serving service or create the dataset manually.".format(
+ _client._project_name
+ )
+ )
+
+ return ModelServing(_client._project_name, _client._project_id)
+
+ def load_default_configuration(self):
+ """Load default configuration and set istio client for model serving"""
+
+ # kserve installed
+ is_kserve_installed = self._serving_api.is_kserve_installed()
+ client.set_kserve_installed(is_kserve_installed)
+
+ # istio client
+ self._set_istio_client_if_available()
+
+ # resource limits
+ max_resources = self._serving_api.get_resource_limits()
+ client.set_serving_resource_limits(max_resources)
+
+ # num instances limits
+ num_instances_range = self._serving_api.get_num_instances_limits()
+ client.set_serving_num_instances_limits(num_instances_range)
+
+ # Knative domain
+ knative_domain = self._serving_api.get_knative_domain()
+ client.set_knative_domain(knative_domain)
+
+ def _set_istio_client_if_available(self):
+ """Set istio client if available"""
+
+ if client.is_kserve_installed():
+ # check existing istio client
+ try:
+ if client.get_istio_instance() is not None:
+ return # istio client already set
+ except Exception:
+ pass
+
+ # setup istio client
+ inference_endpoints = self._serving_api.get_inference_endpoints()
+ if client.get_client_type() == "internal":
+ # if internal, get node port
+ endpoint = get_endpoint_by_type(
+ inference_endpoints, INFERENCE_ENDPOINTS.ENDPOINT_TYPE_NODE
+ )
+ if endpoint is not None:
+ client.set_istio_client(
+ endpoint.get_any_host(),
+ endpoint.get_port(INFERENCE_ENDPOINTS.PORT_NAME_HTTP).number,
+ )
+ else:
+ raise ValueError(
+ "Istio ingress endpoint of type '"
+ + INFERENCE_ENDPOINTS.ENDPOINT_TYPE_NODE
+ + "' not found"
+ )
+ else: # if external
+ endpoint = get_endpoint_by_type(
+ inference_endpoints, INFERENCE_ENDPOINTS.ENDPOINT_TYPE_LOAD_BALANCER
+ )
+ if endpoint is not None:
+ # if load balancer (external ip) available
+ _client = client.get_instance()
+ client.set_istio_client(
+ endpoint.get_any_host(),
+ endpoint.get_port(INFERENCE_ENDPOINTS.PORT_NAME_HTTP).number,
+ _client._project_name,
+ _client._auth._token, # reuse hopsworks client token
+ )
+ return
+ # in case there's not load balancer, check if node port is open
+ endpoint = get_endpoint_by_type(
+ inference_endpoints, INFERENCE_ENDPOINTS.ENDPOINT_TYPE_NODE
+ )
+ if endpoint is not None:
+ # if node port available
+ _client = client.get_instance()
+ host = _client.host
+ port = endpoint.get_port(INFERENCE_ENDPOINTS.PORT_NAME_HTTP).number
+ if self._is_host_port_open(host, port):
+ # and it is open
+ client.set_istio_client(
+ host,
+ port,
+ _client._project_name,
+ _client._auth._token, # reuse hopsworks client token
+ )
+ return
+ # otherwise, fallback to hopsworks client
+ print(
+ "External IP not configured for the Istio ingress gateway, the Hopsworks client will be used for model inference instead"
+ )
+
+ def _is_host_port_open(self, host, port):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.settimeout(1)
+ try:
+ result = sock.connect_ex((host, port))
+ finally:
+ sock.settimeout(None)
+ sock.close()
+ return result == 0
diff --git a/hsml/python/hsml/core/native_hdfs_api.py b/hsml/python/hsml/core/native_hdfs_api.py
new file mode 100644
index 000000000..fadd856ea
--- /dev/null
+++ b/hsml/python/hsml/core/native_hdfs_api.py
@@ -0,0 +1,59 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+try:
+ import pydoop.hdfs as hdfs
+except ImportError:
+ pass
+
+from hsml import client
+
+
+class NativeHdfsApi:
+ def __init__(self):
+ pass
+
+ def exists(self, hdfs_path):
+ return hdfs.path.exists(hdfs_path)
+
+ def project_path(self):
+ _client = client.get_instance()
+ return hdfs.path.abspath("/Projects/" + _client._project_name + "/")
+
+ def chmod(self, hdfs_path, mode):
+ return hdfs.chmod(hdfs_path, mode)
+
+ def mkdir(self, path):
+ return hdfs.mkdir(path)
+
+ def rm(self, path, recursive=True):
+ hdfs.rm(path, recursive=recursive)
+
+ def upload(self, local_path: str, remote_path: str):
+ # copy from local fs to hdfs
+ hdfs.put(local_path, remote_path)
+
+ def download(self, remote_path: str, local_path: str):
+ # copy from hdfs to local fs
+ hdfs.get(remote_path, local_path)
+
+ def copy(self, source_path: str, destination_path: str):
+ # both paths are hdfs paths
+ hdfs.cp(source_path, destination_path)
+
+ def move(self, source_path: str, destination_path: str):
+ # both paths are hdfs paths
+ hdfs.rename(source_path, destination_path)
diff --git a/hsml/python/hsml/core/serving_api.py b/hsml/python/hsml/core/serving_api.py
new file mode 100644
index 000000000..45f4e7fcc
--- /dev/null
+++ b/hsml/python/hsml/core/serving_api.py
@@ -0,0 +1,417 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+from typing import Dict, List, Union
+
+from hsml import (
+ client,
+ deployable_component_logs,
+ deployment,
+ inference_endpoint,
+ predictor_state,
+)
+from hsml.client.istio.utils.infer_type import (
+ InferInput,
+ InferOutput,
+ InferRequest,
+)
+from hsml.constants import ARTIFACT_VERSION
+from hsml.constants import INFERENCE_ENDPOINTS as IE
+
+
+class ServingApi:
+ def __init__(self):
+ pass
+
+ def get_by_id(self, id: int):
+ """Get the metadata of a deployment with a certain id.
+
+ :param id: id of the deployment
+ :type id: int
+ :return: deployment metadata object
+ :rtype: Deployment
+ """
+
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "serving",
+ str(id),
+ ]
+ deployment_json = _client._send_request("GET", path_params)
+ deployment_instance = deployment.Deployment.from_response_json(deployment_json)
+ deployment_instance.model_registry_id = _client._project_id
+ return deployment_instance
+
+ def get(self, name: str):
+ """Get the metadata of a deployment with a certain name.
+
+ :param name: name of the deployment
+ :type name: str
+ :return: deployment metadata object
+ :rtype: Deployment
+ """
+
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "serving"]
+ query_params = {"name": name}
+ deployment_json = _client._send_request(
+ "GET", path_params, query_params=query_params
+ )
+ deployment_instance = deployment.Deployment.from_response_json(deployment_json)
+ deployment_instance.model_registry_id = _client._project_id
+ return deployment_instance
+
+ def get_all(self, model_name: str = None, status: str = None):
+ """Get the metadata of all deployments.
+
+ :return: model metadata objects
+ :rtype: List[Deployment]
+ """
+
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "serving"]
+ query_params = {
+ "model": model_name,
+ "status": status.capitalize() if status is not None else None,
+ }
+ deployments_json = _client._send_request(
+ "GET", path_params, query_params=query_params
+ )
+ deployment_instances = deployment.Deployment.from_response_json(
+ deployments_json
+ )
+ for deployment_instance in deployment_instances:
+ deployment_instance.model_registry_id = _client._project_id
+ return deployment_instances
+
+ def get_inference_endpoints(self):
+ """Get inference endpoints.
+
+ :return: inference endpoints for the current project.
+ :rtype: List[InferenceEndpoint]
+ """
+
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "inference", "endpoints"]
+ endpoints_json = _client._send_request("GET", path_params)
+ return inference_endpoint.InferenceEndpoint.from_response_json(endpoints_json)
+
+ def put(self, deployment_instance):
+ """Save deployment metadata to model serving.
+
+ :param deployment_instance: metadata object of deployment to be saved
+ :type deployment_instance: Deployment
+ :return: updated metadata object of the deployment
+ :rtype: Deployment
+ """
+
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "serving"]
+ headers = {"content-type": "application/json"}
+
+ if deployment_instance.artifact_version == ARTIFACT_VERSION.CREATE:
+ deployment_instance.artifact_version = -1
+
+ deployment_instance = deployment_instance.update_from_response_json(
+ _client._send_request(
+ "PUT",
+ path_params,
+ headers=headers,
+ data=deployment_instance.json(),
+ )
+ )
+ deployment_instance.model_registry_id = _client._project_id
+ return deployment_instance
+
+ def post(self, deployment_instance, action: str):
+ """Perform an action on the deployment
+
+ :param action: action to perform on the deployment (i.e., START or STOP)
+ :type action: str
+ """
+
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "serving",
+ deployment_instance.id,
+ ]
+ query_params = {"action": action}
+ return _client._send_request("POST", path_params, query_params=query_params)
+
+ def delete(self, deployment_instance):
+ """Delete the deployment and metadata.
+
+ :param deployment_instance: metadata object of the deployment to delete
+ :type deployment_instance: Deployment
+ """
+
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "serving",
+ deployment_instance.id,
+ ]
+ return _client._send_request("DELETE", path_params)
+
+ def get_state(self, deployment_instance):
+ """Get the state of a given deployment
+
+ :param deployment_instance: metadata object of the deployment to get state of
+ :type deployment_instance: Deployment
+ :return: predictor state
+ :rtype: PredictorState
+ """
+
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "serving",
+ str(deployment_instance.id),
+ ]
+ deployment_json = _client._send_request("GET", path_params)
+ return predictor_state.PredictorState.from_response_json(deployment_json)
+
+ def reset_changes(self, deployment_instance):
+ """Reset a given deployment to the original values in the Hopsworks instance
+
+ :param deployment_instance: metadata object of the deployment to reset
+ :type deployment_instance: Deployment
+ :return: deployment with reset values
+ :rtype: Deployment
+ """
+
+ _client = client.get_instance()
+ path_params = ["project", _client._project_id, "serving"]
+ query_params = {"name": deployment_instance.name}
+ deployment_json = _client._send_request(
+ "GET", path_params, query_params=query_params
+ )
+ deployment_aux = deployment_instance.update_from_response_json(deployment_json)
+ # TODO: remove when model_registry_id is added properly to deployments in backend
+ deployment_aux.model_registry_id = _client._project_id
+ return deployment_aux
+
+ def send_inference_request(
+ self,
+ deployment_instance,
+ data: Union[Dict, List[InferInput]],
+ through_hopsworks: bool = False,
+ ) -> Union[Dict, List[InferOutput]]:
+ """Send inference requests to a deployment with a certain id
+
+ :param deployment_instance: metadata object of the deployment to be used for the prediction
+ :type deployment_instance: Deployment
+ :param data: payload of the inference request
+ :type data: Union[Dict, List[InferInput]]
+ :param through_hopsworks: whether to send the inference request through the Hopsworks REST API or not
+ :type through_hopsworks: bool
+ :return: inference response
+ :rtype: Union[Dict, List[InferOutput]]
+ """
+ if deployment_instance.api_protocol == IE.API_PROTOCOL_REST:
+ # REST protocol, use hopsworks or istio client
+ return self._send_inference_request_via_rest_protocol(
+ deployment_instance, data, through_hopsworks
+ )
+ else:
+ # gRPC protocol, use the deployment grpc channel
+ return self._send_inference_request_via_grpc_protocol(
+ deployment_instance, data
+ )
+
+ def _send_inference_request_via_rest_protocol(
+ self,
+ deployment_instance,
+ data: Dict,
+ through_hopsworks: bool = False,
+ ) -> Dict:
+ headers = {"content-type": "application/json"}
+ if through_hopsworks:
+ # use Hopsworks client
+ _client = client.get_instance()
+ path_params = self._get_hopsworks_inference_path(
+ _client._project_id, deployment_instance
+ )
+ else:
+ _client = client.get_istio_instance()
+ if _client is not None:
+ # use istio client
+ path_params = self._get_istio_inference_path(deployment_instance)
+ # - add host header
+ headers["host"] = self._get_inference_request_host_header(
+ _client._project_name,
+ deployment_instance.name,
+ client.get_knative_domain(),
+ )
+ else:
+ # fallback to Hopsworks client
+ _client = client.get_instance()
+ path_params = self._get_hopsworks_inference_path(
+ _client._project_id, deployment_instance
+ )
+
+ # send inference request
+ return _client._send_request(
+ "POST", path_params, headers=headers, data=json.dumps(data)
+ )
+
+ def _send_inference_request_via_grpc_protocol(
+ self, deployment_instance, data: List[InferInput]
+ ) -> List[InferOutput]:
+ # get grpc channel
+ if deployment_instance._grpc_channel is None:
+ # The gRPC channel is lazily initialized. The first call to deployment.predict() will initialize
+ # the channel, which will be reused in all following calls on the same deployment object.
+ # The gRPC channel is freed when calling deployment.stop()
+ print("Initializing gRPC channel...")
+ deployment_instance._grpc_channel = self._create_grpc_channel(
+ deployment_instance.name
+ )
+ # build an infer request
+ request = InferRequest(
+ infer_inputs=data,
+ model_name=deployment_instance.name,
+ )
+
+ # send infer request
+ infer_response = deployment_instance._grpc_channel.infer(
+ infer_request=request, headers=None
+ )
+
+ # extract infer outputs
+ return infer_response.outputs
+
+ def _create_grpc_channel(self, deployment_name: str):
+ _client = client.get_istio_instance()
+ service_hostname = self._get_inference_request_host_header(
+ _client._project_name,
+ deployment_name,
+ client.get_knative_domain(),
+ )
+ return _client._create_grpc_channel(service_hostname)
+
+ def is_kserve_installed(self):
+ """Check if kserve is installed
+
+ :return: whether kserve is installed
+ :rtype: bool
+ """
+
+ _client = client.get_instance()
+ path_params = ["variables", "kube_kserve_installed"]
+ kserve_installed = _client._send_request("GET", path_params)
+ return (
+ "successMessage" in kserve_installed
+ and kserve_installed["successMessage"] == "true"
+ )
+
+ def get_resource_limits(self):
+ """Get resource limits for model serving"""
+
+ _client = client.get_instance()
+
+ path_params = ["variables", "kube_serving_max_cores_allocation"]
+ max_cores = _client._send_request("GET", path_params)
+
+ path_params = ["variables", "kube_serving_max_memory_allocation"]
+ max_memory = _client._send_request("GET", path_params)
+
+ path_params = ["variables", "kube_serving_max_gpus_allocation"]
+ max_gpus = _client._send_request("GET", path_params)
+
+ return {
+ "cores": float(max_cores["successMessage"]),
+ "memory": int(max_memory["successMessage"]),
+ "gpus": int(max_gpus["successMessage"]),
+ }
+
+ def get_num_instances_limits(self):
+ """Get number of instances limits for model serving"""
+
+ _client = client.get_instance()
+
+ path_params = ["variables", "kube_serving_min_num_instances"]
+ min_instances = _client._send_request("GET", path_params)
+
+ path_params = ["variables", "kube_serving_max_num_instances"]
+ max_instances = _client._send_request("GET", path_params)
+
+ return [
+ int(min_instances["successMessage"]),
+ int(max_instances["successMessage"]),
+ ]
+
+ def get_knative_domain(self):
+ """Get the domain used by knative"""
+
+ _client = client.get_instance()
+
+ path_params = ["variables", "kube_knative_domain_name"]
+ domain = _client._send_request("GET", path_params)
+
+ return domain["successMessage"]
+
+ def get_logs(self, deployment_instance, component, tail):
+ """Get the logs of a deployment
+
+ :param deployment_instance: metadata object of the deployment to get logs from
+ :type deployment_instance: Deployment
+ :param component: deployment component (e.g., predictor or transformer)
+ :type component: str
+ :param tail: number of tailing lines to retrieve
+ :type tail: int
+ :return: deployment logs
+ :rtype: DeployableComponentLogs
+ """
+
+ _client = client.get_instance()
+ path_params = [
+ "project",
+ _client._project_id,
+ "serving",
+ deployment_instance.id,
+ "logs",
+ ]
+ query_params = {"component": component, "tail": tail}
+ return deployable_component_logs.DeployableComponentLogs.from_response_json(
+ _client._send_request("GET", path_params, query_params=query_params)
+ )
+
+ def _get_inference_request_host_header(
+ self, project_name: str, deployment_name: str, domain: str
+ ):
+ return "{}.{}.{}".format(
+ deployment_name, project_name.replace("_", "-"), domain
+ ).lower()
+
+ def _get_hopsworks_inference_path(self, project_id: int, deployment_instance):
+ return [
+ "project",
+ project_id,
+ "inference",
+ "models",
+ deployment_instance.name + ":predict",
+ ]
+
+ def _get_istio_inference_path(self, deployment_instance):
+ return ["v1", "models", deployment_instance.name + ":predict"]
diff --git a/hsml/python/hsml/decorators.py b/hsml/python/hsml/decorators.py
new file mode 100644
index 000000000..826fd5aa2
--- /dev/null
+++ b/hsml/python/hsml/decorators.py
@@ -0,0 +1,55 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import functools
+
+
+def not_connected(fn):
+ @functools.wraps(fn)
+ def if_not_connected(inst, *args, **kwargs):
+ if inst._connected:
+ raise HopsworksConnectionError
+ return fn(inst, *args, **kwargs)
+
+ return if_not_connected
+
+
+def connected(fn):
+ @functools.wraps(fn)
+ def if_connected(inst, *args, **kwargs):
+ if not inst._connected:
+ raise NoHopsworksConnectionError
+ return fn(inst, *args, **kwargs)
+
+ return if_connected
+
+
+class HopsworksConnectionError(Exception):
+ """Thrown when attempted to change connection attributes while connected."""
+
+ def __init__(self):
+ super().__init__(
+ "Connection is currently in use. Needs to be closed for modification."
+ )
+
+
+class NoHopsworksConnectionError(Exception):
+ """Thrown when attempted to perform operation on connection while not connected."""
+
+ def __init__(self):
+ super().__init__(
+ "Connection is not active. Needs to be connected for model registry operations."
+ )
diff --git a/hsml/python/hsml/deployable_component.py b/hsml/python/hsml/deployable_component.py
new file mode 100644
index 000000000..adabff5b8
--- /dev/null
+++ b/hsml/python/hsml/deployable_component.py
@@ -0,0 +1,92 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from abc import ABC, abstractmethod
+from typing import Optional, Union
+
+import humps
+from hsml import util
+from hsml.inference_batcher import InferenceBatcher
+from hsml.resources import Resources
+
+
+class DeployableComponent(ABC):
+ """Configuration of a deployable component (predictor or transformer)."""
+
+ def __init__(
+ self,
+ script_file: Optional[str] = None,
+ resources: Optional[Resources] = None,
+ inference_batcher: Optional[Union[InferenceBatcher, dict]] = None,
+ **kwargs,
+ ):
+ self._script_file = script_file
+ self._resources = resources
+ self._inference_batcher = (
+ util.get_obj_from_json(inference_batcher, InferenceBatcher)
+ or InferenceBatcher()
+ )
+
+ @classmethod
+ @abstractmethod
+ def from_json(cls, json_decamelized):
+ "To be implemented by the component type"
+ pass
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ return cls.from_json(json_decamelized)
+
+ def json(self):
+ return json.dumps(self, cls=util.MLEncoder)
+
+ @abstractmethod
+ def update_from_response_json(self, json_dict):
+ "To be implemented by the component type"
+ pass
+
+ @abstractmethod
+ def to_dict(self):
+ "To be implemented by the component type"
+ pass
+
+ @property
+ def script_file(self):
+ """Script file ran by the deployment component (i.e., predictor or transformer)."""
+ return self._script_file
+
+ @script_file.setter
+ def script_file(self, script_file: str):
+ self._script_file = script_file
+
+ @property
+ def resources(self):
+ """Resource configuration for the deployment component (i.e., predictor or transformer)."""
+ return self._resources
+
+ @resources.setter
+ def resources(self, resources: Resources):
+ self._resources = resources
+
+ @property
+ def inference_batcher(self):
+ """Configuration of the inference batcher attached to the deployment component (i.e., predictor or transformer)."""
+ return self._inference_batcher
+
+ @inference_batcher.setter
+ def inference_batcher(self, inference_batcher: InferenceBatcher):
+ self._inference_batcher = inference_batcher
diff --git a/hsml/python/hsml/deployable_component_logs.py b/hsml/python/hsml/deployable_component_logs.py
new file mode 100644
index 000000000..7035030a4
--- /dev/null
+++ b/hsml/python/hsml/deployable_component_logs.py
@@ -0,0 +1,91 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from datetime import datetime
+
+import humps
+from hsml import util
+
+
+class DeployableComponentLogs:
+ """Server logs of a deployable component (predictor or transformer).
+
+ # Arguments
+ name: Deployment instance name.
+ content: actual logs
+ # Returns
+ `DeployableComponentLogs`. Server logs of a deployable component
+ """
+
+ def __init__(self, instance_name: str, content: str, **kwargs):
+ self._instance_name = instance_name
+ self._content = content
+ self._created_at = datetime.now()
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ if len(json_decamelized) == 0:
+ return []
+ return [cls.from_json(logs) for logs in json_decamelized]
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ return DeployableComponentLogs(*cls.extract_fields_from_json(json_decamelized))
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ instance_name = util.extract_field_from_json(json_decamelized, "instance_name")
+ content = util.extract_field_from_json(json_decamelized, "content")
+ return instance_name, content
+
+ def to_dict(self):
+ return {"instance_name": self._instance_name, "content": self._content}
+
+ @property
+ def instance_name(self):
+ """Name of the deployment instance containing these server logs."""
+ return self._instance_name
+
+ @property
+ def content(self):
+ """Content of the server logs of the current deployment instance."""
+ return self._content
+
+ @property
+ def created_at(self):
+ """Datetime when the current server logs chunk was retrieved."""
+ return self._created_at
+
+ @property
+ def component(self):
+ """Component of the deployment containing these server logs."""
+ return self._component
+
+ @component.setter
+ def component(self, component: str):
+ self._component = component
+
+ @property
+ def tail(self):
+ """Number of lines of server logs."""
+ return self._tail
+
+ @tail.setter
+ def tail(self, tail: int):
+ self._tail = tail
+
+ def __repr__(self):
+ return f"DeployableComponentLogs(instance_name: {self._instance_name!r}, date: {self._created_at!r}) \n{self._content!s}"
diff --git a/hsml/python/hsml/deployment.py b/hsml/python/hsml/deployment.py
new file mode 100644
index 000000000..473c3b9c2
--- /dev/null
+++ b/hsml/python/hsml/deployment.py
@@ -0,0 +1,470 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Dict, List, Optional, Union
+
+from hsml import client, util
+from hsml import predictor as predictor_mod
+from hsml.client.exceptions import ModelServingException
+from hsml.client.istio.utils.infer_type import InferInput
+from hsml.constants import DEPLOYABLE_COMPONENT, PREDICTOR_STATE
+from hsml.core import model_api, serving_api
+from hsml.engine import serving_engine
+from hsml.inference_batcher import InferenceBatcher
+from hsml.inference_logger import InferenceLogger
+from hsml.predictor_state import PredictorState
+from hsml.resources import Resources
+from hsml.transformer import Transformer
+
+
+class Deployment:
+ """Metadata object representing a deployment in Model Serving."""
+
+ def __init__(
+ self,
+ predictor,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ **kwargs,
+ ):
+ self._predictor = predictor
+ self._description = description
+
+ if self._predictor is None:
+ raise ModelServingException("A predictor is required")
+ elif not isinstance(self._predictor, predictor_mod.Predictor):
+ raise ValueError(
+ "The predictor provided is not an instance of the Predictor class"
+ )
+
+ if name is not None:
+ self._predictor.name = name
+
+ if self._description is None:
+ self._description = self._predictor.description
+ else:
+ self._description = self._predictor.description = description
+
+ self._serving_api = serving_api.ServingApi()
+ self._serving_engine = serving_engine.ServingEngine()
+ self._model_api = model_api.ModelApi()
+ self._grpc_channel = None
+ self._model_registry_id = None
+
+ def save(self, await_update: Optional[int] = 60):
+ """Persist this deployment including the predictor and metadata to Model Serving.
+
+ # Arguments
+ await_update: If the deployment is running, awaiting time (seconds) for the running instances to be updated.
+ If the running instances are not updated within this timespan, the call to this method returns while
+ the update in the background.
+ """
+
+ self._serving_engine.save(self, await_update)
+
+ def start(self, await_running: Optional[int] = 60):
+ """Start the deployment
+
+ # Arguments
+ await_running: Awaiting time (seconds) for the deployment to start.
+ If the deployment has not started within this timespan, the call to this method returns while
+ it deploys in the background.
+ """
+
+ self._serving_engine.start(self, await_status=await_running)
+
+ def stop(self, await_stopped: Optional[int] = 60):
+ """Stop the deployment
+
+ # Arguments
+ await_stopped: Awaiting time (seconds) for the deployment to stop.
+ If the deployment has not stopped within this timespan, the call to this method returns while
+ it stopping in the background.
+ """
+
+ self._serving_engine.stop(self, await_status=await_stopped)
+
+ def delete(self, force=False):
+ """Delete the deployment
+
+ # Arguments
+ force: Force the deletion of the deployment.
+ If the deployment is running, it will be stopped and deleted automatically.
+ !!! warn A call to this method does not ask for a second confirmation.
+ """
+
+ self._serving_engine.delete(self, force)
+
+ def get_state(self) -> PredictorState:
+ """Get the current state of the deployment
+
+ # Returns
+ `PredictorState`. The state of the deployment.
+ """
+
+ return self._serving_engine.get_state(self)
+
+ def is_created(self) -> bool:
+ """Check whether the deployment is created.
+
+ # Returns
+ `bool`. Whether the deployment is created or not.
+ """
+
+ return (
+ self._serving_engine.get_state(self).status
+ != PREDICTOR_STATE.STATUS_CREATING
+ )
+
+ def is_running(self, or_idle=True, or_updating=True) -> bool:
+ """Check whether the deployment is ready to handle inference requests
+
+ # Arguments
+ or_idle: Whether the idle state is considered as running (default is True)
+ or_updating: Whether the updating state is considered as running (default is True)
+
+ # Returns
+ `bool`. Whether the deployment is ready or not.
+ """
+
+ status = self._serving_engine.get_state(self).status
+ return (
+ status == PREDICTOR_STATE.STATUS_RUNNING
+ or (or_idle and status == PREDICTOR_STATE.STATUS_IDLE)
+ or (or_updating and status == PREDICTOR_STATE.STATUS_UPDATING)
+ )
+
+ def is_stopped(self, or_created=True) -> bool:
+ """Check whether the deployment is stopped
+
+ # Arguments
+ or_created: Whether the creating and created state is considered as stopped (default is True)
+
+ # Returns
+ `bool`. Whether the deployment is stopped or not.
+ """
+
+ status = self._serving_engine.get_state(self).status
+ return status == PREDICTOR_STATE.STATUS_STOPPED or (
+ or_created
+ and (
+ status == PREDICTOR_STATE.STATUS_CREATING
+ or status == PREDICTOR_STATE.STATUS_CREATED
+ )
+ )
+
+ def predict(
+ self,
+ data: Union[Dict, InferInput] = None,
+ inputs: Union[List, Dict] = None,
+ ):
+ """Send inference requests to the deployment.
+ One of data or inputs parameters must be set. If both are set, inputs will be ignored.
+
+ !!! example
+ ```python
+ # login into Hopsworks using hopsworks.login()
+
+ # get Hopsworks Model Serving handle
+ ms = project.get_model_serving()
+
+ # retrieve deployment by name
+ my_deployment = ms.get_deployment("my_deployment")
+
+ # (optional) retrieve model input example
+ my_model = project.get_model_registry() \
+ .get_model(my_deployment.model_name, my_deployment.model_version)
+
+ # make predictions using model inputs (single or batch)
+ predictions = my_deployment.predict(inputs=my_model.input_example)
+
+ # or using more sophisticated inference request payloads
+ data = { "instances": [ my_model.input_example ], "key2": "value2" }
+ predictions = my_deployment.predict(data)
+ ```
+
+ # Arguments
+ data: Payload dictionary for the inference request including the model input(s)
+ inputs: Model inputs used in the inference requests
+
+ # Returns
+ `dict`. Inference response.
+ """
+
+ return self._serving_engine.predict(self, data, inputs)
+
+ def get_model(self):
+ """Retrieve the metadata object for the model being used by this deployment"""
+ return self._model_api.get(
+ self.model_name, self.model_version, self.model_registry_id
+ )
+
+ def download_artifact(self):
+ """Download the model artifact served by the deployment"""
+
+ return self._serving_engine.download_artifact(self)
+
+ def get_logs(self, component="predictor", tail=10):
+ """Prints the deployment logs of the predictor or transformer.
+
+ # Arguments
+ component: Deployment component to get the logs from (e.g., predictor or transformer)
+ tail: Number of most recent lines to retrieve from the logs.
+ """
+
+ # validate component
+ components = list(util.get_members(DEPLOYABLE_COMPONENT))
+ if component not in components:
+ raise ValueError(
+ "Component '{}' is not valid. Possible values are '{}'".format(
+ component, ", ".join(components)
+ )
+ )
+
+ logs = self._serving_engine.get_logs(self, component, tail)
+ if logs is not None:
+ for log in logs:
+ print(log, end="\n\n")
+
+ def get_url(self):
+ """Get url to the deployment in Hopsworks"""
+
+ path = (
+ "/p/"
+ + str(client.get_instance()._project_id)
+ + "/deployments/"
+ + str(self.id)
+ )
+ return util.get_hostname_replaced_url(path)
+
+ def describe(self):
+ """Print a description of the deployment"""
+
+ util.pretty_print(self)
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ predictors = predictor_mod.Predictor.from_response_json(json_dict)
+ if isinstance(predictors, list):
+ return [
+ cls.from_predictor(predictor_instance)
+ for predictor_instance in predictors
+ ]
+ else:
+ return cls.from_predictor(predictors)
+
+ @classmethod
+ def from_predictor(cls, predictor_instance):
+ return Deployment(
+ predictor=predictor_instance,
+ name=predictor_instance._name,
+ description=predictor_instance._description,
+ )
+
+ def update_from_response_json(self, json_dict):
+ self._predictor.update_from_response_json(json_dict)
+ self.__init__(
+ predictor=self._predictor,
+ name=self._predictor._name,
+ description=self._predictor._description,
+ )
+ return self
+
+ def json(self):
+ return self._predictor.json()
+
+ def to_dict(self):
+ return self._predictor.to_dict()
+
+ # Deployment
+
+ @property
+ def id(self):
+ """Id of the deployment."""
+ return self._predictor.id
+
+ @property
+ def name(self):
+ """Name of the deployment."""
+ return self._predictor.name
+
+ @name.setter
+ def name(self, name: str):
+ self._predictor.name = name
+
+ @property
+ def description(self):
+ """Description of the deployment."""
+ return self._description
+
+ @description.setter
+ def description(self, description: str):
+ self._description = description
+
+ @property
+ def predictor(self):
+ """Predictor used in the deployment."""
+ return self._predictor
+
+ @predictor.setter
+ def predictor(self, predictor):
+ self._predictor = predictor
+
+ @property
+ def requested_instances(self):
+ """Total number of requested instances in the deployment."""
+ return self._predictor.requested_instances
+
+ # Single predictor
+
+ @property
+ def model_name(self):
+ """Name of the model deployed by the predictor"""
+ return self._predictor.model_name
+
+ @model_name.setter
+ def model_name(self, model_name: str):
+ self._predictor.model_name = model_name
+
+ @property
+ def model_path(self):
+ """Model path deployed by the predictor."""
+ return self._predictor.model_path
+
+ @model_path.setter
+ def model_path(self, model_path: str):
+ self._predictor.model_path = model_path
+
+ @property
+ def model_version(self):
+ """Model version deployed by the predictor."""
+ return self._predictor.model_version
+
+ @model_version.setter
+ def model_version(self, model_version: int):
+ self._predictor.model_version = model_version
+
+ @property
+ def artifact_version(self):
+ """Artifact version deployed by the predictor."""
+ return self._predictor.artifact_version
+
+ @artifact_version.setter
+ def artifact_version(self, artifact_version: Union[int, str]):
+ self._predictor.artifact_version = artifact_version
+
+ @property
+ def artifact_path(self):
+ """Path of the model artifact deployed by the predictor."""
+ return self._predictor.artifact_path
+
+ @property
+ def model_server(self):
+ """Model server ran by the predictor."""
+ return self._predictor.model_server
+
+ @model_server.setter
+ def model_server(self, model_server: str):
+ self._predictor.model_server = model_server
+
+ @property
+ def serving_tool(self):
+ """Serving tool used to run the model server."""
+ return self._predictor.serving_tool
+
+ @serving_tool.setter
+ def serving_tool(self, serving_tool: str):
+ self._predictor.serving_tool = serving_tool
+
+ @property
+ def script_file(self):
+ """Script file used by the predictor."""
+ return self._predictor.script_file
+
+ @script_file.setter
+ def script_file(self, script_file: str):
+ self._predictor.script_file = script_file
+
+ @property
+ def resources(self):
+ """Resource configuration for the predictor."""
+ return self._predictor.resources
+
+ @resources.setter
+ def resources(self, resources: Resources):
+ self._predictor.resources = resources
+
+ @property
+ def inference_logger(self):
+ """Configuration of the inference logger attached to this predictor."""
+ return self._predictor.inference_logger
+
+ @inference_logger.setter
+ def inference_logger(self, inference_logger: InferenceLogger):
+ self._predictor.inference_logger = inference_logger
+
+ @property
+ def inference_batcher(self):
+ """Configuration of the inference batcher attached to this predictor."""
+ return self._predictor.inference_batcher
+
+ @inference_batcher.setter
+ def inference_batcher(self, inference_batcher: InferenceBatcher):
+ self._predictor.inference_batcher = inference_batcher
+
+ @property
+ def transformer(self):
+ """Transformer configured in the predictor."""
+ return self._predictor.transformer
+
+ @transformer.setter
+ def transformer(self, transformer: Transformer):
+ self._predictor.transformer = transformer
+
+ @property
+ def model_registry_id(self):
+ """Model Registry Id of the deployment."""
+ return self._model_registry_id
+
+ @model_registry_id.setter
+ def model_registry_id(self, model_registry_id: int):
+ self._model_registry_id = model_registry_id
+
+ @property
+ def created_at(self):
+ """Created at date of the predictor."""
+ return self._predictor.created_at
+
+ @property
+ def creator(self):
+ """Creator of the predictor."""
+ return self._predictor.creator
+
+ @property
+ def api_protocol(self):
+ """API protocol enabled in the deployment (e.g., HTTP or GRPC)."""
+ return self._predictor.api_protocol
+
+ @api_protocol.setter
+ def api_protocol(self, api_protocol: str):
+ self._predictor.api_protocol = api_protocol
+
+ def __repr__(self):
+ desc = (
+ f", description: {self._description!r}"
+ if self._description is not None
+ else ""
+ )
+ return f"Deployment(name: {self._predictor._name!r}" + desc + ")"
diff --git a/hsml/python/hsml/engine/__init__.py b/hsml/python/hsml/engine/__init__.py
new file mode 100644
index 000000000..ff0a6f046
--- /dev/null
+++ b/hsml/python/hsml/engine/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/engine/hopsworks_engine.py b/hsml/python/hsml/engine/hopsworks_engine.py
new file mode 100644
index 000000000..79537fa48
--- /dev/null
+++ b/hsml/python/hsml/engine/hopsworks_engine.py
@@ -0,0 +1,65 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+
+from hsml import client
+from hsml.core import model_api, native_hdfs_api
+
+
+class HopsworksEngine:
+ def __init__(self):
+ self._native_hdfs_api = native_hdfs_api.NativeHdfsApi()
+ self._model_api = model_api.ModelApi()
+
+ def mkdir(self, remote_path: str):
+ remote_path = self._prepend_project_path(remote_path)
+ self._native_hdfs_api.mkdir(remote_path)
+ self._native_hdfs_api.chmod(remote_path, "ug+rwx")
+
+ def delete(self, model_instance):
+ self._model_api.delete(model_instance)
+
+ def upload(self, local_path: str, remote_path: str, upload_configuration=None):
+ local_path = self._get_abs_path(local_path)
+ remote_path = self._prepend_project_path(remote_path)
+ self._native_hdfs_api.upload(local_path, remote_path)
+ self._native_hdfs_api.chmod(remote_path, "ug+rwx")
+
+ def download(self, remote_path: str, local_path: str):
+ local_path = self._get_abs_path(local_path)
+ remote_path = self._prepend_project_path(remote_path)
+ self._native_hdfs_api.download(remote_path, local_path)
+
+ def copy(self, source_path: str, destination_path: str):
+ # both paths are hdfs paths
+ source_path = self._prepend_project_path(source_path)
+ destination_path = self._prepend_project_path(destination_path)
+ self._native_hdfs_api.copy(source_path, destination_path)
+
+ def move(self, source_path: str, destination_path: str):
+ source_path = self._prepend_project_path(source_path)
+ destination_path = self._prepend_project_path(destination_path)
+ self._native_hdfs_api.move(source_path, destination_path)
+
+ def _get_abs_path(self, local_path: str):
+ return local_path if os.path.isabs(local_path) else os.path.abspath(local_path)
+
+ def _prepend_project_path(self, remote_path: str):
+ if not remote_path.startswith("/Projects/"):
+ _client = client.get_instance()
+ remote_path = "/Projects/{}/{}".format(_client._project_name, remote_path)
+ return remote_path
diff --git a/hsml/python/hsml/engine/local_engine.py b/hsml/python/hsml/engine/local_engine.py
new file mode 100644
index 000000000..7b669a249
--- /dev/null
+++ b/hsml/python/hsml/engine/local_engine.py
@@ -0,0 +1,79 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+
+from hsml import client
+from hsml.core import dataset_api, model_api
+
+
+class LocalEngine:
+ def __init__(self):
+ self._dataset_api = dataset_api.DatasetApi()
+ self._model_api = model_api.ModelApi()
+
+ def mkdir(self, remote_path: str):
+ remote_path = self._prepend_project_path(remote_path)
+ self._dataset_api.mkdir(remote_path)
+
+ def delete(self, model_instance):
+ self._model_api.delete(model_instance)
+
+ def upload(self, local_path: str, remote_path: str, upload_configuration=None):
+ local_path = self._get_abs_path(local_path)
+ remote_path = self._prepend_project_path(remote_path)
+
+ # Initialize the upload configuration to empty dictionary if is None
+ upload_configuration = upload_configuration if upload_configuration else {}
+ self._dataset_api.upload(
+ local_path,
+ remote_path,
+ chunk_size=upload_configuration.get(
+ "chunk_size", self._dataset_api.DEFAULT_UPLOAD_FLOW_CHUNK_SIZE
+ ),
+ simultaneous_uploads=upload_configuration.get(
+ "simultaneous_uploads",
+ self._dataset_api.DEFAULT_UPLOAD_SIMULTANEOUS_UPLOADS,
+ ),
+ max_chunk_retries=upload_configuration.get(
+ "max_chunk_retries",
+ self._dataset_api.DEFAULT_UPLOAD_MAX_CHUNK_RETRIES,
+ ),
+ )
+
+ def download(self, remote_path: str, local_path: str):
+ local_path = self._get_abs_path(local_path)
+ remote_path = self._prepend_project_path(remote_path)
+ self._dataset_api.download(remote_path, local_path)
+
+ def copy(self, source_path, destination_path):
+ source_path = self._prepend_project_path(source_path)
+ destination_path = self._prepend_project_path(destination_path)
+ self._dataset_api.copy(source_path, destination_path)
+
+ def move(self, source_path, destination_path):
+ source_path = self._prepend_project_path(source_path)
+ destination_path = self._prepend_project_path(destination_path)
+ self._dataset_api.move(source_path, destination_path)
+
+ def _get_abs_path(self, local_path: str):
+ return local_path if os.path.isabs(local_path) else os.path.abspath(local_path)
+
+ def _prepend_project_path(self, remote_path: str):
+ if not remote_path.startswith("/Projects/"):
+ _client = client.get_instance()
+ remote_path = "/Projects/{}/{}".format(_client._project_name, remote_path)
+ return remote_path
diff --git a/hsml/python/hsml/engine/model_engine.py b/hsml/python/hsml/engine/model_engine.py
new file mode 100644
index 000000000..29acd269f
--- /dev/null
+++ b/hsml/python/hsml/engine/model_engine.py
@@ -0,0 +1,554 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import importlib
+import json
+import os
+import tempfile
+import time
+import uuid
+
+from hsml import client, constants, util
+from hsml.client.exceptions import ModelRegistryException, RestAPIError
+from hsml.core import dataset_api, model_api
+from hsml.engine import hopsworks_engine, local_engine
+from tqdm.auto import tqdm
+
+
+class ModelEngine:
+ def __init__(self):
+ self._model_api = model_api.ModelApi()
+ self._dataset_api = dataset_api.DatasetApi()
+
+ pydoop_spec = importlib.util.find_spec("pydoop")
+ if pydoop_spec is None:
+ self._engine = local_engine.LocalEngine()
+ else:
+ self._engine = hopsworks_engine.HopsworksEngine()
+
+ def _poll_model_available(self, model_instance, await_registration):
+ if await_registration > 0:
+ model_registry_id = model_instance.model_registry_id
+ sleep_seconds = 5
+ for _ in range(int(await_registration / sleep_seconds)):
+ try:
+ time.sleep(sleep_seconds)
+ model_meta = self._model_api.get(
+ model_instance.name,
+ model_instance.version,
+ model_registry_id,
+ model_instance.shared_registry_project_name,
+ )
+ if model_meta is not None:
+ return model_meta
+ except RestAPIError as e:
+ if e.response.status_code != 404:
+ raise e
+ print(
+ "Model not available during polling, set a higher value for await_registration to wait longer."
+ )
+
+ def _upload_additional_resources(self, model_instance):
+ if model_instance._input_example is not None:
+ input_example_path = os.path.join(os.getcwd(), "input_example.json")
+ input_example = util.input_example_to_json(model_instance._input_example)
+
+ with open(input_example_path, "w+") as out:
+ json.dump(input_example, out, cls=util.NumpyEncoder)
+
+ self._engine.upload(input_example_path, model_instance.version_path)
+ os.remove(input_example_path)
+ model_instance.input_example = None
+ if model_instance._model_schema is not None:
+ model_schema_path = os.path.join(os.getcwd(), "model_schema.json")
+ model_schema = model_instance._model_schema
+
+ with open(model_schema_path, "w+") as out:
+ out.write(model_schema.json())
+
+ self._engine.upload(model_schema_path, model_instance.version_path)
+ os.remove(model_schema_path)
+ model_instance.model_schema = None
+ return model_instance
+
+ def _copy_or_move_hopsfs_model_item(
+ self, item_attr, to_model_version_path, keep_original_files
+ ):
+ """Copy or move model item from a hdfs path to the model version folder in the Models dataset. It works with files and folders."""
+ path = item_attr["path"]
+ to_hdfs_path = os.path.join(to_model_version_path, os.path.basename(path))
+ if keep_original_files:
+ self._engine.copy(path, to_hdfs_path)
+ else:
+ self._engine.move(path, to_hdfs_path)
+
+ def _copy_or_move_hopsfs_model(
+ self,
+ from_hdfs_model_path,
+ to_model_version_path,
+ keep_original_files,
+ update_upload_progress,
+ ):
+ """Copy or move model files from a hdfs path to the model version folder in the Models dataset."""
+ # Strip hdfs prefix
+ if from_hdfs_model_path.startswith("hdfs:/"):
+ projects_index = from_hdfs_model_path.find("/Projects", 0)
+ from_hdfs_model_path = from_hdfs_model_path[projects_index:]
+
+ n_dirs, n_files = 0, 0
+
+ model_path_resp = self._dataset_api.get(from_hdfs_model_path)
+ model_path_attr = model_path_resp["attributes"]
+ if (
+ "datasetType" in model_path_resp
+ and model_path_resp["datasetType"] == "DATASET"
+ ): # This is needed to avoid a user exporting for example "Resources" from wiping the dataset
+ raise AssertionError(
+ "It is disallowed to export a root dataset path."
+ " Move the model to a sub-folder and try again."
+ )
+ elif model_path_attr.get("dir", False):
+ # if path is a directory, iterate of the directory content
+ for entry in self._dataset_api.list(
+ from_hdfs_model_path, sort_by="NAME:desc"
+ )["items"]:
+ path_attr = entry["attributes"]
+ self._copy_or_move_hopsfs_model_item(
+ path_attr, to_model_version_path, keep_original_files
+ )
+ if path_attr.get("dir", False):
+ n_dirs += 1
+ else:
+ n_files += 1
+ update_upload_progress(n_dirs=n_dirs, n_files=n_files)
+ else:
+ # if path is a file, copy/move it
+ self._copy_or_move_hopsfs_model_item(
+ model_path_attr, to_model_version_path, keep_original_files
+ )
+ n_files += 1
+ update_upload_progress(n_dirs=n_dirs, n_files=n_files)
+
+ def _download_model_from_hopsfs_recursive(
+ self,
+ from_hdfs_model_path: str,
+ to_local_path: str,
+ update_download_progress,
+ n_dirs,
+ n_files,
+ ):
+ """Download model files from a model path in hdfs, recursively"""
+
+ for entry in self._dataset_api.list(from_hdfs_model_path, sort_by="NAME:desc")[
+ "items"
+ ]:
+ path_attr = entry["attributes"]
+ path = path_attr["path"]
+ basename = os.path.basename(path)
+
+ if path_attr.get("dir", False):
+ # otherwise, make a recursive call for the folder
+ if basename == "Artifacts":
+ continue # skip Artifacts subfolder
+ local_folder_path = os.path.join(to_local_path, basename)
+ os.mkdir(local_folder_path)
+ n_dirs, n_files = self._download_model_from_hopsfs_recursive(
+ from_hdfs_model_path=path,
+ to_local_path=local_folder_path,
+ update_download_progress=update_download_progress,
+ n_dirs=n_dirs,
+ n_files=n_files,
+ )
+ n_dirs += 1
+ update_download_progress(n_dirs=n_dirs, n_files=n_files)
+ else:
+ # if it's a file, download it
+ local_file_path = os.path.join(to_local_path, basename)
+ self._engine.download(path, local_file_path)
+ n_files += 1
+ update_download_progress(n_dirs=n_dirs, n_files=n_files)
+
+ return n_dirs, n_files
+
+ def _download_model_from_hopsfs(
+ self, from_hdfs_model_path: str, to_local_path: str, update_download_progress
+ ):
+ """Download model files from a model path in hdfs."""
+
+ n_dirs, n_files = self._download_model_from_hopsfs_recursive(
+ from_hdfs_model_path=from_hdfs_model_path,
+ to_local_path=to_local_path,
+ update_download_progress=update_download_progress,
+ n_dirs=0,
+ n_files=0,
+ )
+ update_download_progress(n_dirs=n_dirs, n_files=n_files, done=True)
+
+ def _upload_local_model(
+ self,
+ from_local_model_path,
+ to_model_version_path,
+ update_upload_progress,
+ upload_configuration=None,
+ ):
+ """Copy or upload model files from a local path to the model version folder in the Models dataset."""
+ n_dirs, n_files = 0, 0
+ if os.path.isdir(from_local_model_path):
+ # if path is a dir, upload files and folders iteratively
+ for root, dirs, files in os.walk(from_local_model_path):
+ # os.walk(local_model_path), where local_model_path is expected to be an absolute path
+ # - root is the absolute path of the directory being walked
+ # - dirs is the list of directory names present in the root dir
+ # - files is the list of file names present in the root dir
+ # we need to replace the local path prefix with the hdfs path prefix (i.e., /srv/hops/....../root with /Projects/.../)
+ remote_base_path = root.replace(
+ from_local_model_path, to_model_version_path
+ )
+ for d_name in dirs:
+ self._engine.mkdir(remote_base_path + "/" + d_name)
+ n_dirs += 1
+ update_upload_progress(n_dirs, n_files)
+ for f_name in files:
+ self._engine.upload(
+ root + "/" + f_name,
+ remote_base_path,
+ upload_configuration=upload_configuration,
+ )
+ n_files += 1
+ update_upload_progress(n_dirs, n_files)
+ else:
+ # if path is a file, upload file
+ self._engine.upload(
+ from_local_model_path,
+ to_model_version_path,
+ upload_configuration=upload_configuration,
+ )
+ n_files += 1
+ update_upload_progress(n_dirs, n_files)
+
+ def _save_model_from_local_or_hopsfs_mount(
+ self,
+ model_instance,
+ model_path,
+ keep_original_files,
+ update_upload_progress,
+ upload_configuration=None,
+ ):
+ """Save model files from a local path. The local path can be on hopsfs mount"""
+ # check hopsfs mount
+ if model_path.startswith(constants.MODEL_REGISTRY.HOPSFS_MOUNT_PREFIX):
+ self._copy_or_move_hopsfs_model(
+ from_hdfs_model_path=model_path.replace(
+ constants.MODEL_REGISTRY.HOPSFS_MOUNT_PREFIX, ""
+ ),
+ to_model_version_path=model_instance.version_path,
+ keep_original_files=keep_original_files,
+ update_upload_progress=update_upload_progress,
+ )
+ else:
+ self._upload_local_model(
+ from_local_model_path=model_path,
+ to_model_version_path=model_instance.version_path,
+ update_upload_progress=update_upload_progress,
+ upload_configuration=upload_configuration,
+ )
+
+ def _set_model_version(
+ self, model_instance, dataset_models_root_path, dataset_model_path
+ ):
+ # Set model version if not defined
+ if model_instance._version is None:
+ current_highest_version = 0
+ for item in self._dataset_api.list(dataset_model_path, sort_by="NAME:desc")[
+ "items"
+ ]:
+ _, file_name = os.path.split(item["attributes"]["path"])
+ try:
+ try:
+ current_version = int(file_name)
+ except ValueError:
+ continue
+ if current_version > current_highest_version:
+ current_highest_version = current_version
+ except RestAPIError:
+ pass
+ model_instance._version = current_highest_version + 1
+
+ elif self._dataset_api.path_exists(
+ dataset_models_root_path
+ + "/"
+ + model_instance._name
+ + "/"
+ + str(model_instance._version)
+ ):
+ raise ModelRegistryException(
+ "Model with name {} and version {} already exists".format(
+ model_instance._name, model_instance._version
+ )
+ )
+ return model_instance
+
+ def _build_resource_path(self, model_instance, artifact):
+ artifact_path = "{}/{}".format(model_instance.version_path, artifact)
+ return artifact_path
+
+ def save(
+ self,
+ model_instance,
+ model_path,
+ await_registration=480,
+ keep_original_files=False,
+ upload_configuration=None,
+ ):
+ _client = client.get_instance()
+
+ is_shared_registry = model_instance.shared_registry_project_name is not None
+
+ if is_shared_registry:
+ dataset_models_root_path = "{}::{}".format(
+ model_instance.shared_registry_project_name,
+ constants.MODEL_SERVING.MODELS_DATASET,
+ )
+ model_instance._project_name = model_instance.shared_registry_project_name
+ else:
+ dataset_models_root_path = constants.MODEL_SERVING.MODELS_DATASET
+ model_instance._project_name = _client._project_name
+
+ util.validate_metrics(model_instance.training_metrics)
+
+ if not self._dataset_api.path_exists(dataset_models_root_path):
+ raise AssertionError(
+ "{} dataset does not exist in this project. Please enable the Serving service or create it manually.".format(
+ dataset_models_root_path
+ )
+ )
+
+ # Create /Models/{model_instance._name} folder
+ dataset_model_name_path = dataset_models_root_path + "/" + model_instance._name
+ if not self._dataset_api.path_exists(dataset_model_name_path):
+ self._engine.mkdir(dataset_model_name_path)
+
+ model_instance = self._set_model_version(
+ model_instance, dataset_models_root_path, dataset_model_name_path
+ )
+
+ # Attach model summary xattr to /Models/{model_instance._name}/{model_instance._version}
+ model_query_params = {}
+
+ if "ML_ID" in os.environ:
+ model_instance._experiment_id = os.environ["ML_ID"]
+
+ model_instance._experiment_project_name = _client._project_name
+
+ if "HOPSWORKS_JOB_NAME" in os.environ:
+ model_query_params["jobName"] = os.environ["HOPSWORKS_JOB_NAME"]
+ elif "HOPSWORKS_KERNEL_ID" in os.environ:
+ model_query_params["kernelId"] = os.environ["HOPSWORKS_KERNEL_ID"]
+
+ pbar = tqdm(
+ [
+ {"id": 0, "desc": "Creating model folder"},
+ {"id": 1, "desc": "Uploading model files"},
+ {"id": 2, "desc": "Uploading input_example and model_schema"},
+ {"id": 3, "desc": "Registering model"},
+ {"id": 4, "desc": "Waiting for model registration"},
+ {"id": 5, "desc": "Model export complete"},
+ ]
+ )
+
+ for step in pbar:
+ try:
+ pbar.set_description("%s" % step["desc"])
+ if step["id"] == 0:
+ # Create folders
+ self._engine.mkdir(model_instance.version_path)
+ if step["id"] == 1:
+
+ def update_upload_progress(n_dirs=0, n_files=0, step=step):
+ pbar.set_description(
+ "%s (%s dirs, %s files)" % (step["desc"], n_dirs, n_files)
+ )
+
+ update_upload_progress(n_dirs=0, n_files=0)
+
+ # Upload Model files from local path to /Models/{model_instance._name}/{model_instance._version}
+ # check local absolute
+ if os.path.isabs(model_path) and os.path.exists(model_path):
+ self._save_model_from_local_or_hopsfs_mount(
+ model_instance=model_instance,
+ model_path=model_path,
+ keep_original_files=keep_original_files,
+ update_upload_progress=update_upload_progress,
+ upload_configuration=upload_configuration,
+ )
+ # check local relative
+ elif os.path.exists(
+ os.path.join(os.getcwd(), model_path)
+ ): # check local relative
+ self._save_model_from_local_or_hopsfs_mount(
+ model_instance=model_instance,
+ model_path=os.path.join(os.getcwd(), model_path),
+ keep_original_files=keep_original_files,
+ update_upload_progress=update_upload_progress,
+ upload_configuration=upload_configuration,
+ )
+ # check project relative
+ elif self._dataset_api.path_exists(
+ model_path
+ ): # check hdfs relative and absolute
+ self._copy_or_move_hopsfs_model(
+ from_hdfs_model_path=model_path,
+ to_model_version_path=model_instance.version_path,
+ keep_original_files=keep_original_files,
+ update_upload_progress=update_upload_progress,
+ )
+ else:
+ raise IOError(
+ "Could not find path {} in the local filesystem or in Hopsworks File System".format(
+ model_path
+ )
+ )
+ if step["id"] == 2:
+ model_instance = self._upload_additional_resources(model_instance)
+ if step["id"] == 3:
+ model_instance = self._model_api.put(
+ model_instance, model_query_params
+ )
+ if step["id"] == 4:
+ model_instance = self._poll_model_available(
+ model_instance, await_registration
+ )
+ if step["id"] == 5:
+ pass
+ except BaseException as be:
+ self._dataset_api.rm(model_instance.version_path)
+ raise be
+
+ print("Model created, explore it at " + model_instance.get_url())
+
+ return model_instance
+
+ def download(self, model_instance):
+ model_name_path = os.path.join(
+ tempfile.gettempdir(), str(uuid.uuid4()), model_instance._name
+ )
+ model_version_path = model_name_path + "/" + str(model_instance._version)
+ os.makedirs(model_version_path)
+
+ def update_download_progress(n_dirs, n_files, done=False):
+ print(
+ "Downloading model artifact (%s dirs, %s files)... %s"
+ % (n_dirs, n_files, "DONE" if done else ""),
+ end="\r",
+ )
+
+ try:
+ from_hdfs_model_path = model_instance.version_path
+ if from_hdfs_model_path.startswith("hdfs:/"):
+ projects_index = from_hdfs_model_path.find("/Projects", 0)
+ from_hdfs_model_path = from_hdfs_model_path[projects_index:]
+
+ self._download_model_from_hopsfs(
+ from_hdfs_model_path=from_hdfs_model_path,
+ to_local_path=model_version_path,
+ update_download_progress=update_download_progress,
+ )
+ except BaseException as be:
+ raise be
+
+ return model_version_path
+
+ def read_file(self, model_instance, resource):
+ hdfs_resource_path = self._build_resource_path(
+ model_instance, os.path.basename(resource)
+ )
+ if self._dataset_api.path_exists(hdfs_resource_path):
+ try:
+ resource = os.path.basename(resource)
+ tmp_dir = tempfile.TemporaryDirectory(dir=os.getcwd())
+ local_resource_path = os.path.join(tmp_dir.name, resource)
+ self._engine.download(
+ hdfs_resource_path,
+ local_resource_path,
+ )
+ with open(local_resource_path, "r") as f:
+ return f.read()
+ finally:
+ if tmp_dir is not None and os.path.exists(tmp_dir.name):
+ tmp_dir.cleanup()
+
+ def read_json(self, model_instance, resource):
+ hdfs_resource_path = self._build_resource_path(model_instance, resource)
+ if self._dataset_api.path_exists(hdfs_resource_path):
+ try:
+ tmp_dir = tempfile.TemporaryDirectory(dir=os.getcwd())
+ local_resource_path = os.path.join(tmp_dir.name, resource)
+ self._engine.download(
+ hdfs_resource_path,
+ local_resource_path,
+ )
+ with open(local_resource_path, "rb") as f:
+ return json.loads(f.read())
+ finally:
+ if tmp_dir is not None and os.path.exists(tmp_dir.name):
+ tmp_dir.cleanup()
+
+ def delete(self, model_instance):
+ self._engine.delete(model_instance)
+
+ def set_tag(self, model_instance, name, value):
+ """Attach a name/value tag to a model."""
+ self._model_api.set_tag(model_instance, name, value)
+
+ def delete_tag(self, model_instance, name):
+ """Remove a tag from a model."""
+ self._model_api.delete_tag(model_instance, name)
+
+ def get_tag(self, model_instance, name):
+ """Get tag with a certain name."""
+ return self._model_api.get_tags(model_instance, name)[name]
+
+ def get_tags(self, model_instance):
+ """Get all tags for a model."""
+ return self._model_api.get_tags(model_instance)
+
+ def get_feature_view_provenance(self, model_instance):
+ """Get the parent feature view of this model, based on explicit provenance.
+ These feature views can be accessible, deleted or inaccessible.
+ For deleted and inaccessible feature views, only a minimal information is
+ returned.
+
+ # Arguments
+ model_instance: Metadata object of model.
+
+ # Returns
+ `ProvenanceLinks`: the feature view used to generate this model
+ """
+ return self._model_api.get_feature_view_provenance(model_instance)
+
+ def get_training_dataset_provenance(self, model_instance):
+ """Get the parent training dataset of this model, based on explicit provenance.
+ These training datasets can be accessible, deleted or inaccessible.
+ For deleted and inaccessible feature views, only a minimal information is
+ returned.
+
+ # Arguments
+ model_instance: Metadata object of model.
+
+ # Returns
+ `ProvenanceLinks`: the training dataset used to generate this model
+ """
+ return self._model_api.get_training_dataset_provenance(model_instance)
diff --git a/hsml/python/hsml/engine/serving_engine.py b/hsml/python/hsml/engine/serving_engine.py
new file mode 100644
index 000000000..15e2b3fa6
--- /dev/null
+++ b/hsml/python/hsml/engine/serving_engine.py
@@ -0,0 +1,690 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+import time
+import uuid
+from typing import Dict, List, Union
+
+from hsml import util
+from hsml.client.exceptions import ModelServingException, RestAPIError
+from hsml.client.istio.utils.infer_type import InferInput
+from hsml.constants import (
+ DEPLOYMENT,
+ PREDICTOR,
+ PREDICTOR_STATE,
+)
+from hsml.constants import (
+ INFERENCE_ENDPOINTS as IE,
+)
+from hsml.core import dataset_api, serving_api
+from tqdm.auto import tqdm
+
+
+class ServingEngine:
+ START_STEPS = [
+ PREDICTOR_STATE.CONDITION_TYPE_STOPPED,
+ PREDICTOR_STATE.CONDITION_TYPE_SCHEDULED,
+ PREDICTOR_STATE.CONDITION_TYPE_INITIALIZED,
+ PREDICTOR_STATE.CONDITION_TYPE_STARTED,
+ PREDICTOR_STATE.CONDITION_TYPE_READY,
+ ]
+ STOP_STEPS = [
+ PREDICTOR_STATE.CONDITION_TYPE_SCHEDULED,
+ PREDICTOR_STATE.CONDITION_TYPE_STOPPED,
+ ]
+
+ def __init__(self):
+ self._serving_api = serving_api.ServingApi()
+ self._dataset_api = dataset_api.DatasetApi()
+
+ def _poll_deployment_status(
+ self, deployment_instance, status: str, await_status: int, update_progress=None
+ ):
+ if await_status > 0:
+ sleep_seconds = 5
+ for _ in range(int(await_status / sleep_seconds)):
+ time.sleep(sleep_seconds)
+ state = deployment_instance.get_state()
+ num_instances = self._get_available_instances(state)
+ if update_progress is not None:
+ update_progress(state, num_instances)
+ if state.status == status:
+ return state # deployment reached desired status
+ elif (
+ status == PREDICTOR_STATE.STATUS_RUNNING
+ and state.status == PREDICTOR_STATE.STATUS_FAILED
+ ):
+ error_msg = state.condition.reason
+ if (
+ state.condition.type
+ == PREDICTOR_STATE.CONDITION_TYPE_INITIALIZED
+ or state.condition.type
+ == PREDICTOR_STATE.CONDITION_TYPE_STARTED
+ ):
+ component = (
+ "transformer"
+ if "transformer" in state.condition.reason
+ else "predictor"
+ )
+ error_msg += (
+ ". Please, check the server logs using `.get_logs(component='"
+ + component
+ + "')`"
+ )
+ raise ModelServingException(error_msg)
+ raise ModelServingException(
+ "Deployment has not reached the desired status within the expected awaiting time. Check the current status by using `.get_state()`, "
+ + "explore the server logs using `.get_logs()` or set a higher value for await_"
+ + status.lower()
+ )
+
+ def start(self, deployment_instance, await_status: int) -> bool:
+ (done, state) = self._check_status(
+ deployment_instance, PREDICTOR_STATE.STATUS_RUNNING
+ )
+
+ if not done:
+ min_instances = self._get_min_starting_instances(deployment_instance)
+ num_steps = (len(self.START_STEPS) - 1) + min_instances
+ if deployment_instance._predictor._state.condition is None:
+ num_steps = min_instances # backward compatibility
+ pbar = tqdm(total=num_steps)
+ pbar.set_description("Creating deployment")
+
+ # set progress function
+ def update_progress(state, num_instances):
+ (progress, desc) = self._get_starting_progress(
+ pbar.n, state, num_instances
+ )
+ pbar.update(progress)
+ if desc is not None:
+ pbar.set_description(desc)
+
+ try:
+ update_progress(state, num_instances=0)
+
+ if state.status == PREDICTOR_STATE.STATUS_CREATING:
+ state = self._poll_deployment_status( # wait for preparation
+ deployment_instance,
+ PREDICTOR_STATE.STATUS_CREATED,
+ await_status,
+ update_progress,
+ )
+
+ self._serving_api.post(
+ deployment_instance, DEPLOYMENT.ACTION_START
+ ) # start deployment
+
+ state = self._poll_deployment_status( # wait for status
+ deployment_instance,
+ PREDICTOR_STATE.STATUS_RUNNING,
+ await_status,
+ update_progress,
+ )
+ except RestAPIError as re:
+ self.stop(deployment_instance, await_status=0)
+ raise re
+
+ if state.status == PREDICTOR_STATE.STATUS_RUNNING:
+ print("Start making predictions by using `.predict()`")
+
+ def stop(self, deployment_instance, await_status: int) -> bool:
+ (done, state) = self._check_status(
+ deployment_instance, PREDICTOR_STATE.STATUS_STOPPED
+ )
+ if not done:
+ num_instances = self._get_available_instances(state)
+ num_steps = len(self.STOP_STEPS) + (
+ deployment_instance.requested_instances
+ if deployment_instance.requested_instances >= num_instances
+ else num_instances
+ )
+ if deployment_instance._predictor._state.condition is None:
+ # backward compatibility
+ num_steps = self._get_min_starting_instances(deployment_instance)
+ pbar = tqdm(total=num_steps)
+ pbar.set_description("Preparing to stop deployment")
+
+ # set progress function
+ def update_progress(state, num_instances):
+ (progress, desc) = self._get_stopping_progress(
+ pbar.total, pbar.n, state, num_instances
+ )
+ pbar.update(progress)
+ if desc is not None:
+ pbar.set_description(desc)
+
+ update_progress(state, num_instances)
+ self._serving_api.post(
+ deployment_instance, DEPLOYMENT.ACTION_STOP
+ ) # stop deployment
+
+ _ = self._poll_deployment_status( # wait for status
+ deployment_instance,
+ PREDICTOR_STATE.STATUS_STOPPED,
+ await_status,
+ update_progress,
+ )
+
+ # free grpc channel
+ deployment_instance._grpc_channel = None
+
+ def _check_status(self, deployment_instance, desired_status):
+ state = deployment_instance.get_state()
+ if state is None:
+ return (True, None)
+
+ # desired status: running
+ if desired_status == PREDICTOR_STATE.STATUS_RUNNING:
+ if (
+ state.status == PREDICTOR_STATE.STATUS_RUNNING
+ or state.status == PREDICTOR_STATE.STATUS_IDLE
+ ):
+ print("Deployment is already running")
+ return (True, state)
+ if state.status == PREDICTOR_STATE.STATUS_STARTING:
+ print("Deployment is already starting")
+ return (True, state)
+ if state.status == PREDICTOR_STATE.STATUS_UPDATING:
+ print("Deployments is already running and updating")
+ return (True, state)
+ if state.status == PREDICTOR_STATE.STATUS_FAILED:
+ print("Deployment is in failed state. " + state.condition.reason)
+ return (True, state)
+ if state.status == PREDICTOR_STATE.STATUS_STOPPING:
+ raise ModelServingException(
+ "Deployment is stopping, please wait until it completely stops"
+ )
+
+ # desired status: stopped
+ if desired_status == PREDICTOR_STATE.STATUS_STOPPED:
+ if (
+ state.status == PREDICTOR_STATE.STATUS_CREATING
+ or state.status == PREDICTOR_STATE.STATUS_CREATED
+ or state.status == PREDICTOR_STATE.STATUS_STOPPED
+ ):
+ print("Deployment is already stopped")
+ return (True, state)
+ if state.status == PREDICTOR_STATE.STATUS_STOPPING:
+ print("Deployment is already stopping")
+ return (True, state)
+
+ return (False, state)
+
+ def _get_starting_progress(self, current_step, state, num_instances):
+ if state.condition is None: # backward compatibility
+ progress = num_instances - current_step
+ if state.status == PREDICTOR_STATE.STATUS_RUNNING:
+ return (progress, "Deployment is ready")
+ return (progress, None if current_step == 0 else "Deployment is starting")
+
+ step = self.START_STEPS.index(state.condition.type)
+ if (
+ state.condition.type == PREDICTOR_STATE.CONDITION_TYPE_STARTED
+ or state.condition.type == PREDICTOR_STATE.CONDITION_TYPE_READY
+ ):
+ step += num_instances
+ progress = step - current_step
+ desc = None
+ if state.condition.type != PREDICTOR_STATE.CONDITION_TYPE_STOPPED:
+ desc = (
+ state.condition.reason
+ if state.status != PREDICTOR_STATE.STATUS_FAILED
+ else "Deployment failed to start"
+ )
+ return (progress, desc)
+
+ def _get_stopping_progress(self, total_steps, current_step, state, num_instances):
+ if state.condition is None: # backward compatibility
+ progress = (total_steps - num_instances) - current_step
+ if state.status == PREDICTOR_STATE.STATUS_STOPPED:
+ return (progress, "Deployment is stopped")
+ return (
+ progress,
+ None if total_steps == current_step else "Deployment is stopping",
+ )
+
+ step = 0
+ if state.condition.type == PREDICTOR_STATE.CONDITION_TYPE_SCHEDULED:
+ step = 1 if state.condition.status is None else 0
+ elif state.condition.type == PREDICTOR_STATE.CONDITION_TYPE_STOPPED:
+ num_instances = (total_steps - 2) - num_instances # num stopped instances
+ step = (
+ (2 + num_instances)
+ if (state.condition.status is None or state.condition.status)
+ else 0
+ )
+ progress = step - current_step
+ desc = None
+ if (
+ state.condition.type != PREDICTOR_STATE.CONDITION_TYPE_READY
+ and state.status != PREDICTOR_STATE.STATUS_FAILED
+ ):
+ desc = (
+ "Deployment is stopped"
+ if state.status == PREDICTOR_STATE.STATUS_STOPPED
+ else state.condition.reason
+ )
+
+ return (progress, desc)
+
+ def _get_min_starting_instances(self, deployment_instance):
+ min_start_instances = 1 # predictor
+ if deployment_instance.transformer is not None:
+ min_start_instances += 1 # transformer
+ return (
+ deployment_instance.requested_instances
+ if deployment_instance.requested_instances >= min_start_instances
+ else min_start_instances
+ )
+
+ def _get_available_instances(self, state):
+ if state.status == PREDICTOR_STATE.STATUS_CREATING:
+ return 0
+ num_instances = state.available_predictor_instances
+ if state.available_transformer_instances is not None:
+ num_instances += state.available_transformer_instances
+ return num_instances
+
+ def _get_stopped_instances(self, available_instances, requested_instances):
+ num_instances = requested_instances - available_instances
+ return num_instances if num_instances >= 0 else 0
+
+ def download_artifact(self, deployment_instance):
+ if deployment_instance.id is None:
+ raise ModelServingException(
+ "Deployment is not created yet. To create the deployment use `.save()`"
+ )
+ if deployment_instance.artifact_version is None:
+ # model artifacts are not created in non-k8s installations
+ raise ModelServingException(
+ "Model artifacts not supported in non-k8s installations. \
+ Download the model files by using `model.download()`"
+ )
+
+ from_artifact_zip_path = deployment_instance.artifact_path
+ to_artifacts_path = os.path.join(
+ os.getcwd(),
+ str(uuid.uuid4()),
+ deployment_instance.model_name,
+ str(deployment_instance.model_version),
+ "Artifacts",
+ )
+ to_artifact_version_path = (
+ to_artifacts_path + "/" + str(deployment_instance.artifact_version)
+ )
+ to_artifact_zip_path = to_artifact_version_path + ".zip"
+
+ os.makedirs(to_artifacts_path)
+
+ try:
+ self._dataset_api.download(from_artifact_zip_path, to_artifact_zip_path)
+ util.decompress(to_artifact_zip_path, extract_dir=to_artifacts_path)
+ os.remove(to_artifact_zip_path)
+ finally:
+ if os.path.exists(to_artifact_zip_path):
+ os.remove(to_artifact_zip_path)
+
+ return to_artifact_version_path
+
+ def create(self, deployment_instance):
+ try:
+ self._serving_api.put(deployment_instance)
+ print("Deployment created, explore it at " + deployment_instance.get_url())
+ except RestAPIError as re:
+ raise_err = True
+ if re.error_code == ModelServingException.ERROR_CODE_DUPLICATED_ENTRY:
+ msg = "Deployment with the same name already exists"
+ existing_deployment = self._serving_api.get(deployment_instance.name)
+ if (
+ existing_deployment.model_name == deployment_instance.model_name
+ and existing_deployment.model_version
+ == deployment_instance.model_version
+ ): # if same name and model version, retrieve existing deployment
+ print(msg + ". Getting existing deployment...")
+ print("To create a new deployment choose a different name.")
+ deployment_instance.update_from_response_json(
+ existing_deployment.to_dict()
+ )
+ raise_err = False
+ else: # otherwise, raise an exception
+ print(", but it is serving a different model version.")
+ print("Please, choose a different name.")
+
+ if raise_err:
+ raise re
+
+ if deployment_instance.is_stopped():
+ print("Before making predictions, start the deployment by using `.start()`")
+
+ def update(self, deployment_instance, await_update):
+ state = deployment_instance.get_state()
+ if state is None:
+ return
+
+ if state.status == PREDICTOR_STATE.STATUS_STARTING:
+ # if starting, it cannot be updated yet
+ raise ModelServingException(
+ "Deployment is starting, please wait until it is running before applying changes. \n"
+ + "Check the current status by using `.get_state()` or explore the server logs using `.get_logs()`"
+ )
+ if (
+ state.status == PREDICTOR_STATE.STATUS_RUNNING
+ or state.status == PREDICTOR_STATE.STATUS_IDLE
+ or state.status == PREDICTOR_STATE.STATUS_FAILED
+ ):
+ # if running, it's fine
+ self._serving_api.put(deployment_instance)
+ print("Deployment updated, applying changes to running instances...")
+ state = self._poll_deployment_status( # wait for status
+ deployment_instance, PREDICTOR_STATE.STATUS_RUNNING, await_update
+ )
+ if state is not None:
+ if state.status == PREDICTOR_STATE.STATUS_RUNNING:
+ print("Running instances updated successfully")
+ return
+ if state.status == PREDICTOR_STATE.STATUS_UPDATING:
+ # if updating, it cannot be updated yet
+ raise ModelServingException(
+ "Deployment is updating, please wait until it is running before applying changes. \n"
+ + "Check the current status by using `.get_state()` or explore the server logs using `.get_logs()`"
+ )
+ if state.status == PREDICTOR_STATE.STATUS_STOPPING:
+ # if stopping, it cannot be updated yet
+ raise ModelServingException(
+ "Deployment is stopping, please wait until it is stopped before applying changes"
+ )
+ if (
+ state.status == PREDICTOR_STATE.STATUS_CREATING
+ or state.status == PREDICTOR_STATE.STATUS_CREATED
+ or state.status == PREDICTOR_STATE.STATUS_STOPPED
+ ):
+ # if stopped, it's fine
+ self._serving_api.put(deployment_instance)
+ print("Deployment updated, explore it at " + deployment_instance.get_url())
+ return
+
+ raise ValueError("Unknown deployment status: " + state.status)
+
+ def save(self, deployment_instance, await_update: int):
+ if deployment_instance.id is None:
+ # if new deployment
+ self.create(deployment_instance)
+ return
+
+ # if existing deployment
+ self.update(deployment_instance, await_update)
+
+ def delete(self, deployment_instance, force=False):
+ state = deployment_instance.get_state()
+ if state is None:
+ return
+
+ if (
+ not force
+ and state.status != PREDICTOR_STATE.STATUS_STOPPED
+ and state.status != PREDICTOR_STATE.STATUS_CREATED
+ ):
+ raise ModelServingException(
+ "Deployment not stopped, please stop it first by using `.stop()` or check its status with .get_state()"
+ )
+
+ self._serving_api.delete(deployment_instance)
+ print("Deployment deleted successfully")
+
+ def get_state(self, deployment_instance):
+ try:
+ state = self._serving_api.get_state(deployment_instance)
+ except RestAPIError as re:
+ if re.error_code == ModelServingException.ERROR_CODE_SERVING_NOT_FOUND:
+ raise ModelServingException("Deployment not found") from re
+ raise re
+ deployment_instance._predictor._set_state(state)
+ return state
+
+ def get_logs(self, deployment_instance, component, tail):
+ state = self.get_state(deployment_instance)
+ if state is None:
+ return
+
+ if state.status == PREDICTOR_STATE.STATUS_STOPPING:
+ print(
+ "Deployment is stopping, explore historical logs at "
+ + deployment_instance.get_url()
+ )
+ return
+ if state.status == PREDICTOR_STATE.STATUS_STOPPED:
+ print(
+ "Deployment not running, explore historical logs at "
+ + deployment_instance.get_url()
+ )
+ return
+ if state.status == PREDICTOR_STATE.STATUS_STARTING:
+ print("Deployment is starting, server logs might not be ready yet")
+
+ print(
+ "Explore all the logs and filters in the Kibana logs at "
+ + deployment_instance.get_url(),
+ end="\n\n",
+ )
+
+ return self._serving_api.get_logs(deployment_instance, component, tail)
+
+ # Model inference
+
+ def predict(
+ self,
+ deployment_instance,
+ data: Union[Dict, List[InferInput]],
+ inputs: Union[Dict, List[Dict]],
+ ):
+ # validate user-provided payload
+ self._validate_inference_payload(deployment_instance.api_protocol, data, inputs)
+
+ # build inference payload based on API protocol
+ payload = self._build_inference_payload(
+ deployment_instance.api_protocol, data, inputs
+ )
+
+ # if not KServe, send request through Hopsworks
+ serving_tool = deployment_instance.predictor.serving_tool
+ through_hopsworks = serving_tool != PREDICTOR.SERVING_TOOL_KSERVE
+ try:
+ return self._serving_api.send_inference_request(
+ deployment_instance, payload, through_hopsworks
+ )
+ except RestAPIError as re:
+ if (
+ re.response.status_code == RestAPIError.STATUS_CODE_NOT_FOUND
+ or re.error_code
+ == ModelServingException.ERROR_CODE_DEPLOYMENT_NOT_RUNNING
+ ):
+ raise ModelServingException(
+ "Deployment not created or running. If it is already created, start it by using `.start()` or check its status with .get_state()"
+ ) from re
+
+ re.args = (
+ re.args[0] + "\n\n Check the model server logs by using `.get_logs()`",
+ )
+ raise re
+
+ def _validate_inference_payload(
+ self,
+ api_protocol,
+ data: Union[Dict, List[InferInput]],
+ inputs: Union[Dict, List[Dict]],
+ ):
+ """Validates the user-provided inference payload. Either data or inputs parameter is expected, but both cannot be provided together."""
+ # check null inputs
+ if data is not None and inputs is not None:
+ raise ModelServingException(
+ "Inference data and inputs parameters cannot be provided together."
+ )
+ # check data or inputs
+ if data is not None:
+ self._validate_inference_data(api_protocol, data)
+ else:
+ self._validate_inference_inputs(api_protocol, inputs)
+
+ def _validate_inference_data(
+ self, api_protocol, data: Union[Dict, List[InferInput]]
+ ):
+ """Validates the inference payload when provided through the `data` parameter. The data parameter contains the raw payload to be sent
+ in the inference request and should have the corresponding type and format depending on the API protocol.
+ For the REST protocol, data should be a dictionary. For GRPC protocol, one or more InferInput objects is expected.
+ """
+ if api_protocol == IE.API_PROTOCOL_REST: # REST protocol
+ if isinstance(data, Dict):
+ if "instances" not in data and "inputs" not in data:
+ raise ModelServingException(
+ "Inference data is missing 'instances' key."
+ )
+
+ payload = data["instances"] if "instances" in data else data["inputs"]
+ if not isinstance(payload, List):
+ raise ModelServingException(
+ "Instances field should contain a 2-dim list."
+ )
+ elif len(payload) == 0:
+ raise ModelServingException(
+ "Inference data cannot contain an empty list."
+ )
+ elif not isinstance(payload[0], List):
+ raise ModelServingException(
+ "Instances field should contain a 2-dim list."
+ )
+ elif len(payload[0]) == 0:
+ raise ModelServingException(
+ "Inference data cannot contain an empty list."
+ )
+ else: # not Dict
+ if isinstance(data, InferInput) or (
+ isinstance(data, List) and isinstance(data[0], InferInput)
+ ):
+ raise ModelServingException(
+ "Inference data cannot contain `InferInput` for deployments with gRPC protocol disabled. Use a dictionary instead."
+ )
+ raise ModelServingException(
+ "Inference data must be a dictionary. Otherwise, use the `inputs` parameter."
+ )
+
+ else: # gRPC protocol
+ if isinstance(data, Dict):
+ raise ModelServingException(
+ "Inference data cannot be a dictionary for deployments with gRPC protocol enabled. "
+ "Create a `InferInput` object or use the `inputs` parameter instead."
+ )
+ elif isinstance(data, List):
+ if len(data) == 0:
+ raise ModelServingException(
+ "Inference data cannot contain an empty list."
+ )
+ if not isinstance(data[0], InferInput):
+ raise ModelServingException(
+ "Inference data must contain a list of `InferInput` objects. Otherwise, use the `inputs` parameter."
+ )
+ else:
+ raise ModelServingException(
+ "Inference data must contain a list of `InferInput` objects for deployments with gRPC protocol enabled."
+ )
+
+ def _validate_inference_inputs(
+ self, api_protocol, inputs: Union[Dict, List[Dict]], recursive_call=False
+ ):
+ """Validates the inference payload when provided through the `inputs` parameter. The inputs parameter contains only the payload values,
+ which will be parsed when building the request payload. It can be either a dictionary or a list.
+ """
+ if isinstance(inputs, List):
+ if len(inputs) == 0:
+ raise ModelServingException("Inference inputs cannot be an empty list.")
+ else:
+ self._validate_inference_inputs(
+ api_protocol, inputs[0], recursive_call=True
+ )
+ elif isinstance(inputs, InferInput):
+ raise ModelServingException(
+ "Inference inputs cannot be of type `InferInput`. Use the `data` parameter instead."
+ )
+ elif isinstance(inputs, Dict):
+ required_keys = ("name", "shape", "datatype", "data")
+ if api_protocol == IE.API_PROTOCOL_GRPC and not all(
+ k in inputs for k in required_keys
+ ):
+ raise ModelServingException(
+ f"Inference inputs is missing one or more keys. Required keys are [{', '.join(required_keys)}]."
+ )
+ elif not recursive_call or (api_protocol == IE.API_PROTOCOL_GRPC):
+ # if it is the first call to this method, inputs have an invalid type/format
+ # if GRPC protocol is used, only Dict type is valid for the input values
+ raise ModelServingException(
+ "Inference inputs type is not valid. Supported types are dictionary and list."
+ )
+
+ def _build_inference_payload(
+ self,
+ api_protocol,
+ data: Union[Dict, List[InferInput]],
+ inputs: Union[Dict, List[Dict]],
+ ):
+ """Build the inference payload for an inference request. If the 'data' parameter is provided, this method ensures
+ it has the correct format depending on the API protocol. Otherwise, if the 'inputs' parameter is provided, this method
+ builds the correct request payload depending on the API protocol.
+ """
+ if data is not None:
+ # data contains the raw payload (dict or InferInput), nothing needs to be changed
+ return data
+ else: # parse inputs
+ return self._parse_inference_inputs(api_protocol, inputs)
+
+ def _parse_inference_inputs(
+ self, api_protocol, inputs: Union[Dict, List[Dict]], recursive_call=False
+ ):
+ if api_protocol == IE.API_PROTOCOL_REST: # REST protocol
+ if not isinstance(inputs, List):
+ data = {"instances": [[inputs]]} # wrap inputs in a 2-dim list
+ else:
+ data = {"instances": inputs} # use given inputs list by default
+ # check depth of the list: at least two levels are required for batch inference
+ # if the content is neither a list or dict, wrap it in an additional list
+ for i in inputs:
+ if not isinstance(i, List) and not isinstance(i, Dict):
+ # if there are no two levels, wrap inputs in a list
+ data = {"instances": [inputs]}
+ break
+ else: # gRPC protocol
+ if isinstance(inputs, Dict): # Dict
+ data = InferInput(
+ name=inputs["name"],
+ shape=inputs["shape"],
+ datatype=inputs["datatype"],
+ data=inputs["data"],
+ parameters=(
+ inputs["parameters"] if "parameters" in inputs else None
+ ),
+ )
+ if not recursive_call:
+ # if inputs is of type Dict, return a singleton
+ data = [data]
+
+ else: # List[Dict]
+ data = inputs
+ for index, inputs_item in enumerate(inputs):
+ data[index] = self._parse_inference_inputs(
+ api_protocol, inputs_item, recursive_call=True
+ )
+
+ return data
diff --git a/hsml/python/hsml/inference_batcher.py b/hsml/python/hsml/inference_batcher.py
new file mode 100644
index 000000000..265615c56
--- /dev/null
+++ b/hsml/python/hsml/inference_batcher.py
@@ -0,0 +1,136 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Optional
+
+import humps
+from hsml import util
+from hsml.constants import INFERENCE_BATCHER
+
+
+class InferenceBatcher:
+ """Configuration of an inference batcher for a predictor.
+
+ # Arguments
+ enabled: Whether the inference batcher is enabled or not. The default value is `false`.
+ max_batch_size: Maximum requests batch size.
+ max_latency: Maximum latency for request batching.
+ timeout: Maximum waiting time for request batching.
+ # Returns
+ `InferenceLogger`. Configuration of an inference logger.
+ """
+
+ def __init__(
+ self,
+ enabled: Optional[bool] = None,
+ max_batch_size: Optional[int] = None,
+ max_latency: Optional[int] = None,
+ timeout: Optional[int] = None,
+ **kwargs,
+ ):
+ self._enabled = enabled if enabled is not None else INFERENCE_BATCHER.ENABLED
+ self._max_batch_size = max_batch_size if max_batch_size is not None else None
+ self._max_latency = max_latency if max_latency is not None else None
+ self._timeout = timeout if timeout is not None else None
+
+ def describe(self):
+ """Print a description of the inference batcher"""
+ util.pretty_print(self)
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ return cls.from_json(json_decamelized)
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ return InferenceBatcher(**cls.extract_fields_from_json(json_decamelized))
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ config = (
+ json_decamelized.pop("batching_configuration")
+ if "batching_configuration" in json_decamelized
+ else json_decamelized
+ )
+ kwargs = {}
+ kwargs["enabled"] = util.extract_field_from_json(
+ config, ["batching_enabled", "enabled"]
+ )
+ kwargs["max_batch_size"] = util.extract_field_from_json(
+ config, "max_batch_size"
+ )
+ kwargs["max_latency"] = util.extract_field_from_json(config, "max_latency")
+ kwargs["timeout"] = util.extract_field_from_json(config, "timeout")
+
+ return kwargs
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ self.__init__(**self.extract_fields_from_json(json_decamelized))
+ return self
+
+ def json(self):
+ return json.dumps(self, cls=util.MLEncoder)
+
+ def to_dict(self):
+ json = {"batchingEnabled": self._enabled}
+ if self._max_batch_size is not None:
+ json["maxBatchSize"] = self._max_batch_size
+ if self._max_latency is not None:
+ json["maxLatency"] = self._max_latency
+ if self._timeout is not None:
+ json["timeout"] = self._timeout
+ return {"batchingConfiguration": json}
+
+ @property
+ def enabled(self):
+ """Whether the inference batcher is enabled or not."""
+ return self._enabled
+
+ @enabled.setter
+ def enabled(self, enabled: bool):
+ self._enabled = enabled
+
+ @property
+ def max_batch_size(self):
+ """Maximum requests batch size."""
+ return self._max_batch_size
+
+ @max_batch_size.setter
+ def max_batch_size(self, max_batch_size: int):
+ self._max_batch_size = max_batch_size
+
+ @property
+ def max_latency(self):
+ """Maximum latency."""
+ return self._max_latency
+
+ @max_latency.setter
+ def max_latency(self, max_latency: int):
+ self._max_latency = max_latency
+
+ @property
+ def timeout(self):
+ """Maximum timeout."""
+ return self._timeout
+
+ @timeout.setter
+ def timeout(self, timeout: int):
+ self._timeout = timeout
+
+ def __repr__(self):
+ return f"InferenceBatcher(enabled: {self._enabled!r})"
diff --git a/hsml/python/hsml/inference_endpoint.py b/hsml/python/hsml/inference_endpoint.py
new file mode 100644
index 000000000..af031dbf5
--- /dev/null
+++ b/hsml/python/hsml/inference_endpoint.py
@@ -0,0 +1,163 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+from typing import List, Optional
+
+import humps
+from hsml import util
+
+
+class InferenceEndpointPort:
+ """Port of an inference endpoint.
+
+ # Arguments
+ name: Name of the port. It typically defines the purpose of the port (e.g., HTTP, HTTPS, STATUS-PORT, TLS)
+ number: Port number.
+ # Returns
+ `InferenceEndpointPort`. Port of an inference endpoint.
+ """
+
+ def __init__(self, name: str, number: int, **kwargs):
+ self._name = name
+ self._number = number
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ return cls.from_json(json_decamelized)
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ return InferenceEndpointPort(**cls.extract_fields_from_json(json_decamelized))
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ kwargs = {}
+ kwargs["name"] = util.extract_field_from_json(json_decamelized, "name")
+ kwargs["number"] = util.extract_field_from_json(json_decamelized, "number")
+ return kwargs
+
+ def to_dict(self):
+ return {"name": self._name, "number": self._number}
+
+ @property
+ def name(self):
+ """Name of the inference endpoint port."""
+ return self._name
+
+ @property
+ def number(self):
+ """Port number of the inference endpoint port."""
+ return self._number
+
+ def __repr__(self):
+ return f"InferenceEndpointPort(name: {self._name!r})"
+
+
+class InferenceEndpoint:
+ """Inference endpoint available in the current project for model inference.
+
+ # Arguments
+ type: Type of inference endpoint (e.g., NODE, KUBE_CLUSTER, LOAD_BALANCER).
+ hosts: List of hosts of the inference endpoint.
+ ports: List of ports of the inference endpoint.
+ # Returns
+ `InferenceEndpoint`. Inference endpoint.
+ """
+
+ def __init__(
+ self,
+ type: str,
+ hosts: List[str],
+ ports: Optional[List[InferenceEndpointPort]],
+ ):
+ self._type = type
+ self._hosts = hosts
+ self._ports = ports
+
+ def get_any_host(self):
+ """Get any host available"""
+ return random.choice(self._hosts) if self._hosts is not None else None
+
+ def get_port(self, name):
+ """Get port by name"""
+ if self._ports is not None:
+ for port in self._ports:
+ if port.name == name:
+ return port
+ return None
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ if isinstance(json_decamelized, list):
+ if len(json_decamelized) == 0:
+ return []
+ return [cls.from_json(endpoint) for endpoint in json_decamelized]
+ else:
+ if "count" in json_decamelized:
+ if json_decamelized["count"] == 0:
+ return []
+ return [
+ cls.from_json(endpoint) for endpoint in json_decamelized["items"]
+ ]
+ return cls.from_json(json_decamelized)
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ return InferenceEndpoint(**cls.extract_fields_from_json(json_decamelized))
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ kwargs = {}
+ kwargs["type"] = util.extract_field_from_json(json_decamelized, "type")
+ kwargs["hosts"] = util.extract_field_from_json(json_decamelized, "hosts")
+ kwargs["ports"] = util.extract_field_from_json(
+ obj=json_decamelized, fields="ports", as_instance_of=InferenceEndpointPort
+ )
+ return kwargs
+
+ def to_dict(self):
+ return {
+ "type": self._type,
+ "hosts": self._hosts,
+ "ports": [port.to_dict() for port in self._ports],
+ }
+
+ @property
+ def type(self):
+ """Type of inference endpoint."""
+ return self._type
+
+ @property
+ def hosts(self):
+ """Hosts of the inference endpoint."""
+ return self._hosts
+
+ @property
+ def ports(self):
+ """Ports of the inference endpoint."""
+ return self._ports
+
+ def __repr__(self):
+ return f"InferenceEndpoint(type: {self._type!r})"
+
+
+def get_endpoint_by_type(endpoints, type) -> InferenceEndpoint:
+ for endpoint in endpoints:
+ if endpoint.type == type:
+ return endpoint
+ return None
diff --git a/hsml/python/hsml/inference_logger.py b/hsml/python/hsml/inference_logger.py
new file mode 100644
index 000000000..ef2f5c9ab
--- /dev/null
+++ b/hsml/python/hsml/inference_logger.py
@@ -0,0 +1,124 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Optional, Union
+
+import humps
+from hsml import util
+from hsml.constants import DEFAULT, INFERENCE_LOGGER
+from hsml.kafka_topic import KafkaTopic
+
+
+class InferenceLogger:
+ """Configuration of an inference logger for a predictor.
+
+ # Arguments
+ kafka_topic: Kafka topic to send the inference logs to. By default, a new Kafka topic is configured.
+ mode: Inference logging mode. (e.g., `NONE`, `ALL`, `PREDICTIONS`, or `MODEL_INPUTS`). By default, `ALL` inference logs are sent.
+ # Returns
+ `InferenceLogger`. Configuration of an inference logger.
+ """
+
+ def __init__(
+ self,
+ kafka_topic: Optional[Union[KafkaTopic, dict]] = DEFAULT,
+ mode: Optional[str] = INFERENCE_LOGGER.MODE_ALL,
+ **kwargs,
+ ):
+ self._kafka_topic = util.get_obj_from_json(kafka_topic, KafkaTopic)
+ self._mode = self._validate_mode(mode, self._kafka_topic) or (
+ INFERENCE_LOGGER.MODE_ALL
+ if self._kafka_topic is not None
+ else INFERENCE_LOGGER.MODE_NONE
+ )
+
+ def describe(self):
+ """Print a description of the inference logger"""
+ util.pretty_print(self)
+
+ @classmethod
+ def _validate_mode(cls, mode, kafka_topic):
+ if mode is not None:
+ modes = list(util.get_members(INFERENCE_LOGGER))
+ if mode not in modes:
+ raise ValueError(
+ "Inference logging mode '{}' is not valid. Possible values are '{}'".format(
+ mode, ", ".join(modes)
+ )
+ )
+
+ if kafka_topic is None and mode is not None:
+ mode = None
+ elif kafka_topic is not None and mode is None:
+ mode = INFERENCE_LOGGER.MODE_NONE
+
+ return mode
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ return cls.from_json(json_decamelized)
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ return InferenceLogger(**cls.extract_fields_from_json(json_decamelized))
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ kwargs = {}
+ kwargs["kafka_topic"] = util.extract_field_from_json(
+ json_decamelized,
+ ["kafka_topic_dto", "kafka_topic"],
+ )
+ kwargs["mode"] = util.extract_field_from_json(
+ json_decamelized, ["inference_logging", "mode"]
+ )
+ return kwargs
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ self.__init__(**self.extract_fields_from_json(json_decamelized))
+ return self
+
+ def json(self):
+ return json.dumps(self, cls=util.MLEncoder)
+
+ def to_dict(self):
+ json = {"inferenceLogging": self._mode}
+ if self._kafka_topic is not None:
+ return {**json, **self._kafka_topic.to_dict()}
+ return json
+
+ @property
+ def kafka_topic(self):
+ """Kafka topic to send the inference logs to."""
+ return self._kafka_topic
+
+ @kafka_topic.setter
+ def kafka_topic(self, kafka_topic: KafkaTopic):
+ self._kafka_topic = kafka_topic
+
+ @property
+ def mode(self):
+ """Inference logging mode ("NONE", "ALL", "PREDICTIONS", or "MODEL_INPUTS")."""
+ return self._mode
+
+ @mode.setter
+ def mode(self, mode: str):
+ self._mode = mode
+
+ def __repr__(self):
+ return f"InferenceLogger(mode: {self._mode!r})"
diff --git a/hsml/python/hsml/kafka_topic.py b/hsml/python/hsml/kafka_topic.py
new file mode 100644
index 000000000..9dce0bb56
--- /dev/null
+++ b/hsml/python/hsml/kafka_topic.py
@@ -0,0 +1,137 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Optional
+
+import humps
+from hsml import util
+from hsml.constants import KAFKA_TOPIC
+
+
+class KafkaTopic:
+ """Configuration for a Kafka topic."""
+
+ def __init__(
+ self,
+ name: str = KAFKA_TOPIC.CREATE,
+ num_replicas: Optional[int] = None,
+ num_partitions: Optional[int] = None,
+ **kwargs,
+ ):
+ self._name = name
+ self._num_replicas, self._num_partitions = self._validate_topic_config(
+ self._name, num_replicas, num_partitions
+ )
+
+ def describe(self):
+ util.pretty_print(self)
+
+ @classmethod
+ def _validate_topic_config(cls, name, num_replicas, num_partitions):
+ if name is not None and name != KAFKA_TOPIC.NONE:
+ if name == KAFKA_TOPIC.CREATE:
+ if num_replicas is None:
+ print(
+ "Setting number of replicas to default value '{}'".format(
+ KAFKA_TOPIC.NUM_REPLICAS
+ )
+ )
+ num_replicas = KAFKA_TOPIC.NUM_REPLICAS
+ if num_partitions is None:
+ print(
+ "Setting number of partitions to default value '{}'".format(
+ KAFKA_TOPIC.NUM_PARTITIONS
+ )
+ )
+ num_partitions = KAFKA_TOPIC.NUM_PARTITIONS
+ else:
+ if num_replicas is not None or num_partitions is not None:
+ raise ValueError(
+ "Number of replicas or partitions cannot be changed in existing kafka topics."
+ )
+ elif name is None or name == KAFKA_TOPIC.NONE:
+ num_replicas = None
+ num_partitions = None
+
+ return num_replicas, num_partitions
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ return cls.from_json(json_decamelized)
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ return KafkaTopic(**cls.extract_fields_from_json(json_decamelized))
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ kwargs = {}
+ kwargs["name"] = json_decamelized.pop("name") # required
+ kwargs["num_replicas"] = util.extract_field_from_json(
+ json_decamelized, ["num_of_replicas", "num_replicas"]
+ )
+ kwargs["num_partitions"] = util.extract_field_from_json(
+ json_decamelized, ["num_of_partitions", "num_partitions"]
+ )
+ return kwargs
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ self.__init__(**self.extract_fields_from_json(json_decamelized))
+ return self
+
+ def json(self):
+ return json.dumps(self, cls=util.MLEncoder)
+
+ def to_dict(self):
+ return {
+ "kafkaTopicDTO": {
+ "name": self._name,
+ "numOfReplicas": self._num_replicas,
+ "numOfPartitions": self._num_partitions,
+ }
+ }
+
+ @property
+ def name(self):
+ """Name of the Kafka topic."""
+ return self._name
+
+ @name.setter
+ def name(self, name: str):
+ self._name = name
+
+ @property
+ def num_replicas(self):
+ """Number of replicas of the Kafka topic."""
+ return self._num_replicas
+
+ @num_replicas.setter
+ def num_replicas(self, num_replicas: int):
+ self._num_replicas = num_replicas
+
+ @property
+ def num_partitions(self):
+ """Number of partitions of the Kafka topic."""
+ return self._num_partitions
+
+ @num_partitions.setter
+ def topic_num_partitions(self, num_partitions: int):
+ self._num_partitions = num_partitions
+
+ def __repr__(self):
+ return f"KafkaTopic({self._name!r})"
diff --git a/hsml/python/hsml/model.py b/hsml/python/hsml/model.py
new file mode 100644
index 000000000..2d63a7eef
--- /dev/null
+++ b/hsml/python/hsml/model.py
@@ -0,0 +1,572 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import logging
+import os
+import warnings
+from typing import Any, Dict, Optional, Union
+
+import humps
+from hsml import client, util
+from hsml.constants import ARTIFACT_VERSION
+from hsml.constants import INFERENCE_ENDPOINTS as IE
+from hsml.core import explicit_provenance
+from hsml.engine import model_engine
+from hsml.inference_batcher import InferenceBatcher
+from hsml.inference_logger import InferenceLogger
+from hsml.predictor import Predictor
+from hsml.resources import PredictorResources
+from hsml.transformer import Transformer
+
+
+_logger = logging.getLogger(__name__)
+
+
+class Model:
+ """Metadata object representing a model in the Model Registry."""
+
+ def __init__(
+ self,
+ id,
+ name,
+ version=None,
+ created=None,
+ creator=None,
+ environment=None,
+ description=None,
+ experiment_id=None,
+ project_name=None,
+ experiment_project_name=None,
+ metrics=None,
+ program=None,
+ user_full_name=None,
+ model_schema=None,
+ training_dataset=None,
+ input_example=None,
+ framework=None,
+ model_registry_id=None,
+ # unused, but needed since they come in the backend response
+ tags=None,
+ href=None,
+ feature_view=None,
+ training_dataset_version=None,
+ **kwargs,
+ ):
+ self._id = id
+ self._name = name
+ self._version = version
+
+ if description is None:
+ self._description = "A collection of models for " + name
+ else:
+ self._description = description
+
+ self._created = created
+ self._creator = creator
+ self._environment = environment
+ self._experiment_id = experiment_id
+ self._project_name = project_name
+ self._experiment_project_name = experiment_project_name
+ self._training_metrics = metrics
+ self._program = program
+ self._user_full_name = user_full_name
+ self._input_example = input_example
+ self._framework = framework
+ self._model_schema = model_schema
+ self._training_dataset = training_dataset
+
+ # This is needed for update_from_response_json function to not overwrite name of the shared registry this model originates from
+ if not hasattr(self, "_shared_registry_project_name"):
+ self._shared_registry_project_name = None
+
+ self._model_registry_id = model_registry_id
+
+ self._model_engine = model_engine.ModelEngine()
+ self._feature_view = feature_view
+ self._training_dataset_version = training_dataset_version
+ if training_dataset_version is None and feature_view is not None:
+ if feature_view.get_last_accessed_training_dataset() is not None:
+ self._training_dataset_version = (
+ feature_view.get_last_accessed_training_dataset()
+ )
+ else:
+ warnings.warn(
+ "Provenance cached data - feature view provided, but training dataset version is missing",
+ util.ProvenanceWarning,
+ stacklevel=1,
+ )
+
+ def save(
+ self,
+ model_path,
+ await_registration=480,
+ keep_original_files=False,
+ upload_configuration: Optional[Dict[str, Any]] = None,
+ ):
+ """Persist this model including model files and metadata to the model registry.
+
+ # Arguments
+ model_path: Local or remote (Hopsworks file system) path to the folder where the model files are located, or path to a specific model file.
+ await_registration: Awaiting time for the model to be registered in Hopsworks.
+ keep_original_files: If the model files are located in hopsfs, whether to move or copy those files into the Models dataset. Default is False (i.e., model files will be moved)
+ upload_configuration: When saving a model from outside Hopsworks, the model is uploaded to the model registry using the REST APIs. Each model artifact is divided into
+ chunks and each chunk uploaded independently. This parameter can be used to control the upload chunk size, the parallelism and the number of retries.
+ `upload_configuration` can contain the following keys:
+ * key `chunk_size`: size of each chunk in megabytes. Default 10.
+ * key `simultaneous_uploads`: number of chunks to upload in parallel. Default 3.
+ * key `max_chunk_retries`: number of times to retry the upload of a chunk in case of failure. Default 1.
+
+ # Returns
+ `Model`: The model metadata object.
+ """
+ return self._model_engine.save(
+ model_instance=self,
+ model_path=model_path,
+ await_registration=await_registration,
+ keep_original_files=keep_original_files,
+ upload_configuration=upload_configuration,
+ )
+
+ def download(self):
+ """Download the model files.
+
+ # Returns
+ `str`: Absolute path to local folder containing the model files.
+ """
+ return self._model_engine.download(model_instance=self)
+
+ def delete(self):
+ """Delete the model
+
+ !!! danger "Potentially dangerous operation"
+ This operation drops all metadata associated with **this version** of the
+ model **and** deletes the model files.
+
+ # Raises
+ `RestAPIError`.
+ """
+ self._model_engine.delete(model_instance=self)
+
+ def deploy(
+ self,
+ name: Optional[str] = None,
+ description: Optional[str] = None,
+ artifact_version: Optional[str] = ARTIFACT_VERSION.CREATE,
+ serving_tool: Optional[str] = None,
+ script_file: Optional[str] = None,
+ resources: Optional[Union[PredictorResources, dict]] = None,
+ inference_logger: Optional[Union[InferenceLogger, dict]] = None,
+ inference_batcher: Optional[Union[InferenceBatcher, dict]] = None,
+ transformer: Optional[Union[Transformer, dict]] = None,
+ api_protocol: Optional[str] = IE.API_PROTOCOL_REST,
+ ):
+ """Deploy the model.
+
+ !!! example
+ ```python
+
+ import hopsworks
+
+ project = hopsworks.login()
+
+ # get Hopsworks Model Registry handle
+ mr = project.get_model_registry()
+
+ # retrieve the trained model you want to deploy
+ my_model = mr.get_model("my_model", version=1)
+
+ my_deployment = my_model.deploy()
+ ```
+ # Arguments
+ name: Name of the deployment.
+ description: Description of the deployment.
+ artifact_version: Version number of the model artifact to deploy, `CREATE` to create a new model artifact
+ or `MODEL-ONLY` to reuse the shared artifact containing only the model files.
+ serving_tool: Serving tool used to deploy the model server.
+ script_file: Path to a custom predictor script implementing the Predict class.
+ resources: Resources to be allocated for the predictor.
+ inference_logger: Inference logger configuration.
+ inference_batcher: Inference batcher configuration.
+ transformer: Transformer to be deployed together with the predictor.
+ api_protocol: API protocol to be enabled in the deployment (i.e., 'REST' or 'GRPC'). Defaults to 'REST'.
+
+ # Returns
+ `Deployment`: The deployment metadata object of a new or existing deployment.
+ """
+
+ if name is None:
+ name = self._name
+
+ predictor = Predictor.for_model(
+ self,
+ name=name,
+ description=description,
+ artifact_version=artifact_version,
+ serving_tool=serving_tool,
+ script_file=script_file,
+ resources=resources,
+ inference_logger=inference_logger,
+ inference_batcher=inference_batcher,
+ transformer=transformer,
+ api_protocol=api_protocol,
+ )
+
+ return predictor.deploy()
+
+ def set_tag(self, name: str, value: Union[str, dict]):
+ """Attach a tag to a model.
+
+ A tag consists of a pair. Tag names are unique identifiers across the whole cluster.
+ The value of a tag can be any valid json - primitives, arrays or json objects.
+
+ # Arguments
+ name: Name of the tag to be added.
+ value: Value of the tag to be added.
+ # Raises
+ `RestAPIError` in case the backend fails to add the tag.
+ """
+
+ self._model_engine.set_tag(model_instance=self, name=name, value=value)
+
+ def delete_tag(self, name: str):
+ """Delete a tag attached to a model.
+
+ # Arguments
+ name: Name of the tag to be removed.
+ # Raises
+ `RestAPIError` in case the backend fails to delete the tag.
+ """
+ self._model_engine.delete_tag(model_instance=self, name=name)
+
+ def get_tag(self, name: str):
+ """Get the tags of a model.
+
+ # Arguments
+ name: Name of the tag to get.
+ # Returns
+ tag value
+ # Raises
+ `RestAPIError` in case the backend fails to retrieve the tag.
+ """
+ return self._model_engine.get_tag(model_instance=self, name=name)
+
+ def get_tags(self):
+ """Retrieves all tags attached to a model.
+
+ # Returns
+ `Dict[str, obj]` of tags.
+ # Raises
+ `RestAPIError` in case the backend fails to retrieve the tags.
+ """
+ return self._model_engine.get_tags(model_instance=self)
+
+ def get_url(self):
+ path = (
+ "/p/"
+ + str(client.get_instance()._project_id)
+ + "/models/"
+ + str(self.name)
+ + "/"
+ + str(self.version)
+ )
+ return util.get_hostname_replaced_url(sub_path=path)
+
+ def get_feature_view(self, init: bool = True, online: Optional[bool] = None):
+ """Get the parent feature view of this model, based on explicit provenance.
+ Only accessible, usable feature view objects are returned. Otherwise an Exception is raised.
+ For more details, call the base method - get_feature_view_provenance
+
+ # Returns
+ `FeatureView`: Feature View Object.
+ # Raises
+ `Exception` in case the backend fails to retrieve the tags.
+ """
+ fv_prov = self.get_feature_view_provenance()
+ fv = explicit_provenance.Links.get_one_accessible_parent(fv_prov)
+ if fv is None:
+ return None
+ if init:
+ td_prov = self.get_training_dataset_provenance()
+ td = explicit_provenance.Links.get_one_accessible_parent(td_prov)
+ is_deployment = "DEPLOYMENT_NAME" in os.environ
+ if online or is_deployment:
+ _logger.info(
+ "Initializing for batch and online retrieval of feature vectors"
+ + (" - within a deployment" if is_deployment else "")
+ )
+ fv.init_serving(training_dataset_version=td.version)
+ elif online is False:
+ _logger.info("Initializing for batch retrieval of feature vectors")
+ fv.init_batch_scoring(training_dataset_version=td.version)
+ return fv
+
+ def get_feature_view_provenance(self):
+ """Get the parent feature view of this model, based on explicit provenance.
+ This feature view can be accessible, deleted or inaccessible.
+ For deleted and inaccessible feature views, only a minimal information is
+ returned.
+
+ # Returns
+ `ProvenanceLinks`: Object containing the section of provenance graph requested.
+ """
+ return self._model_engine.get_feature_view_provenance(model_instance=self)
+
+ def get_training_dataset_provenance(self):
+ """Get the parent training dataset of this model, based on explicit provenance.
+ This training dataset can be accessible, deleted or inaccessible.
+ For deleted and inaccessible training datasets, only a minimal information is
+ returned.
+
+ # Returns
+ `ProvenanceLinks`: Object containing the section of provenance graph requested.
+ """
+ return self._model_engine.get_training_dataset_provenance(model_instance=self)
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ if "count" in json_decamelized:
+ if json_decamelized["count"] == 0:
+ return []
+ return [util.set_model_class(model) for model in json_decamelized["items"]]
+ else:
+ return util.set_model_class(json_decamelized)
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ if "type" in json_decamelized: # backwards compatibility
+ _ = json_decamelized.pop("type")
+ self.__init__(**json_decamelized)
+ return self
+
+ def json(self):
+ return json.dumps(self, cls=util.MLEncoder)
+
+ def to_dict(self):
+ return {
+ "id": self._name + "_" + str(self._version),
+ "experimentId": self._experiment_id,
+ "projectName": self._project_name,
+ "experimentProjectName": self._experiment_project_name,
+ "name": self._name,
+ "modelSchema": self._model_schema,
+ "version": self._version,
+ "description": self._description,
+ "inputExample": self._input_example,
+ "framework": self._framework,
+ "metrics": self._training_metrics,
+ "trainingDataset": self._training_dataset,
+ "environment": self._environment,
+ "program": self._program,
+ "featureView": util.feature_view_to_json(self._feature_view),
+ "trainingDatasetVersion": self._training_dataset_version,
+ }
+
+ @property
+ def id(self):
+ """Id of the model."""
+ return self._id
+
+ @id.setter
+ def id(self, id):
+ self._id = id
+
+ @property
+ def name(self):
+ """Name of the model."""
+ return self._name
+
+ @name.setter
+ def name(self, name):
+ self._name = name
+
+ @property
+ def version(self):
+ """Version of the model."""
+ return self._version
+
+ @version.setter
+ def version(self, version):
+ self._version = version
+
+ @property
+ def description(self):
+ """Description of the model."""
+ return self._description
+
+ @description.setter
+ def description(self, description):
+ self._description = description
+
+ @property
+ def created(self):
+ """Creation date of the model."""
+ return self._created
+
+ @created.setter
+ def created(self, created):
+ self._created = created
+
+ @property
+ def creator(self):
+ """Creator of the model."""
+ return self._creator
+
+ @creator.setter
+ def creator(self, creator):
+ self._creator = creator
+
+ @property
+ def environment(self):
+ """Input example of the model."""
+ if self._environment is not None:
+ return self._model_engine.read_file(
+ model_instance=self, resource="environment.yml"
+ )
+ return self._environment
+
+ @environment.setter
+ def environment(self, environment):
+ self._environment = environment
+
+ @property
+ def experiment_id(self):
+ """Experiment Id of the model."""
+ return self._experiment_id
+
+ @experiment_id.setter
+ def experiment_id(self, experiment_id):
+ self._experiment_id = experiment_id
+
+ @property
+ def training_metrics(self):
+ """Training metrics of the model."""
+ return self._training_metrics
+
+ @training_metrics.setter
+ def training_metrics(self, training_metrics):
+ self._training_metrics = training_metrics
+
+ @property
+ def program(self):
+ """Executable used to export the model."""
+ if self._program is not None:
+ return self._model_engine.read_file(
+ model_instance=self, resource=self._program
+ )
+
+ @program.setter
+ def program(self, program):
+ self._program = program
+
+ @property
+ def user(self):
+ """user of the model."""
+ return self._user_full_name
+
+ @user.setter
+ def user(self, user_full_name):
+ self._user_full_name = user_full_name
+
+ @property
+ def input_example(self):
+ """input_example of the model."""
+ return self._model_engine.read_json(
+ model_instance=self, resource="input_example.json"
+ )
+
+ @input_example.setter
+ def input_example(self, input_example):
+ self._input_example = input_example
+
+ @property
+ def framework(self):
+ """framework of the model."""
+ return self._framework
+
+ @framework.setter
+ def framework(self, framework):
+ self._framework = framework
+
+ @property
+ def model_schema(self):
+ """model schema of the model."""
+ return self._model_engine.read_json(
+ model_instance=self, resource="model_schema.json"
+ )
+
+ @model_schema.setter
+ def model_schema(self, model_schema):
+ self._model_schema = model_schema
+
+ @property
+ def training_dataset(self):
+ """training_dataset of the model."""
+ return self._training_dataset
+
+ @training_dataset.setter
+ def training_dataset(self, training_dataset):
+ self._training_dataset = training_dataset
+
+ @property
+ def project_name(self):
+ """project_name of the model."""
+ return self._project_name
+
+ @project_name.setter
+ def project_name(self, project_name):
+ self._project_name = project_name
+
+ @property
+ def model_registry_id(self):
+ """model_registry_id of the model."""
+ return self._model_registry_id
+
+ @model_registry_id.setter
+ def model_registry_id(self, model_registry_id):
+ self._model_registry_id = model_registry_id
+
+ @property
+ def experiment_project_name(self):
+ """experiment_project_name of the model."""
+ return self._experiment_project_name
+
+ @experiment_project_name.setter
+ def experiment_project_name(self, experiment_project_name):
+ self._experiment_project_name = experiment_project_name
+
+ @property
+ def model_path(self):
+ """path of the model with version folder omitted. Resolves to /Projects/{project_name}/Models/{name}"""
+ return "/Projects/{}/Models/{}".format(self.project_name, self.name)
+
+ @property
+ def version_path(self):
+ """path of the model including version folder. Resolves to /Projects/{project_name}/Models/{name}/{version}"""
+ return "{}/{}".format(self.model_path, str(self.version))
+
+ @property
+ def shared_registry_project_name(self):
+ """shared_registry_project_name of the model."""
+ return self._shared_registry_project_name
+
+ @shared_registry_project_name.setter
+ def shared_registry_project_name(self, shared_registry_project_name):
+ self._shared_registry_project_name = shared_registry_project_name
+
+ def __repr__(self):
+ return f"Model(name: {self._name!r}, version: {self._version!r})"
diff --git a/hsml/python/hsml/model_registry.py b/hsml/python/hsml/model_registry.py
new file mode 100644
index 000000000..4a7f3443b
--- /dev/null
+++ b/hsml/python/hsml/model_registry.py
@@ -0,0 +1,196 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import warnings
+
+import humps
+from hsml import util
+from hsml.core import model_api
+from hsml.python import signature as python_signature # noqa: F401
+from hsml.sklearn import signature as sklearn_signature # noqa: F401
+from hsml.tensorflow import signature as tensorflow_signature # noqa: F401
+from hsml.torch import signature as torch_signature # noqa: F401
+
+
+class ModelRegistry:
+ DEFAULT_VERSION = 1
+
+ def __init__(
+ self,
+ project_name,
+ project_id,
+ model_registry_id,
+ shared_registry_project_name=None,
+ **kwargs,
+ ):
+ self._project_name = project_name
+ self._project_id = project_id
+
+ self._shared_registry_project_name = shared_registry_project_name
+ self._model_registry_id = model_registry_id
+
+ self._model_api = model_api.ModelApi()
+
+ self._tensorflow = tensorflow_signature
+ self._python = python_signature
+ self._sklearn = sklearn_signature
+ self._torch = torch_signature
+
+ tensorflow_signature._mr = self
+ python_signature._mr = self
+ sklearn_signature._mr = self
+ torch_signature._mr = self
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ return cls(**json_decamelized)
+
+ def get_model(self, name: str, version: int = None):
+ """Get a model entity from the model registry.
+ Getting a model from the Model Registry means getting its metadata handle
+ so you can subsequently download the model directory.
+
+ # Arguments
+ name: Name of the model to get.
+ version: Version of the model to retrieve, defaults to `None` and will
+ return the `version=1`.
+ # Returns
+ `Model`: The model metadata object.
+ # Raises
+ `RestAPIError`: If unable to retrieve model from the model registry.
+ """
+
+ if version is None:
+ warnings.warn(
+ "No version provided for getting model `{}`, defaulting to `{}`.".format(
+ name, self.DEFAULT_VERSION
+ ),
+ util.VersionWarning,
+ stacklevel=1,
+ )
+ version = self.DEFAULT_VERSION
+
+ return self._model_api.get(
+ name,
+ version,
+ self.model_registry_id,
+ shared_registry_project_name=self.shared_registry_project_name,
+ )
+
+ def get_models(self, name: str):
+ """Get all model entities from the model registry for a specified name.
+ Getting all models from the Model Registry for a given name returns a list of model entities, one for each version registered under
+ the specified model name.
+
+ # Arguments
+ name: Name of the model to get.
+ # Returns
+ `List[Model]`: A list of model metadata objects.
+ # Raises
+ `RestAPIError`: If unable to retrieve model versions from the model registry.
+ """
+
+ return self._model_api.get_models(
+ name,
+ self.model_registry_id,
+ shared_registry_project_name=self.shared_registry_project_name,
+ )
+
+ def get_best_model(self, name: str, metric: str, direction: str):
+ """Get the best performing model entity from the model registry.
+ Getting the best performing model from the Model Registry means specifying in addition to the name, also a metric
+ name corresponding to one of the keys in the training_metrics dict of the model and a direction. For example to
+ get the model version with the highest accuracy, specify metric='accuracy' and direction='max'.
+
+ # Arguments
+ name: Name of the model to get.
+ metric: Name of the key in the training metrics field to compare.
+ direction: 'max' to get the model entity with the highest value of the set metric, or 'min' for the lowest.
+ # Returns
+ `Model`: The model metadata object.
+ # Raises
+ `RestAPIError`: If unable to retrieve model from the model registry.
+ """
+
+ model = self._model_api.get_models(
+ name,
+ self.model_registry_id,
+ shared_registry_project_name=self.shared_registry_project_name,
+ metric=metric,
+ direction=direction,
+ )
+ if isinstance(model, list) and len(model) > 0:
+ return model[0]
+ else:
+ return None
+
+ @property
+ def project_name(self):
+ """Name of the project the registry is connected to."""
+ return self._project_name
+
+ @property
+ def project_path(self):
+ """Path of the project the registry is connected to."""
+ return "/Projects/{}".format(self._project_name)
+
+ @property
+ def project_id(self):
+ """Id of the project the registry is connected to."""
+ return self._project_id
+
+ @property
+ def shared_registry_project_name(self):
+ """Name of the project the shared model registry originates from."""
+ return self._shared_registry_project_name
+
+ @property
+ def model_registry_id(self):
+ """Id of the model registry."""
+ return self._model_registry_id
+
+ @property
+ def tensorflow(self):
+ """Module for exporting a TensorFlow model."""
+
+ return tensorflow_signature
+
+ @property
+ def sklearn(self):
+ """Module for exporting a sklearn model."""
+
+ return sklearn_signature
+
+ @property
+ def torch(self):
+ """Module for exporting a torch model."""
+
+ return torch_signature
+
+ @property
+ def python(self):
+ """Module for exporting a generic Python model."""
+
+ return python_signature
+
+ def __repr__(self):
+ project_name = (
+ self._shared_registry_project_name
+ if self._shared_registry_project_name is not None
+ else self._project_name
+ )
+ return f"ModelRegistry(project: {project_name!r})"
diff --git a/hsml/python/hsml/model_schema.py b/hsml/python/hsml/model_schema.py
new file mode 100644
index 000000000..7af3999ca
--- /dev/null
+++ b/hsml/python/hsml/model_schema.py
@@ -0,0 +1,64 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+from typing import Optional
+
+from hsml.schema import Schema
+
+
+class ModelSchema:
+ """Create a schema for a model.
+
+ # Arguments
+ input_schema: Schema to describe the inputs.
+ output_schema: Schema to describe the outputs.
+
+ # Returns
+ `ModelSchema`. The model schema object.
+ """
+
+ def __init__(
+ self,
+ input_schema: Optional[Schema] = None,
+ output_schema: Optional[Schema] = None,
+ **kwargs,
+ ):
+ if input_schema is not None:
+ self.input_schema = input_schema
+
+ if output_schema is not None:
+ self.output_schema = output_schema
+
+ def json(self):
+ return json.dumps(
+ self, default=lambda o: getattr(o, "__dict__", o), sort_keys=True, indent=2
+ )
+
+ def to_dict(self):
+ """
+ Get dict representation of the ModelSchema.
+ """
+ return json.loads(self.json())
+
+ def __repr__(self):
+ input_type = (
+ self.input_schema._get_type() if hasattr(self, "input_schema") else None
+ )
+ output_type = (
+ self.output_schema._get_type() if hasattr(self, "output_schema") else None
+ )
+ return f"ModelSchema(input: {input_type!r}, output: {output_type!r})"
diff --git a/hsml/python/hsml/model_serving.py b/hsml/python/hsml/model_serving.py
new file mode 100644
index 000000000..a256fdc13
--- /dev/null
+++ b/hsml/python/hsml/model_serving.py
@@ -0,0 +1,374 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from typing import Optional, Union
+
+from hsml import util
+from hsml.constants import ARTIFACT_VERSION, PREDICTOR_STATE
+from hsml.constants import INFERENCE_ENDPOINTS as IE
+from hsml.core import serving_api
+from hsml.deployment import Deployment
+from hsml.inference_batcher import InferenceBatcher
+from hsml.inference_logger import InferenceLogger
+from hsml.model import Model
+from hsml.predictor import Predictor
+from hsml.resources import PredictorResources
+from hsml.transformer import Transformer
+
+
+class ModelServing:
+ DEFAULT_VERSION = 1
+
+ def __init__(
+ self,
+ project_name: str,
+ project_id: int,
+ **kwargs,
+ ):
+ self._project_name = project_name
+ self._project_id = project_id
+
+ self._serving_api = serving_api.ServingApi()
+
+ def get_deployment_by_id(self, id: int):
+ """Get a deployment by id from Model Serving.
+ Getting a deployment from Model Serving means getting its metadata handle
+ so you can subsequently operate on it (e.g., start or stop).
+
+ !!! example
+ ```python
+ # login and get Hopsworks Model Serving handle using .login() and .get_model_serving()
+
+ # get a deployment by id
+ my_deployment = ms.get_deployment_by_id(1)
+ ```
+
+ # Arguments
+ id: Id of the deployment to get.
+ # Returns
+ `Deployment`: The deployment metadata object.
+ # Raises
+ `RestAPIError`: If unable to retrieve deployment from model serving.
+ """
+
+ return self._serving_api.get_by_id(id)
+
+ def get_deployment(self, name: str = None):
+ """Get a deployment by name from Model Serving.
+
+ !!! example
+ ```python
+ # login and get Hopsworks Model Serving handle using .login() and .get_model_serving()
+
+ # get a deployment by name
+ my_deployment = ms.get_deployment('deployment_name')
+ ```
+
+ Getting a deployment from Model Serving means getting its metadata handle
+ so you can subsequently operate on it (e.g., start or stop).
+
+ # Arguments
+ name: Name of the deployment to get.
+ # Returns
+ `Deployment`: The deployment metadata object.
+ # Raises
+ `RestAPIError`: If unable to retrieve deployment from model serving.
+ """
+
+ if name is None and ("DEPLOYMENT_NAME" in os.environ):
+ name = os.environ["DEPLOYMENT_NAME"]
+ return self._serving_api.get(name)
+
+ def get_deployments(self, model: Model = None, status: str = None):
+ """Get all deployments from model serving.
+ !!! example
+ ```python
+ # login into Hopsworks using hopsworks.login()
+
+ # get Hopsworks Model Registry handle
+ mr = project.get_model_registry()
+
+ # get Hopsworks Model Serving handle
+ ms = project.get_model_serving()
+
+ # retrieve the trained model you want to deploy
+ my_model = mr.get_model("my_model", version=1)
+
+ list_deployments = ms.get_deployment(my_model)
+
+ for deployment in list_deployments:
+ print(deployment.get_state())
+ ```
+ # Arguments
+ model: Filter by model served in the deployments
+ status: Filter by status of the deployments
+ # Returns
+ `List[Deployment]`: A list of deployments.
+ # Raises
+ `RestAPIError`: If unable to retrieve deployments from model serving.
+ """
+
+ model_name = model.name if model is not None else None
+ if status is not None:
+ self._validate_deployment_status(status)
+
+ return self._serving_api.get_all(model_name, status)
+
+ def _validate_deployment_status(self, status):
+ statuses = list(util.get_members(PREDICTOR_STATE, prefix="STATUS"))
+ status = status.upper()
+ if status not in statuses:
+ raise ValueError(
+ "Deployment status '{}' is not valid. Possible values are '{}'".format(
+ status, ", ".join(statuses)
+ )
+ )
+ return status
+
+ def get_inference_endpoints(self):
+ """Get all inference endpoints available in the current project.
+
+ # Returns
+ `List[InferenceEndpoint]`: Inference endpoints for model inference
+ """
+
+ return self._serving_api.get_inference_endpoints()
+
+ def create_predictor(
+ self,
+ model: Model,
+ name: Optional[str] = None,
+ artifact_version: Optional[str] = ARTIFACT_VERSION.CREATE,
+ serving_tool: Optional[str] = None,
+ script_file: Optional[str] = None,
+ resources: Optional[Union[PredictorResources, dict]] = None,
+ inference_logger: Optional[Union[InferenceLogger, dict, str]] = None,
+ inference_batcher: Optional[Union[InferenceBatcher, dict]] = None,
+ transformer: Optional[Union[Transformer, dict]] = None,
+ api_protocol: Optional[str] = IE.API_PROTOCOL_REST,
+ ):
+ """Create a Predictor metadata object.
+
+ !!! example
+ ```python
+ # login into Hopsworks using hopsworks.login()
+
+ # get Hopsworks Model Registry handle
+ mr = project.get_model_registry()
+
+ # retrieve the trained model you want to deploy
+ my_model = mr.get_model("my_model", version=1)
+
+ # get Hopsworks Model Serving handle
+ ms = project.get_model_serving()
+
+ my_predictor = ms.create_predictor(my_model)
+
+ my_deployment = my_predictor.deploy()
+ ```
+
+ !!! note "Lazy"
+ This method is lazy and does not persist any metadata or deploy any model on its own.
+ To create a deployment using this predictor, call the `deploy()` method.
+
+ # Arguments
+ model: Model to be deployed.
+ name: Name of the predictor.
+ artifact_version: Version number of the model artifact to deploy, `CREATE` to create a new model artifact
+ or `MODEL-ONLY` to reuse the shared artifact containing only the model files.
+ serving_tool: Serving tool used to deploy the model server.
+ script_file: Path to a custom predictor script implementing the Predict class.
+ resources: Resources to be allocated for the predictor.
+ inference_logger: Inference logger configuration.
+ inference_batcher: Inference batcher configuration.
+ transformer: Transformer to be deployed together with the predictor.
+ api_protocol: API protocol to be enabled in the deployment (i.e., 'REST' or 'GRPC'). Defaults to 'REST'.
+
+ # Returns
+ `Predictor`. The predictor metadata object.
+ """
+
+ if name is None:
+ name = model.name
+
+ return Predictor.for_model(
+ model,
+ name=name,
+ artifact_version=artifact_version,
+ serving_tool=serving_tool,
+ script_file=script_file,
+ resources=resources,
+ inference_logger=inference_logger,
+ inference_batcher=inference_batcher,
+ transformer=transformer,
+ api_protocol=api_protocol,
+ )
+
+ def create_transformer(
+ self,
+ script_file: Optional[str] = None,
+ resources: Optional[Union[PredictorResources, dict]] = None,
+ ):
+ """Create a Transformer metadata object.
+
+ !!! example
+ ```python
+ # login into Hopsworks using hopsworks.login()
+
+ # get Dataset API instance
+ dataset_api = project.get_dataset_api()
+
+ # get Hopsworks Model Serving handle
+ ms = project.get_model_serving()
+
+ # create my_transformer.py Python script
+ class Transformer(object):
+ def __init__(self):
+ ''' Initialization code goes here '''
+ pass
+
+ def preprocess(self, inputs):
+ ''' Transform the requests inputs here. The object returned by this method will be used as model input to make predictions. '''
+ return inputs
+
+ def postprocess(self, outputs):
+ ''' Transform the predictions computed by the model before returning a response '''
+ return outputs
+
+ uploaded_file_path = dataset_api.upload("my_transformer.py", "Resources", overwrite=True)
+ transformer_script_path = os.path.join("/Projects", project.name, uploaded_file_path)
+
+ my_transformer = ms.create_transformer(script_file=uploaded_file_path)
+
+ # or
+
+ from hsml.transformer import Transformer
+
+ my_transformer = Transformer(script_file)
+ ```
+
+ !!! example "Create a deployment with the transformer"
+ ```python
+
+ my_predictor = ms.create_predictor(transformer=my_transformer)
+ my_deployment = my_predictor.deploy()
+
+ # or
+ my_deployment = ms.create_deployment(my_predictor, transformer=my_transformer)
+ my_deployment.save()
+ ```
+
+ !!! note "Lazy"
+ This method is lazy and does not persist any metadata or deploy any transformer. To create a deployment using this transformer, set it in the `predictor.transformer` property.
+
+ # Arguments
+ script_file: Path to a custom predictor script implementing the Transformer class.
+ resources: Resources to be allocated for the transformer.
+
+ # Returns
+ `Transformer`. The model metadata object.
+ """
+
+ return Transformer(script_file=script_file, resources=resources)
+
+ def create_deployment(self, predictor: Predictor, name: Optional[str] = None):
+ """Create a Deployment metadata object.
+
+ !!! example
+ ```python
+ # login into Hopsworks using hopsworks.login()
+
+ # get Hopsworks Model Registry handle
+ mr = project.get_model_registry()
+
+ # retrieve the trained model you want to deploy
+ my_model = mr.get_model("my_model", version=1)
+
+ # get Hopsworks Model Serving handle
+ ms = project.get_model_serving()
+
+ my_predictor = ms.create_predictor(my_model)
+
+ my_deployment = ms.create_deployment(my_predictor)
+ my_deployment.save()
+ ```
+
+ !!! example "Using the model object"
+ ```python
+ # login into Hopsworks using hopsworks.login()
+
+ # get Hopsworks Model Registry handle
+ mr = project.get_model_registry()
+
+ # retrieve the trained model you want to deploy
+ my_model = mr.get_model("my_model", version=1)
+
+ my_deployment = my_model.deploy()
+
+ my_deployment.get_state().describe()
+ ```
+
+ !!! example "Using the Model Serving handle"
+ ```python
+ # login into Hopsworks using hopsworks.login()
+
+ # get Hopsworks Model Registry handle
+ mr = project.get_model_registry()
+
+ # retrieve the trained model you want to deploy
+ my_model = mr.get_model("my_model", version=1)
+
+ # get Hopsworks Model Serving handle
+ ms = project.get_model_serving()
+
+ my_predictor = ms.create_predictor(my_model)
+
+ my_deployment = my_predictor.deploy()
+
+ my_deployment.get_state().describe()
+ ```
+
+ !!! note "Lazy"
+ This method is lazy and does not persist any metadata or deploy any model. To create a deployment, call the `save()` method.
+
+ # Arguments
+ predictor: predictor to be used in the deployment
+ name: name of the deployment
+
+ # Returns
+ `Deployment`. The model metadata object.
+ """
+
+ return Deployment(predictor=predictor, name=name)
+
+ @property
+ def project_name(self):
+ """Name of the project in which Model Serving is located."""
+ return self._project_name
+
+ @property
+ def project_path(self):
+ """Path of the project the registry is connected to."""
+ return "/Projects/{}".format(self._project_name)
+
+ @property
+ def project_id(self):
+ """Id of the project in which Model Serving is located."""
+ return self._project_id
+
+ def __repr__(self):
+ return f"ModelServing(project: {self._project_name!r})"
diff --git a/hsml/python/hsml/predictor.py b/hsml/python/hsml/predictor.py
new file mode 100644
index 000000000..10cc29f41
--- /dev/null
+++ b/hsml/python/hsml/predictor.py
@@ -0,0 +1,466 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Optional, Union
+
+import humps
+from hsml import client, deployment, util
+from hsml.constants import (
+ ARTIFACT_VERSION,
+ INFERENCE_ENDPOINTS,
+ MODEL,
+ PREDICTOR,
+ RESOURCES,
+)
+from hsml.deployable_component import DeployableComponent
+from hsml.inference_batcher import InferenceBatcher
+from hsml.inference_logger import InferenceLogger
+from hsml.predictor_state import PredictorState
+from hsml.resources import PredictorResources
+from hsml.transformer import Transformer
+
+
+class Predictor(DeployableComponent):
+ """Metadata object representing a predictor in Model Serving."""
+
+ def __init__(
+ self,
+ name: str,
+ model_name: str,
+ model_path: str,
+ model_version: int,
+ model_framework: str, # MODEL.FRAMEWORK
+ artifact_version: Union[int, str],
+ model_server: str,
+ serving_tool: Optional[str] = None,
+ script_file: Optional[str] = None,
+ resources: Optional[Union[PredictorResources, dict]] = None, # base
+ inference_logger: Optional[Union[InferenceLogger, dict]] = None, # base
+ inference_batcher: Optional[Union[InferenceBatcher, dict]] = None, # base
+ transformer: Optional[Union[Transformer, dict]] = None,
+ id: Optional[int] = None,
+ description: Optional[str] = None,
+ created_at: Optional[str] = None,
+ creator: Optional[str] = None,
+ api_protocol: Optional[str] = INFERENCE_ENDPOINTS.API_PROTOCOL_REST,
+ **kwargs,
+ ):
+ serving_tool = (
+ self._validate_serving_tool(serving_tool)
+ or self._get_default_serving_tool()
+ )
+ resources = self._validate_resources(
+ util.get_obj_from_json(resources, PredictorResources), serving_tool
+ ) or self._get_default_resources(serving_tool)
+
+ super().__init__(
+ script_file,
+ resources,
+ inference_batcher,
+ )
+
+ self._name = name
+ self._model_name = model_name
+ self._model_path = model_path
+ self._model_version = model_version
+ self._model_framework = model_framework
+ self._artifact_version = artifact_version
+ self._serving_tool = serving_tool
+ self._model_server = model_server
+ self._id = id
+ self._description = description
+ self._created_at = created_at
+ self._creator = creator
+
+ self._inference_logger = util.get_obj_from_json(
+ inference_logger, InferenceLogger
+ )
+ self._transformer = util.get_obj_from_json(transformer, Transformer)
+ self._validate_script_file(self._model_framework, self._script_file)
+ self._api_protocol = api_protocol
+
+ def deploy(self):
+ """Create a deployment for this predictor and persists it in the Model Serving.
+
+ !!! example
+ ```python
+
+ import hopsworks
+
+ project = hopsworks.login()
+
+ # get Hopsworks Model Registry handle
+ mr = project.get_model_registry()
+
+ # retrieve the trained model you want to deploy
+ my_model = mr.get_model("my_model", version=1)
+
+ # get Hopsworks Model Serving handle
+ ms = project.get_model_serving()
+
+ my_predictor = ms.create_predictor(my_model)
+ my_deployment = my_predictor.deploy()
+
+ print(my_deployment.get_state())
+ ```
+
+ # Returns
+ `Deployment`. The deployment metadata object of a new or existing deployment.
+ """
+
+ _deployment = deployment.Deployment(
+ predictor=self, name=self._name, description=self._description
+ )
+ _deployment.save()
+
+ return _deployment
+
+ def describe(self):
+ """Print a description of the predictor"""
+ util.pretty_print(self)
+
+ def _set_state(self, state: PredictorState):
+ """Set the state of the predictor"""
+ self._state = state
+
+ @classmethod
+ def _validate_serving_tool(cls, serving_tool):
+ if serving_tool is not None:
+ if client.is_saas_connection():
+ # only kserve supported in saasy hopsworks
+ if serving_tool != PREDICTOR.SERVING_TOOL_KSERVE:
+ raise ValueError(
+ "KServe deployments are the only supported in Serverless Hopsworks"
+ )
+ return serving_tool
+ # if not saas, check valid serving_tool
+ serving_tools = list(util.get_members(PREDICTOR, prefix="SERVING_TOOL"))
+ if serving_tool not in serving_tools:
+ raise ValueError(
+ "Serving tool '{}' is not valid. Possible values are '{}'".format(
+ serving_tool, ", ".join(serving_tools)
+ )
+ )
+ return serving_tool
+
+ @classmethod
+ def _validate_script_file(cls, model_framework, script_file):
+ if model_framework == MODEL.FRAMEWORK_PYTHON and script_file is None:
+ raise ValueError(
+ "Predictor scripts are required in deployments for custom Python models"
+ )
+
+ @classmethod
+ def _infer_model_server(cls, model_framework):
+ return (
+ PREDICTOR.MODEL_SERVER_TF_SERVING
+ if model_framework == MODEL.FRAMEWORK_TENSORFLOW
+ else PREDICTOR.MODEL_SERVER_PYTHON
+ )
+
+ @classmethod
+ def _get_default_serving_tool(cls):
+ # set kserve as default if it is available
+ return (
+ PREDICTOR.SERVING_TOOL_KSERVE
+ if client.is_kserve_installed()
+ else PREDICTOR.SERVING_TOOL_DEFAULT
+ )
+
+ @classmethod
+ def _validate_resources(cls, resources, serving_tool):
+ if resources is not None:
+ # ensure scale-to-zero for kserve deployments when required
+ if (
+ serving_tool == PREDICTOR.SERVING_TOOL_KSERVE
+ and resources.num_instances != 0
+ and client.is_scale_to_zero_required()
+ ):
+ raise ValueError(
+ "Scale-to-zero is required for KServe deployments in this cluster. Please, set the number of instances to 0."
+ )
+ return resources
+
+ @classmethod
+ def _get_default_resources(cls, serving_tool):
+ num_instances = (
+ 0 # enable scale-to-zero by default if required
+ if serving_tool == PREDICTOR.SERVING_TOOL_KSERVE
+ and client.is_scale_to_zero_required()
+ else RESOURCES.MIN_NUM_INSTANCES
+ )
+ return PredictorResources(num_instances)
+
+ @classmethod
+ def for_model(cls, model, **kwargs):
+ kwargs["model_name"] = model.name
+ kwargs["model_path"] = model.model_path
+ kwargs["model_version"] = model.version
+
+ # get predictor for specific model, includes model type-related validations
+ return util.get_predictor_for_model(model=model, **kwargs)
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ if isinstance(json_decamelized, list):
+ if len(json_decamelized) == 0:
+ return []
+ return [cls.from_json(predictor) for predictor in json_decamelized]
+ else:
+ if "count" in json_decamelized:
+ if json_decamelized["count"] == 0:
+ return []
+ return [
+ cls.from_json(predictor) for predictor in json_decamelized["items"]
+ ]
+ return cls.from_json(json_decamelized)
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ predictor = Predictor(**cls.extract_fields_from_json(json_decamelized))
+ predictor._set_state(PredictorState.from_response_json(json_decamelized))
+ return predictor
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ kwargs = {}
+ kwargs["name"] = json_decamelized.pop("name")
+ kwargs["description"] = util.extract_field_from_json(
+ json_decamelized, "description"
+ )
+ kwargs["model_name"] = util.extract_field_from_json(
+ json_decamelized, "model_name", default=kwargs["name"]
+ )
+ kwargs["model_path"] = json_decamelized.pop("model_path")
+ kwargs["model_version"] = json_decamelized.pop("model_version")
+ kwargs["model_framework"] = (
+ json_decamelized.pop("model_framework")
+ if "model_framework" in json_decamelized
+ else MODEL.FRAMEWORK_SKLEARN # backward compatibility
+ )
+ kwargs["artifact_version"] = util.extract_field_from_json(
+ json_decamelized, "artifact_version"
+ )
+ kwargs["model_server"] = json_decamelized.pop("model_server")
+ kwargs["serving_tool"] = json_decamelized.pop("serving_tool")
+ kwargs["script_file"] = util.extract_field_from_json(
+ json_decamelized, "predictor"
+ )
+ kwargs["resources"] = PredictorResources.from_json(json_decamelized)
+ kwargs["inference_logger"] = InferenceLogger.from_json(json_decamelized)
+ kwargs["inference_batcher"] = InferenceBatcher.from_json(json_decamelized)
+ kwargs["transformer"] = Transformer.from_json(json_decamelized)
+ kwargs["id"] = json_decamelized.pop("id")
+ kwargs["created_at"] = json_decamelized.pop("created")
+ kwargs["creator"] = json_decamelized.pop("creator")
+ kwargs["api_protocol"] = json_decamelized.pop("api_protocol")
+ return kwargs
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ self.__init__(**self.extract_fields_from_json(json_decamelized))
+ self._set_state(PredictorState.from_response_json(json_decamelized))
+ return self
+
+ def json(self):
+ return json.dumps(self, cls=util.MLEncoder)
+
+ def to_dict(self):
+ json = {
+ "id": self._id,
+ "name": self._name,
+ "description": self._description,
+ "modelName": self._model_name,
+ "modelPath": self._model_path,
+ "modelVersion": self._model_version,
+ "modelFramework": self._model_framework,
+ "artifactVersion": self._artifact_version,
+ "created": self._created_at,
+ "creator": self._creator,
+ "modelServer": self._model_server,
+ "servingTool": self._serving_tool,
+ "predictor": self._script_file,
+ "apiProtocol": self._api_protocol,
+ }
+ if self._resources is not None:
+ json = {**json, **self._resources.to_dict()}
+ if self._inference_logger is not None:
+ json = {**json, **self._inference_logger.to_dict()}
+ if self._inference_batcher is not None:
+ json = {**json, **self._inference_batcher.to_dict()}
+ if self._transformer is not None:
+ json = {**json, **self._transformer.to_dict()}
+ return json
+
+ @property
+ def id(self):
+ """Id of the predictor."""
+ return self._id
+
+ @property
+ def name(self):
+ """Name of the predictor."""
+ return self._name
+
+ @name.setter
+ def name(self, name: str):
+ self._name = name
+
+ @property
+ def description(self):
+ """Description of the predictor."""
+ return self._description
+
+ @description.setter
+ def description(self, description: str):
+ self._description = description
+
+ @property
+ def model_name(self):
+ """Name of the model deployed by the predictor."""
+ return self._model_name
+
+ @model_name.setter
+ def model_name(self, model_name: str):
+ self._model_name = model_name
+
+ @property
+ def model_path(self):
+ """Model path deployed by the predictor."""
+ return self._model_path
+
+ @model_path.setter
+ def model_path(self, model_path: str):
+ self._model_path = model_path
+
+ @property
+ def model_version(self):
+ """Model version deployed by the predictor."""
+ return self._model_version
+
+ @model_version.setter
+ def model_version(self, model_version: int):
+ self._model_version = model_version
+
+ @property
+ def model_framework(self):
+ """Model framework of the model to be deployed by the predictor."""
+ return self._model_framework
+
+ @model_framework.setter
+ def model_framework(self, model_framework: str):
+ self._model_framework = model_framework
+ self._model_server = self._infer_model_server(model_framework)
+
+ @property
+ def artifact_version(self):
+ """Artifact version deployed by the predictor."""
+ return self._artifact_version
+
+ @artifact_version.setter
+ def artifact_version(self, artifact_version: Union[int, str]):
+ self._artifact_version = artifact_version
+
+ @property
+ def artifact_path(self):
+ """Path of the model artifact deployed by the predictor. Resolves to /Projects/{project_name}/Models/{name}/{version}/Artifacts/{artifact_version}/{name}_{version}_{artifact_version}.zip"""
+ artifact_name = "{}_{}_{}.zip".format(
+ self._model_name, str(self._model_version), str(self._artifact_version)
+ )
+ return "{}/{}/Artifacts/{}/{}".format(
+ self._model_path,
+ str(self._model_version),
+ str(self._artifact_version),
+ artifact_name,
+ )
+
+ @property
+ def model_server(self):
+ """Model server used by the predictor."""
+ return self._model_server
+
+ @property
+ def serving_tool(self):
+ """Serving tool used to run the model server."""
+ return self._serving_tool
+
+ @serving_tool.setter
+ def serving_tool(self, serving_tool: str):
+ self._serving_tool = serving_tool
+
+ @property
+ def script_file(self):
+ """Script file used to load and run the model."""
+ return self._script_file
+
+ @script_file.setter
+ def script_file(self, script_file: str):
+ self._script_file = script_file
+ self._artifact_version = ARTIFACT_VERSION.CREATE
+
+ @property
+ def inference_logger(self):
+ """Configuration of the inference logger attached to this predictor."""
+ return self._inference_logger
+
+ @inference_logger.setter
+ def inference_logger(self, inference_logger: InferenceLogger):
+ self._inference_logger = inference_logger
+
+ @property
+ def transformer(self):
+ """Transformer configuration attached to the predictor."""
+ return self._transformer
+
+ @transformer.setter
+ def transformer(self, transformer: Transformer):
+ self._transformer = transformer
+
+ @property
+ def created_at(self):
+ """Created at date of the predictor."""
+ return self._created_at
+
+ @property
+ def creator(self):
+ """Creator of the predictor."""
+ return self._creator
+
+ @property
+ def requested_instances(self):
+ """Total number of requested instances in the predictor."""
+ num_instances = self._resources.num_instances
+ if self._transformer is not None:
+ num_instances += self._transformer.resources.num_instances
+ return num_instances
+
+ @property
+ def api_protocol(self):
+ """API protocol enabled in the predictor (e.g., HTTP or GRPC)."""
+ return self._api_protocol
+
+ @api_protocol.setter
+ def api_protocol(self, api_protocol):
+ self._api_protocol = api_protocol
+
+ def __repr__(self):
+ desc = (
+ f", description: {self._description!r}"
+ if self._description is not None
+ else ""
+ )
+ return f"Predictor(name: {self._name!r}" + desc + ")"
diff --git a/hsml/python/hsml/predictor_state.py b/hsml/python/hsml/predictor_state.py
new file mode 100644
index 000000000..b145993e1
--- /dev/null
+++ b/hsml/python/hsml/predictor_state.py
@@ -0,0 +1,147 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional
+
+import humps
+from hsml import util
+from hsml.predictor_state_condition import PredictorStateCondition
+
+
+class PredictorState:
+ """State of a predictor."""
+
+ def __init__(
+ self,
+ available_predictor_instances: int,
+ available_transformer_instances: Optional[int],
+ hopsworks_inference_path: str,
+ model_server_inference_path: str,
+ internal_port: Optional[int],
+ revision: Optional[int],
+ deployed: Optional[bool],
+ condition: Optional[PredictorStateCondition],
+ status: str,
+ **kwargs,
+ ):
+ self._available_predictor_instances = available_predictor_instances
+ self._available_transformer_instances = available_transformer_instances
+ self._hopsworks_inference_path = hopsworks_inference_path
+ self._model_server_inference_path = model_server_inference_path
+ self._internal_port = internal_port
+ self._revision = revision
+ self._deployed = deployed if deployed is not None else False
+ self._condition = condition
+ self._status = status
+
+ def describe(self):
+ """Print a description of the deployment state"""
+ util.pretty_print(self)
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ return PredictorState(*cls.extract_fields_from_json(json_decamelized))
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ ai = util.extract_field_from_json(json_decamelized, "available_instances")
+ ati = util.extract_field_from_json(
+ json_decamelized, "available_transformer_instances"
+ )
+ hip = util.extract_field_from_json(json_decamelized, "hopsworks_inference_path")
+ msip = util.extract_field_from_json(
+ json_decamelized, "model_server_inference_path"
+ )
+ ipt = util.extract_field_from_json(json_decamelized, "internal_port")
+ r = util.extract_field_from_json(json_decamelized, "revision")
+ d = util.extract_field_from_json(json_decamelized, "deployed")
+ c = util.extract_field_from_json(
+ json_decamelized, "condition", as_instance_of=PredictorStateCondition
+ )
+ s = util.extract_field_from_json(json_decamelized, "status")
+
+ return ai, ati, hip, msip, ipt, r, d, c, s
+
+ def to_dict(self):
+ json = {
+ "availableInstances": self._available_predictor_instances,
+ "hopsworksInferencePath": self._hopsworks_inference_path,
+ "modelServerInferencePath": self._model_server_inference_path,
+ "status": self._status,
+ }
+
+ if self._available_transformer_instances is not None:
+ json["availableTransformerInstances"] = (
+ self._available_transformer_instances
+ )
+ if self._internal_port is not None:
+ json["internalPort"] = self._internal_port
+ if self._revision is not None:
+ json["revision"] = self._revision
+ if self._deployed is not None:
+ json["deployed"] = self._deployed
+ if self._condition is not None:
+ json = {**json, **self._condition.to_dict()}
+
+ return json
+
+ @property
+ def available_predictor_instances(self):
+ """Available predicotr instances."""
+ return self._available_predictor_instances
+
+ @property
+ def available_transformer_instances(self):
+ """Available transformer instances."""
+ return self._available_transformer_instances
+
+ @property
+ def hopsworks_inference_path(self):
+ """Inference path in the Hopsworks REST API."""
+ return self._hopsworks_inference_path
+
+ @property
+ def model_server_inference_path(self):
+ """Inference path in the model server"""
+ return self._model_server_inference_path
+
+ @property
+ def internal_port(self):
+ """Internal port for the predictor."""
+ return self._internal_port
+
+ @property
+ def revision(self):
+ """Last revision of the predictor."""
+ return self._revision
+
+ @property
+ def deployed(self):
+ """Whether the predictor is deployed or not."""
+ return self._deployed
+
+ @property
+ def condition(self):
+ """Condition of the current state of predictor."""
+ return self._condition
+
+ @property
+ def status(self):
+ """Status of the predictor."""
+ return self._status
+
+ def __repr__(self):
+ return f"PredictorState(status: {self.status.capitalize()!r})"
diff --git a/hsml/python/hsml/predictor_state_condition.py b/hsml/python/hsml/predictor_state_condition.py
new file mode 100644
index 000000000..cf1c58934
--- /dev/null
+++ b/hsml/python/hsml/predictor_state_condition.py
@@ -0,0 +1,90 @@
+#
+# Copyright 2022 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from typing import Optional
+
+import humps
+from hsml import util
+
+
+class PredictorStateCondition:
+ """Condition of a predictor state."""
+
+ def __init__(
+ self,
+ type: str,
+ status: Optional[bool] = None,
+ reason: Optional[str] = None,
+ **kwargs,
+ ):
+ self._type = type
+ self._status = status
+ self._reason = reason
+
+ def describe(self):
+ util.pretty_print(self)
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ return cls.from_json(json_decamelized)
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ return PredictorStateCondition(**cls.extract_fields_from_json(json_decamelized))
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ kwargs = {}
+ kwargs["type"] = json_decamelized.pop("type") # required
+ kwargs["status"] = util.extract_field_from_json(json_decamelized, "status")
+ kwargs["reason"] = util.extract_field_from_json(json_decamelized, "reason")
+ return kwargs
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ self.__init__(**self.extract_fields_from_json(json_decamelized))
+ return self
+
+ def json(self):
+ return json.dumps(self, cls=util.MLEncoder)
+
+ def to_dict(self):
+ return {
+ "condition": {
+ "type": self._type,
+ "status": self._status,
+ "reason": self._reason,
+ }
+ }
+
+ @property
+ def type(self):
+ """Condition type of the predictor state."""
+ return self._type
+
+ @property
+ def status(self):
+ """Condition status of the predictor state."""
+ return self._status
+
+ @property
+ def reason(self):
+ """Condition reason of the predictor state."""
+ return self._reason
+
+ def __repr__(self):
+ return f"PredictorStateCondition(type: {self.type.capitalize()!r}, status: {self.status!r})"
diff --git a/hsml/python/hsml/python/__init__.py b/hsml/python/hsml/python/__init__.py
new file mode 100644
index 000000000..ff0a6f046
--- /dev/null
+++ b/hsml/python/hsml/python/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/python/model.py b/hsml/python/hsml/python/model.py
new file mode 100644
index 000000000..ce2f0f984
--- /dev/null
+++ b/hsml/python/hsml/python/model.py
@@ -0,0 +1,79 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import humps
+from hsml.constants import MODEL
+from hsml.model import Model
+
+
+class Model(Model):
+ """Metadata object representing a generic python model in the Model Registry."""
+
+ def __init__(
+ self,
+ id,
+ name,
+ version=None,
+ created=None,
+ creator=None,
+ environment=None,
+ description=None,
+ experiment_id=None,
+ project_name=None,
+ experiment_project_name=None,
+ metrics=None,
+ program=None,
+ user_full_name=None,
+ model_schema=None,
+ training_dataset=None,
+ input_example=None,
+ model_registry_id=None,
+ tags=None,
+ href=None,
+ feature_view=None,
+ training_dataset_version=None,
+ **kwargs,
+ ):
+ super().__init__(
+ id,
+ name,
+ version=version,
+ created=created,
+ creator=creator,
+ environment=environment,
+ description=description,
+ experiment_id=experiment_id,
+ project_name=project_name,
+ experiment_project_name=experiment_project_name,
+ metrics=metrics,
+ program=program,
+ user_full_name=user_full_name,
+ model_schema=model_schema,
+ training_dataset=training_dataset,
+ input_example=input_example,
+ framework=MODEL.FRAMEWORK_PYTHON,
+ model_registry_id=model_registry_id,
+ feature_view=feature_view,
+ training_dataset_version=training_dataset_version,
+ )
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ json_decamelized.pop("framework")
+ if "type" in json_decamelized: # backwards compatibility
+ _ = json_decamelized.pop("type")
+ self.__init__(**json_decamelized)
+ return self
diff --git a/hsml/python/hsml/python/predictor.py b/hsml/python/hsml/python/predictor.py
new file mode 100644
index 000000000..a3ca1643f
--- /dev/null
+++ b/hsml/python/hsml/python/predictor.py
@@ -0,0 +1,33 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from hsml.constants import MODEL, PREDICTOR
+from hsml.predictor import Predictor
+
+
+class Predictor(Predictor):
+ """Configuration for a predictor running a python model."""
+
+ def __init__(self, **kwargs):
+ kwargs["model_framework"] = MODEL.FRAMEWORK_PYTHON
+ kwargs["model_server"] = PREDICTOR.MODEL_SERVER_PYTHON
+
+ if kwargs["script_file"] is None:
+ raise ValueError(
+ "Predictor scripts are required in deployments for custom Python models"
+ )
+
+ super().__init__(**kwargs)
diff --git a/hsml/python/hsml/python/signature.py b/hsml/python/hsml/python/signature.py
new file mode 100644
index 000000000..94b154abf
--- /dev/null
+++ b/hsml/python/hsml/python/signature.py
@@ -0,0 +1,75 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional, Union
+
+import numpy
+import pandas
+from hsml.model_schema import ModelSchema
+from hsml.python.model import Model
+
+
+_mr = None
+
+
+def create_model(
+ name: str,
+ version: Optional[int] = None,
+ metrics: Optional[dict] = None,
+ description: Optional[str] = None,
+ input_example: Optional[
+ Union[pandas.DataFrame, pandas.Series, numpy.ndarray, list]
+ ] = None,
+ model_schema: Optional[ModelSchema] = None,
+ feature_view=None,
+ training_dataset_version: Optional[int] = None,
+):
+ """Create a generic Python model metadata object.
+
+ !!! note "Lazy"
+ This method is lazy and does not persist any metadata or uploads model artifacts in the
+ model registry on its own. To save the model object and the model artifacts, call the `save()` method with a
+ local file path to the directory containing the model artifacts.
+
+ # Arguments
+ name: Name of the model to create.
+ version: Optionally version of the model to create, defaults to `None` and
+ will create the model with incremented version from the last
+ version in the model registry.
+ metrics: Optionally a dictionary with model evaluation metrics (e.g., accuracy, MAE)
+ description: Optionally a string describing the model, defaults to empty string
+ `""`.
+ input_example: Optionally an input example that represents a single input for the model, defaults to `None`.
+ model_schema: Optionally a model schema for the model inputs and/or outputs.
+
+ # Returns
+ `Model`. The model metadata object.
+ """
+ model = Model(
+ id=None,
+ name=name,
+ version=version,
+ description=description,
+ metrics=metrics,
+ input_example=input_example,
+ model_schema=model_schema,
+ feature_view=feature_view,
+ training_dataset_version=training_dataset_version,
+ )
+ model._shared_registry_project_name = _mr.shared_registry_project_name
+ model._model_registry_id = _mr.model_registry_id
+
+ return model
diff --git a/hsml/python/hsml/resources.py b/hsml/python/hsml/resources.py
new file mode 100644
index 000000000..039aa263a
--- /dev/null
+++ b/hsml/python/hsml/resources.py
@@ -0,0 +1,394 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from abc import ABC, abstractmethod
+from typing import Optional, Union
+
+import humps
+from hsml import client, util
+from hsml.constants import RESOURCES
+
+
+class Resources:
+ """Resource configuration for a predictor or transformer.
+
+ # Arguments
+ cores: Number of CPUs.
+ memory: Memory (MB) resources.
+ gpus: Number of GPUs.
+ # Returns
+ `Resources`. Resource configuration for a predictor or transformer.
+ """
+
+ def __init__(
+ self,
+ cores: int,
+ memory: int,
+ gpus: int,
+ **kwargs,
+ ):
+ self._cores = cores
+ self._memory = memory
+ self._gpus = gpus
+
+ def describe(self):
+ """Print a description of the resource configuration"""
+ util.pretty_print(self)
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ return cls.from_json(json_decamelized)
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ return Resources(**cls.extract_fields_from_json(json_decamelized))
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ kwargs = {}
+ kwargs["cores"] = util.extract_field_from_json(json_decamelized, "cores")
+ kwargs["memory"] = util.extract_field_from_json(json_decamelized, "memory")
+ kwargs["gpus"] = util.extract_field_from_json(json_decamelized, "gpus")
+ return kwargs
+
+ def json(self):
+ return json.dumps(self, cls=util.MLEncoder)
+
+ def to_dict(self):
+ return {"cores": self._cores, "memory": self._memory, "gpus": self._gpus}
+
+ @property
+ def cores(self):
+ """Number of CPUs to be allocated per instance"""
+ return self._cores
+
+ @cores.setter
+ def cores(self, cores: int):
+ self._cores = cores
+
+ @property
+ def memory(self):
+ """Memory resources to be allocated per instance"""
+ return self._memory
+
+ @memory.setter
+ def memory(self, memory: int):
+ self._memory = memory
+
+ @property
+ def gpus(self):
+ """Number of GPUs to be allocated per instance"""
+ return self._gpus
+
+ @gpus.setter
+ def gpus(self, gpus: int):
+ self._gpus = gpus
+
+ def __repr__(self):
+ return f"Resources(cores: {self._cores!r}, memory: {self._memory!r}, gpus: {self._gpus!r})"
+
+
+class ComponentResources(ABC):
+ """Resource configuration for a predictor or transformer.
+
+ # Arguments
+ num_instances: Number of instances.
+ requests: Minimum resources to allocate for a deployment
+ limits: Maximum resources to allocate for a deployment
+ # Returns
+ `ComponentResource`. Resource configuration for a predictor or transformer.
+ """
+
+ def __init__(
+ self,
+ num_instances: int,
+ requests: Optional[Union[Resources, dict]] = None,
+ limits: Optional[Union[Resources, dict]] = None,
+ ):
+ self._num_instances = num_instances
+ self._requests = util.get_obj_from_json(requests, Resources) or Resources(
+ RESOURCES.MIN_CORES, RESOURCES.MIN_MEMORY, RESOURCES.MIN_GPUS
+ )
+ self._fill_missing_resources(
+ self._requests,
+ RESOURCES.MIN_CORES,
+ RESOURCES.MIN_MEMORY,
+ RESOURCES.MIN_GPUS,
+ )
+ self._limits = util.get_obj_from_json(limits, Resources) or Resources(
+ *self._get_default_resource_limits()
+ )
+ self._fill_missing_resources(self._limits, *self._get_default_resource_limits())
+
+ # validate both requests and limits
+ self._validate_resources(self._requests, self._limits)
+
+ def describe(self):
+ """Print a description of the resource configuration"""
+ util.pretty_print(self)
+
+ def _get_default_resource_limits(self):
+ max_resources = client.get_serving_resource_limits()
+ # cores limit
+ if max_resources["cores"] == -1: # no limit
+ max_cores = (
+ RESOURCES.MAX_CORES
+ if RESOURCES.MAX_CORES >= self._requests.cores
+ else self._requests.cores
+ )
+ else:
+ max_cores = (
+ RESOURCES.MAX_CORES
+ if RESOURCES.MAX_CORES <= max_resources["cores"]
+ and RESOURCES.MAX_CORES >= self._requests.cores
+ else max_resources["cores"]
+ )
+ # memory limit
+ if max_resources["memory"] == -1: # no limit
+ max_memory = (
+ RESOURCES.MAX_MEMORY
+ if RESOURCES.MAX_MEMORY >= self._requests.memory
+ else self._requests.memory
+ )
+ else:
+ max_memory = (
+ RESOURCES.MAX_MEMORY
+ if RESOURCES.MAX_MEMORY <= max_resources["memory"]
+ and RESOURCES.MAX_MEMORY >= self._requests.memory
+ else max_resources["memory"]
+ )
+ # gpus limit
+ if max_resources["gpus"] == -1: # no limit
+ max_gpus = (
+ RESOURCES.MAX_GPUS
+ if RESOURCES.MAX_GPUS >= self._requests.gpus
+ else self._requests.gpus
+ )
+ else:
+ max_gpus = (
+ RESOURCES.MAX_GPUS
+ if RESOURCES.MAX_GPUS <= max_resources["gpus"]
+ and RESOURCES.MAX_GPUS >= self._requests.gpus
+ else max_resources["gpus"]
+ )
+ return max_cores, max_memory, max_gpus
+
+ @classmethod
+ def _fill_missing_resources(cls, resources, cores, memory, gpus):
+ if resources.cores is None:
+ resources.cores = cores
+ if resources.memory is None:
+ resources.memory = memory
+ if resources.gpus is None:
+ resources.gpus = gpus
+
+ @classmethod
+ def _validate_resources(cls, requests, limits):
+ # limits
+ max_resources = client.get_serving_resource_limits()
+ if max_resources["cores"] > -1:
+ if limits.cores <= 0:
+ raise ValueError("Limit number of cores must be greater than 0 cores.")
+ if limits.cores > max_resources["cores"]:
+ raise ValueError(
+ "Limit number of cores cannot exceed the maximum of "
+ + str(max_resources["cores"])
+ + " cores."
+ )
+ if max_resources["memory"] > -1:
+ if limits.memory <= 0:
+ raise ValueError("Limit memory resources must be greater than 0 MB.")
+ if limits.memory > max_resources["memory"]:
+ raise ValueError(
+ "Limit memory resources cannot exceed the maximum of "
+ + str(max_resources["memory"])
+ + " MB."
+ )
+ if max_resources["gpus"] > -1:
+ if limits.gpus < 0:
+ raise ValueError(
+ "Limit number of gpus must be greater than or equal to 0 gpus."
+ )
+ if limits.gpus > max_resources["gpus"]:
+ raise ValueError(
+ "Limit number of gpus cannot exceed the maximum of "
+ + str(max_resources["gpus"])
+ + " gpus."
+ )
+
+ # requests
+ if requests.cores <= 0:
+ raise ValueError("Requested number of cores must be greater than 0 cores.")
+ if limits.cores > -1 and requests.cores > limits.cores:
+ raise ValueError(
+ "Requested number of cores cannot exceed the limit of "
+ + str(limits.cores)
+ + " cores."
+ )
+ if requests.memory <= 0:
+ raise ValueError("Requested memory resources must be greater than 0 MB.")
+ if limits.memory > -1 and requests.memory > limits.memory:
+ raise ValueError(
+ "Requested memory resources cannot exceed the limit of "
+ + str(limits.memory)
+ + " MB."
+ )
+ if requests.gpus < 0:
+ raise ValueError(
+ "Requested number of gpus must be greater than or equal to 0 gpus."
+ )
+ if limits.gpus > -1 and requests.gpus > limits.gpus:
+ raise ValueError(
+ "Requested number of gpus cannot exceed the limit of "
+ + str(limits.gpus)
+ + " gpus."
+ )
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ return cls.from_json(json_decamelized)
+
+ @classmethod
+ @abstractmethod
+ def from_json(cls, json_decamelized):
+ pass
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ kwargs = {}
+
+ # extract resources
+ if cls.RESOURCES_CONFIG_KEY in json_decamelized:
+ resources = json_decamelized.pop(cls.RESOURCES_CONFIG_KEY)
+ elif "resources" in json_decamelized:
+ resources = json_decamelized.pop("resources")
+ else:
+ resources = json_decamelized
+
+ # extract resource fields
+ kwargs["requests"] = util.extract_field_from_json(
+ resources, "requests", as_instance_of=Resources
+ )
+ kwargs["limits"] = util.extract_field_from_json(
+ resources, "limits", as_instance_of=Resources
+ )
+
+ # extract num instances
+ if cls.NUM_INSTANCES_KEY in json_decamelized:
+ kwargs["num_instances"] = json_decamelized.pop(cls.NUM_INSTANCES_KEY)
+ elif "num_instances" in json_decamelized:
+ kwargs["num_instances"] = json_decamelized.pop("num_instances")
+ else:
+ kwargs["num_instances"] = util.extract_field_from_json(
+ resources, [cls.NUM_INSTANCES_KEY, "num_instances"]
+ )
+
+ return kwargs
+
+ def json(self):
+ return json.dumps(self, cls=util.MLEncoder)
+
+ @abstractmethod
+ def to_dict(self):
+ pass
+
+ @property
+ def num_instances(self):
+ """Number of instances"""
+ return self._num_instances
+
+ @num_instances.setter
+ def num_instances(self, num_instances: int):
+ self._num_instances = num_instances
+
+ @property
+ def requests(self):
+ """Minimum resources to allocate"""
+ return self._requests
+
+ @requests.setter
+ def requests(self, requests: Resources):
+ self._resources = requests
+
+ @property
+ def limits(self):
+ """Maximum resources to allocate"""
+ return self._limits
+
+ @limits.setter
+ def limits(self, limits: Resources):
+ self._limits = limits
+
+ def __repr__(self):
+ return f"ComponentResources(num_instances: {self._num_instances!r}, requests: {self._requests is not None!r}, limits: {self._limits is not None!r})"
+
+
+class PredictorResources(ComponentResources):
+ RESOURCES_CONFIG_KEY = "predictor_resources"
+ NUM_INSTANCES_KEY = "requested_instances"
+
+ def __init__(
+ self,
+ num_instances: int,
+ requests: Optional[Union[Resources, dict]] = None,
+ limits: Optional[Union[Resources, dict]] = None,
+ ):
+ super().__init__(num_instances, requests, limits)
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ return PredictorResources(**cls.extract_fields_from_json(json_decamelized))
+
+ def to_dict(self):
+ return {
+ humps.camelize(self.NUM_INSTANCES_KEY): self._num_instances,
+ humps.camelize(self.RESOURCES_CONFIG_KEY): {
+ "requests": (
+ self._requests.to_dict() if self._requests is not None else None
+ ),
+ "limits": self._limits.to_dict() if self._limits is not None else None,
+ },
+ }
+
+
+class TransformerResources(ComponentResources):
+ RESOURCES_CONFIG_KEY = "transformer_resources"
+ NUM_INSTANCES_KEY = "requested_transformer_instances"
+
+ def __init__(
+ self,
+ num_instances: int,
+ requests: Optional[Union[Resources, dict]] = None,
+ limits: Optional[Union[Resources, dict]] = None,
+ ):
+ super().__init__(num_instances, requests, limits)
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ return TransformerResources(**cls.extract_fields_from_json(json_decamelized))
+
+ def to_dict(self):
+ return {
+ humps.camelize(self.NUM_INSTANCES_KEY): self._num_instances,
+ humps.camelize(self.RESOURCES_CONFIG_KEY): {
+ "requests": (
+ self._requests.to_dict() if self._requests is not None else None
+ ),
+ "limits": self._limits.to_dict() if self._limits is not None else None,
+ },
+ }
diff --git a/hsml/python/hsml/schema.py b/hsml/python/hsml/schema.py
new file mode 100644
index 000000000..22e46aed1
--- /dev/null
+++ b/hsml/python/hsml/schema.py
@@ -0,0 +1,83 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+from typing import Optional, TypeVar, Union
+
+import numpy
+import pandas
+from hsml.utils.schema.columnar_schema import ColumnarSchema
+from hsml.utils.schema.tensor_schema import TensorSchema
+
+
+class Schema:
+ """Create a schema for a model input or output.
+
+ # Arguments
+ object: The object to construct the schema from.
+
+ # Returns
+ `Schema`. The schema object.
+ """
+
+ def __init__(
+ self,
+ object: Optional[
+ Union[
+ pandas.DataFrame,
+ pandas.Series,
+ TypeVar("pyspark.sql.dataframe.DataFrame"), # noqa: F821
+ TypeVar("hsfs.training_dataset.TrainingDataset"), # noqa: F821
+ numpy.ndarray,
+ list,
+ ]
+ ] = None,
+ **kwargs,
+ ):
+ # A tensor schema is either ndarray of a list containing name, type and shape dicts
+ if isinstance(object, numpy.ndarray) or (
+ isinstance(object, list) and all(["shape" in entry for entry in object])
+ ):
+ self.tensor_schema = self._convert_tensor_to_schema(object).tensors
+ else:
+ self.columnar_schema = self._convert_columnar_to_schema(object).columns
+
+ def _convert_columnar_to_schema(self, object):
+ return ColumnarSchema(object)
+
+ def _convert_tensor_to_schema(self, object):
+ return TensorSchema(object)
+
+ def _get_type(self):
+ if hasattr(self, "tensor_schema"):
+ return "tensor"
+ if hasattr(self, "columnar_schema"):
+ return "columnar"
+ return None
+
+ def json(self):
+ return json.dumps(
+ self, default=lambda o: getattr(o, "__dict__", o), sort_keys=True, indent=2
+ )
+
+ def to_dict(self):
+ """
+ Get dict representation of the Schema.
+ """
+ return json.loads(self.json())
+
+ def __repr__(self):
+ return f"Schema(type: {self._get_type()!r})"
diff --git a/hsml/python/hsml/sklearn/__init__.py b/hsml/python/hsml/sklearn/__init__.py
new file mode 100644
index 000000000..ff0a6f046
--- /dev/null
+++ b/hsml/python/hsml/sklearn/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/sklearn/model.py b/hsml/python/hsml/sklearn/model.py
new file mode 100644
index 000000000..900a36204
--- /dev/null
+++ b/hsml/python/hsml/sklearn/model.py
@@ -0,0 +1,79 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import humps
+from hsml.constants import MODEL
+from hsml.model import Model
+
+
+class Model(Model):
+ """Metadata object representing an sklearn model in the Model Registry."""
+
+ def __init__(
+ self,
+ id,
+ name,
+ version=None,
+ created=None,
+ creator=None,
+ environment=None,
+ description=None,
+ experiment_id=None,
+ project_name=None,
+ experiment_project_name=None,
+ metrics=None,
+ program=None,
+ user_full_name=None,
+ model_schema=None,
+ training_dataset=None,
+ input_example=None,
+ model_registry_id=None,
+ tags=None,
+ href=None,
+ feature_view=None,
+ training_dataset_version=None,
+ **kwargs,
+ ):
+ super().__init__(
+ id,
+ name,
+ version=version,
+ created=created,
+ creator=creator,
+ environment=environment,
+ description=description,
+ experiment_id=experiment_id,
+ project_name=project_name,
+ experiment_project_name=experiment_project_name,
+ metrics=metrics,
+ program=program,
+ user_full_name=user_full_name,
+ model_schema=model_schema,
+ training_dataset=training_dataset,
+ input_example=input_example,
+ framework=MODEL.FRAMEWORK_SKLEARN,
+ model_registry_id=model_registry_id,
+ feature_view=feature_view,
+ training_dataset_version=training_dataset_version,
+ )
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ json_decamelized.pop("framework")
+ if "type" in json_decamelized: # backwards compatibility
+ _ = json_decamelized.pop("type")
+ self.__init__(**json_decamelized)
+ return self
diff --git a/hsml/python/hsml/sklearn/predictor.py b/hsml/python/hsml/sklearn/predictor.py
new file mode 100644
index 000000000..1d43c66f7
--- /dev/null
+++ b/hsml/python/hsml/sklearn/predictor.py
@@ -0,0 +1,28 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from hsml.constants import MODEL, PREDICTOR
+from hsml.predictor import Predictor
+
+
+class Predictor(Predictor):
+ """Configuration for a predictor running a sklearn model."""
+
+ def __init__(self, **kwargs):
+ kwargs["model_framework"] = MODEL.FRAMEWORK_SKLEARN
+ kwargs["model_server"] = PREDICTOR.MODEL_SERVER_PYTHON
+
+ super().__init__(**kwargs)
diff --git a/hsml/python/hsml/sklearn/signature.py b/hsml/python/hsml/sklearn/signature.py
new file mode 100644
index 000000000..ef2ab74d2
--- /dev/null
+++ b/hsml/python/hsml/sklearn/signature.py
@@ -0,0 +1,75 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional, Union
+
+import numpy
+import pandas
+from hsml.model_schema import ModelSchema
+from hsml.sklearn.model import Model
+
+
+_mr = None
+
+
+def create_model(
+ name: str,
+ version: Optional[int] = None,
+ metrics: Optional[dict] = None,
+ description: Optional[str] = None,
+ input_example: Optional[
+ Union[pandas.DataFrame, pandas.Series, numpy.ndarray, list]
+ ] = None,
+ model_schema: Optional[ModelSchema] = None,
+ feature_view=None,
+ training_dataset_version: Optional[int] = None,
+):
+ """Create an SkLearn model metadata object.
+
+ !!! note "Lazy"
+ This method is lazy and does not persist any metadata or uploads model artifacts in the
+ model registry on its own. To save the model object and the model artifacts, call the `save()` method with a
+ local file path to the directory containing the model artifacts.
+
+ # Arguments
+ name: Name of the model to create.
+ version: Optionally version of the model to create, defaults to `None` and
+ will create the model with incremented version from the last
+ version in the model registry.
+ metrics: Optionally a dictionary with model evaluation metrics (e.g., accuracy, MAE)
+ description: Optionally a string describing the model, defaults to empty string
+ `""`.
+ input_example: Optionally an input example that represents a single input for the model, defaults to `None`.
+ model_schema: Optionally a model schema for the model inputs and/or outputs.
+
+ # Returns
+ `Model`. The model metadata object.
+ """
+ model = Model(
+ id=None,
+ name=name,
+ version=version,
+ description=description,
+ metrics=metrics,
+ input_example=input_example,
+ model_schema=model_schema,
+ feature_view=feature_view,
+ training_dataset_version=training_dataset_version,
+ )
+ model._shared_registry_project_name = _mr.shared_registry_project_name
+ model._model_registry_id = _mr.model_registry_id
+
+ return model
diff --git a/hsml/python/hsml/tag.py b/hsml/python/hsml/tag.py
new file mode 100644
index 000000000..aecf2ed74
--- /dev/null
+++ b/hsml/python/hsml/tag.py
@@ -0,0 +1,77 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+
+import humps
+from hsml import util
+
+
+class Tag:
+ def __init__(
+ self,
+ name,
+ value,
+ schema=None,
+ href=None,
+ expand=None,
+ items=None,
+ count=None,
+ type=None,
+ **kwargs,
+ ):
+ self._name = name
+ self._value = value
+
+ def to_dict(self):
+ return {
+ "name": self._name,
+ "value": self._value,
+ }
+
+ def json(self):
+ return json.dumps(self, cls=util.MLEncoder)
+
+ @classmethod
+ def from_response_json(cls, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ if "count" not in json_decamelized or json_decamelized["count"] == 0:
+ return []
+ return [cls(**tag) for tag in json_decamelized["items"]]
+
+ @property
+ def name(self):
+ """Name of the tag."""
+ return self._name
+
+ @name.setter
+ def name(self, name):
+ self._name = name
+
+ @property
+ def value(self):
+ """Value of tag."""
+ return self._value
+
+ @value.setter
+ def value(self, value):
+ self._value = value
+
+ def __str__(self):
+ return self.json()
+
+ def __repr__(self):
+ return f"Tag({self._name!r}, {self._value!r})"
diff --git a/hsml/python/hsml/tensorflow/__init__.py b/hsml/python/hsml/tensorflow/__init__.py
new file mode 100644
index 000000000..ff0a6f046
--- /dev/null
+++ b/hsml/python/hsml/tensorflow/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/tensorflow/model.py b/hsml/python/hsml/tensorflow/model.py
new file mode 100644
index 000000000..c09ccf10d
--- /dev/null
+++ b/hsml/python/hsml/tensorflow/model.py
@@ -0,0 +1,79 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import humps
+from hsml.constants import MODEL
+from hsml.model import Model
+
+
+class Model(Model):
+ """Metadata object representing a tensorflow model in the Model Registry."""
+
+ def __init__(
+ self,
+ id,
+ name,
+ version=None,
+ created=None,
+ creator=None,
+ environment=None,
+ description=None,
+ experiment_id=None,
+ project_name=None,
+ experiment_project_name=None,
+ metrics=None,
+ program=None,
+ user_full_name=None,
+ model_schema=None,
+ training_dataset=None,
+ input_example=None,
+ model_registry_id=None,
+ tags=None,
+ href=None,
+ feature_view=None,
+ training_dataset_version=None,
+ **kwargs,
+ ):
+ super().__init__(
+ id,
+ name,
+ version=version,
+ created=created,
+ creator=creator,
+ environment=environment,
+ description=description,
+ experiment_id=experiment_id,
+ project_name=project_name,
+ experiment_project_name=experiment_project_name,
+ metrics=metrics,
+ program=program,
+ user_full_name=user_full_name,
+ model_schema=model_schema,
+ training_dataset=training_dataset,
+ input_example=input_example,
+ framework=MODEL.FRAMEWORK_TENSORFLOW,
+ model_registry_id=model_registry_id,
+ feature_view=feature_view,
+ training_dataset_version=training_dataset_version,
+ )
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ json_decamelized.pop("framework")
+ if "type" in json_decamelized: # backwards compatibility
+ _ = json_decamelized.pop("type")
+ self.__init__(**json_decamelized)
+ return self
diff --git a/hsml/python/hsml/tensorflow/predictor.py b/hsml/python/hsml/tensorflow/predictor.py
new file mode 100644
index 000000000..045aadf3a
--- /dev/null
+++ b/hsml/python/hsml/tensorflow/predictor.py
@@ -0,0 +1,33 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from hsml.constants import MODEL, PREDICTOR
+from hsml.predictor import Predictor
+
+
+class Predictor(Predictor):
+ """Configuration for a predictor running a tensorflow model."""
+
+ def __init__(self, **kwargs):
+ kwargs["model_framework"] = MODEL.FRAMEWORK_TENSORFLOW
+ kwargs["model_server"] = PREDICTOR.MODEL_SERVER_TF_SERVING
+
+ if kwargs["script_file"] is not None:
+ raise ValueError(
+ "Predictor scripts are not supported in deployments for Tensorflow models"
+ )
+
+ super().__init__(**kwargs)
diff --git a/hsml/python/hsml/tensorflow/signature.py b/hsml/python/hsml/tensorflow/signature.py
new file mode 100644
index 000000000..88b0f0fc4
--- /dev/null
+++ b/hsml/python/hsml/tensorflow/signature.py
@@ -0,0 +1,75 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional, Union
+
+import numpy
+import pandas
+from hsml.model_schema import ModelSchema
+from hsml.tensorflow.model import Model
+
+
+_mr = None
+
+
+def create_model(
+ name: str,
+ version: Optional[int] = None,
+ metrics: Optional[dict] = None,
+ description: Optional[str] = None,
+ input_example: Optional[
+ Union[pandas.DataFrame, pandas.Series, numpy.ndarray, list]
+ ] = None,
+ model_schema: Optional[ModelSchema] = None,
+ feature_view=None,
+ training_dataset_version: Optional[int] = None,
+):
+ """Create a TensorFlow model metadata object.
+
+ !!! note "Lazy"
+ This method is lazy and does not persist any metadata or uploads model artifacts in the
+ model registry on its own. To save the model object and the model artifacts, call the `save()` method with a
+ local file path to the directory containing the model artifacts.
+
+ # Arguments
+ name: Name of the model to create.
+ version: Optionally version of the model to create, defaults to `None` and
+ will create the model with incremented version from the last
+ version in the model registry.
+ metrics: Optionally a dictionary with model evaluation metrics (e.g., accuracy, MAE)
+ description: Optionally a string describing the model, defaults to empty string
+ `""`.
+ input_example: Optionally an input example that represents a single input for the model, defaults to `None`.
+ model_schema: Optionally a model schema for the model inputs and/or outputs.
+
+ # Returns
+ `Model`. The model metadata object.
+ """
+ model = Model(
+ id=None,
+ name=name,
+ version=version,
+ description=description,
+ metrics=metrics,
+ input_example=input_example,
+ model_schema=model_schema,
+ feature_view=feature_view,
+ training_dataset_version=training_dataset_version,
+ )
+ model._shared_registry_project_name = _mr.shared_registry_project_name
+ model._model_registry_id = _mr.model_registry_id
+
+ return model
diff --git a/hsml/python/hsml/torch/__init__.py b/hsml/python/hsml/torch/__init__.py
new file mode 100644
index 000000000..ff0a6f046
--- /dev/null
+++ b/hsml/python/hsml/torch/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/torch/model.py b/hsml/python/hsml/torch/model.py
new file mode 100644
index 000000000..59102119c
--- /dev/null
+++ b/hsml/python/hsml/torch/model.py
@@ -0,0 +1,79 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import humps
+from hsml.constants import MODEL
+from hsml.model import Model
+
+
+class Model(Model):
+ """Metadata object representing a torch model in the Model Registry."""
+
+ def __init__(
+ self,
+ id,
+ name,
+ version=None,
+ created=None,
+ creator=None,
+ environment=None,
+ description=None,
+ experiment_id=None,
+ project_name=None,
+ experiment_project_name=None,
+ metrics=None,
+ program=None,
+ user_full_name=None,
+ model_schema=None,
+ training_dataset=None,
+ input_example=None,
+ model_registry_id=None,
+ tags=None,
+ href=None,
+ feature_view=None,
+ training_dataset_version=None,
+ **kwargs,
+ ):
+ super().__init__(
+ id,
+ name,
+ version=version,
+ created=created,
+ creator=creator,
+ environment=environment,
+ description=description,
+ experiment_id=experiment_id,
+ project_name=project_name,
+ experiment_project_name=experiment_project_name,
+ metrics=metrics,
+ program=program,
+ user_full_name=user_full_name,
+ model_schema=model_schema,
+ training_dataset=training_dataset,
+ input_example=input_example,
+ framework=MODEL.FRAMEWORK_TORCH,
+ model_registry_id=model_registry_id,
+ feature_view=feature_view,
+ training_dataset_version=training_dataset_version,
+ )
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ json_decamelized.pop("framework")
+ if "type" in json_decamelized: # backwards compatibility
+ _ = json_decamelized.pop("type")
+ self.__init__(**json_decamelized)
+ return self
diff --git a/hsml/python/hsml/torch/predictor.py b/hsml/python/hsml/torch/predictor.py
new file mode 100644
index 000000000..5f7ed5d7a
--- /dev/null
+++ b/hsml/python/hsml/torch/predictor.py
@@ -0,0 +1,33 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from hsml.constants import MODEL, PREDICTOR
+from hsml.predictor import Predictor
+
+
+class Predictor(Predictor):
+ """Configuration for a predictor running a torch model."""
+
+ def __init__(self, **kwargs):
+ kwargs["model_framework"] = MODEL.FRAMEWORK_PYTHON
+ kwargs["model_server"] = PREDICTOR.MODEL_SERVER_PYTHON
+
+ if kwargs["script_file"] is None:
+ raise ValueError(
+ "Predictor scripts are required in deployments for Torch models"
+ )
+
+ super().__init__(**kwargs)
diff --git a/hsml/python/hsml/torch/signature.py b/hsml/python/hsml/torch/signature.py
new file mode 100644
index 000000000..32ab27d37
--- /dev/null
+++ b/hsml/python/hsml/torch/signature.py
@@ -0,0 +1,75 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import Optional, Union
+
+import numpy
+import pandas
+from hsml.model_schema import ModelSchema
+from hsml.torch.model import Model
+
+
+_mr = None
+
+
+def create_model(
+ name: str,
+ version: Optional[int] = None,
+ metrics: Optional[dict] = None,
+ description: Optional[str] = None,
+ input_example: Optional[
+ Union[pandas.DataFrame, pandas.Series, numpy.ndarray, list]
+ ] = None,
+ model_schema: Optional[ModelSchema] = None,
+ feature_view=None,
+ training_dataset_version: Optional[int] = None,
+):
+ """Create a Torch model metadata object.
+
+ !!! note "Lazy"
+ This method is lazy and does not persist any metadata or uploads model artifacts in the
+ model registry on its own. To save the model object and the model artifacts, call the `save()` method with a
+ local file path to the directory containing the model artifacts.
+
+ # Arguments
+ name: Name of the model to create.
+ version: Optionally version of the model to create, defaults to `None` and
+ will create the model with incremented version from the last
+ version in the model registry.
+ metrics: Optionally a dictionary with model evaluation metrics (e.g., accuracy, MAE)
+ description: Optionally a string describing the model, defaults to empty string
+ `""`.
+ input_example: Optionally an input example that represents a single input for the model, defaults to `None`.
+ model_schema: Optionally a model schema for the model inputs and/or outputs.
+
+ # Returns
+ `Model`. The model metadata object.
+ """
+ model = Model(
+ id=None,
+ name=name,
+ version=version,
+ description=description,
+ metrics=metrics,
+ input_example=input_example,
+ model_schema=model_schema,
+ feature_view=feature_view,
+ training_dataset_version=training_dataset_version,
+ )
+ model._shared_registry_project_name = _mr.shared_registry_project_name
+ model._model_registry_id = _mr.model_registry_id
+
+ return model
diff --git a/hsml/python/hsml/transformer.py b/hsml/python/hsml/transformer.py
new file mode 100644
index 000000000..4121d106d
--- /dev/null
+++ b/hsml/python/hsml/transformer.py
@@ -0,0 +1,93 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Optional, Union
+
+import humps
+from hsml import client, util
+from hsml.constants import RESOURCES
+from hsml.deployable_component import DeployableComponent
+from hsml.resources import TransformerResources
+
+
+class Transformer(DeployableComponent):
+ """Metadata object representing a transformer to be used in a predictor."""
+
+ def __init__(
+ self,
+ script_file: str,
+ resources: Optional[Union[TransformerResources, dict]] = None, # base
+ **kwargs,
+ ):
+ resources = (
+ self._validate_resources(
+ util.get_obj_from_json(resources, TransformerResources)
+ )
+ or self._get_default_resources()
+ )
+ if resources.num_instances is None:
+ resources.num_instances = self._get_default_num_instances()
+
+ super().__init__(script_file, resources)
+
+ def describe(self):
+ """Print a description of the transformer"""
+ util.pretty_print(self)
+
+ @classmethod
+ def _validate_resources(cls, resources):
+ if resources is not None:
+ # ensure scale-to-zero for kserve deployments when required
+ if resources.num_instances != 0 and client.is_scale_to_zero_required():
+ raise ValueError(
+ "Scale-to-zero is required for KServe deployments in this cluster. Please, set the number of transformer instances to 0."
+ )
+ return resources
+
+ @classmethod
+ def _get_default_num_instances(cls):
+ return (
+ 0 # enable scale-to-zero by default if required
+ if client.is_scale_to_zero_required()
+ else RESOURCES.MIN_NUM_INSTANCES
+ )
+
+ @classmethod
+ def _get_default_resources(cls):
+ return TransformerResources(cls._get_default_num_instances())
+
+ @classmethod
+ def from_json(cls, json_decamelized):
+ sf, rc = cls.extract_fields_from_json(json_decamelized)
+ return Transformer(sf, rc) if sf is not None else None
+
+ @classmethod
+ def extract_fields_from_json(cls, json_decamelized):
+ sf = util.extract_field_from_json(
+ json_decamelized, ["transformer", "script_file"]
+ )
+ rc = TransformerResources.from_json(json_decamelized)
+ return sf, rc
+
+ def update_from_response_json(self, json_dict):
+ json_decamelized = humps.decamelize(json_dict)
+ self.__init__(*self.extract_fields_from_json(json_decamelized))
+ return self
+
+ def to_dict(self):
+ return {"transformer": self._script_file, **self._resources.to_dict()}
+
+ def __repr__(self):
+ return f"Transformer({self._script_file!r})"
diff --git a/hsml/python/hsml/util.py b/hsml/python/hsml/util.py
new file mode 100644
index 000000000..96380b6f4
--- /dev/null
+++ b/hsml/python/hsml/util.py
@@ -0,0 +1,347 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import annotations
+
+import datetime
+import inspect
+import os
+import shutil
+from json import JSONEncoder, dumps
+from urllib.parse import urljoin, urlparse
+
+import humps
+import numpy as np
+import pandas as pd
+from hsml import client
+from hsml.constants import DEFAULT, MODEL, PREDICTOR
+from hsml.model import Model as BaseModel
+from hsml.predictor import Predictor as BasePredictor
+from hsml.python.model import Model as PyModel
+from hsml.python.predictor import Predictor as PyPredictor
+from hsml.sklearn.model import Model as SkLearnModel
+from hsml.sklearn.predictor import Predictor as SkLearnPredictor
+from hsml.tensorflow.model import Model as TFModel
+from hsml.tensorflow.predictor import Predictor as TFPredictor
+from hsml.torch.model import Model as TorchModel
+from hsml.torch.predictor import Predictor as TorchPredictor
+from six import string_types
+
+
+class VersionWarning(Warning):
+ pass
+
+
+class ProvenanceWarning(Warning):
+ pass
+
+
+class MLEncoder(JSONEncoder):
+ def default(self, obj):
+ try:
+ return obj.to_dict()
+ except AttributeError:
+ return super().default(obj)
+
+
+class NumpyEncoder(JSONEncoder):
+ """Special json encoder for numpy types.
+ Note that some numpy types doesn't have native python equivalence,
+ hence json.dumps will raise TypeError.
+ In this case, you'll need to convert your numpy types into its closest python equivalence.
+ """
+
+ def convert(self, obj):
+ import base64
+
+ import numpy as np
+ import pandas as pd
+
+ def encode_binary(x):
+ return base64.encodebytes(x).decode("ascii")
+
+ if isinstance(obj, np.ndarray):
+ if obj.dtype == np.object:
+ return [self.convert(x)[0] for x in obj.tolist()]
+ elif obj.dtype == np.bytes_:
+ return np.vectorize(encode_binary)(obj), True
+ else:
+ return obj.tolist(), True
+
+ if isinstance(obj, (pd.Timestamp, datetime.date)):
+ return obj.isoformat(), True
+ if isinstance(obj, bytes) or isinstance(obj, bytearray):
+ return encode_binary(obj), True
+ if isinstance(obj, np.generic):
+ return obj.item(), True
+ if isinstance(obj, np.datetime64):
+ return np.datetime_as_string(obj), True
+ return obj, False
+
+ def default(self, obj): # pylint: disable=E0202
+ res, converted = self.convert(obj)
+ if converted:
+ return res
+ else:
+ return super().default(obj)
+
+
+# Model registry
+
+# - schema and types
+
+
+def set_model_class(model):
+ if "href" in model:
+ _ = model.pop("href")
+ if "type" in model: # backwards compatibility
+ _ = model.pop("type")
+ if "tags" in model:
+ _ = model.pop("tags") # tags are always retrieved from backend
+
+ if "framework" not in model:
+ return BaseModel(**model)
+
+ framework = model.pop("framework")
+ if framework == MODEL.FRAMEWORK_TENSORFLOW:
+ return TFModel(**model)
+ if framework == MODEL.FRAMEWORK_TORCH:
+ return TorchModel(**model)
+ if framework == MODEL.FRAMEWORK_SKLEARN:
+ return SkLearnModel(**model)
+ elif framework == MODEL.FRAMEWORK_PYTHON:
+ return PyModel(**model)
+ else:
+ raise ValueError(
+ "framework {} is not a supported framework".format(str(framework))
+ )
+
+
+def input_example_to_json(input_example):
+ if isinstance(input_example, np.ndarray):
+ if input_example.size > 0:
+ return _handle_tensor_input(input_example)
+ else:
+ raise ValueError(
+ "input_example of type {} can not be empty".format(type(input_example))
+ )
+ elif isinstance(input_example, dict):
+ return _handle_dict_input(input_example)
+ else:
+ return _handle_dataframe_input(input_example)
+
+
+def _handle_tensor_input(input_tensor):
+ return input_tensor.tolist()
+
+
+def _handle_dataframe_input(input_ex):
+ if isinstance(input_ex, pd.DataFrame):
+ if not input_ex.empty:
+ return input_ex.iloc[0].tolist()
+ else:
+ raise ValueError(
+ "input_example of type {} can not be empty".format(type(input_ex))
+ )
+ elif isinstance(input_ex, pd.Series):
+ if not input_ex.empty:
+ return input_ex.tolist()
+ else:
+ raise ValueError(
+ "input_example of type {} can not be empty".format(type(input_ex))
+ )
+ elif isinstance(input_ex, list):
+ if len(input_ex) > 0:
+ return input_ex
+ else:
+ raise ValueError(
+ "input_example of type {} can not be empty".format(type(input_ex))
+ )
+ else:
+ raise TypeError(
+ "{} is not a supported input example type".format(type(input_ex))
+ )
+
+
+def _handle_dict_input(input_ex):
+ return input_ex
+
+
+# - artifacts
+
+
+def compress(archive_out_path, archive_name, path_to_archive):
+ if os.path.isdir(path_to_archive):
+ return shutil.make_archive(
+ os.path.join(archive_out_path, archive_name), "gztar", path_to_archive
+ )
+ else:
+ return shutil.make_archive(
+ os.path.join(archive_out_path, archive_name),
+ "gztar",
+ os.path.dirname(path_to_archive),
+ os.path.basename(path_to_archive),
+ )
+
+
+def decompress(archive_file_path, extract_dir=None):
+ return shutil.unpack_archive(archive_file_path, extract_dir=extract_dir)
+
+
+# - export models
+
+
+def validate_metrics(metrics):
+ if metrics is not None:
+ if not isinstance(metrics, dict):
+ raise TypeError(
+ "provided metrics is of instance {}, expected a dict".format(
+ type(metrics)
+ )
+ )
+
+ for metric in metrics:
+ # Validate key is a string
+ if not isinstance(metric, string_types):
+ raise TypeError(
+ "provided metrics key is of instance {}, expected a string".format(
+ type(metric)
+ )
+ )
+ # Validate value is a number
+ try:
+ float(metrics[metric])
+ except ValueError as err:
+ raise ValueError(
+ "{} is not a number, only numbers can be attached as metadata for models.".format(
+ str(metrics[metric])
+ )
+ ) from err
+
+
+# Model serving
+
+
+def get_predictor_for_model(model, **kwargs):
+ if not isinstance(model, BaseModel):
+ raise ValueError(
+ "model is of type {}, but an instance of {} class is expected".format(
+ type(model), BaseModel
+ )
+ )
+
+ if type(model) == TFModel:
+ return TFPredictor(**kwargs)
+ if type(model) == TorchModel:
+ return TorchPredictor(**kwargs)
+ if type(model) == SkLearnModel:
+ return SkLearnPredictor(**kwargs)
+ if type(model) == PyModel:
+ return PyPredictor(**kwargs)
+ if type(model) == BaseModel:
+ return BasePredictor( # python as default framework and model server
+ model_framework=MODEL.FRAMEWORK_PYTHON,
+ model_server=PREDICTOR.MODEL_SERVER_PYTHON,
+ **kwargs,
+ )
+
+
+def get_hostname_replaced_url(sub_path: str):
+ """
+ construct and return an url with public hopsworks hostname and sub path
+ :param self:
+ :param sub_path: url sub-path after base url
+ :return: href url
+ """
+ href = urljoin(client.get_instance()._base_url, sub_path)
+ url_parsed = client.get_instance()._replace_public_host(urlparse(href))
+ return url_parsed.geturl()
+
+
+# General
+
+
+def pretty_print(obj):
+ if isinstance(obj, list):
+ for logs in obj:
+ pretty_print(logs)
+ else:
+ json_decamelized = humps.decamelize(obj.to_dict())
+ print(dumps(json_decamelized, indent=4, sort_keys=True))
+
+
+def get_members(cls, prefix=None):
+ for m in inspect.getmembers(cls, lambda m: not (inspect.isroutine(m))):
+ n = m[0] # name
+ if (prefix is not None and n.startswith(prefix)) or (
+ prefix is None and not (n.startswith("__") and n.endswith("__"))
+ ):
+ yield m[1] # value
+
+
+# - json
+
+
+def extract_field_from_json(obj, fields, default=None, as_instance_of=None):
+ if isinstance(fields, list):
+ for field in fields:
+ value = extract_field_from_json(obj, field, default, as_instance_of)
+ if value is not None:
+ break
+ else:
+ value = obj.pop(fields) if fields in obj else default
+ if as_instance_of is not None:
+ if isinstance(value, list):
+ # if the field is a list, get all obj
+ value = [
+ get_obj_from_json(obj=subvalue, cls=as_instance_of)
+ for subvalue in value
+ ]
+ else:
+ # otherwise, get single obj
+ value = get_obj_from_json(obj=value, cls=as_instance_of)
+ return value
+
+
+def get_obj_from_json(obj, cls):
+ if obj is not None:
+ if isinstance(obj, cls):
+ return obj
+ if isinstance(obj, dict):
+ if obj is DEFAULT:
+ return cls()
+ return cls.from_json(obj)
+ raise ValueError(
+ "Object of type {} cannot be converted to class {}".format(type(obj), cls)
+ )
+ return obj
+
+
+def feature_view_to_json(obj):
+ if obj is None:
+ return None
+ import importlib.util
+
+ if importlib.util.find_spec("hsfs"):
+ from hsfs import feature_view
+
+ if isinstance(obj, feature_view.FeatureView):
+ import json
+
+ import humps
+
+ return humps.camelize(json.loads(obj.json()))
+ return None
diff --git a/hsml/python/hsml/utils/__init__.py b/hsml/python/hsml/utils/__init__.py
new file mode 100644
index 000000000..7fa8fd556
--- /dev/null
+++ b/hsml/python/hsml/utils/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/utils/schema/__init__.py b/hsml/python/hsml/utils/schema/__init__.py
new file mode 100644
index 000000000..7fa8fd556
--- /dev/null
+++ b/hsml/python/hsml/utils/schema/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/hsml/utils/schema/column.py b/hsml/python/hsml/utils/schema/column.py
new file mode 100644
index 000000000..fa5fc3723
--- /dev/null
+++ b/hsml/python/hsml/utils/schema/column.py
@@ -0,0 +1,28 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+class Column:
+ """Metadata object representing a column in the schema for a model."""
+
+ def __init__(self, type, name=None, description=None):
+ self.type = str(type)
+
+ if name is not None:
+ self.name = str(name)
+
+ if description is not None:
+ self.description = str(description)
diff --git a/hsml/python/hsml/utils/schema/columnar_schema.py b/hsml/python/hsml/utils/schema/columnar_schema.py
new file mode 100644
index 000000000..3aa5fde0e
--- /dev/null
+++ b/hsml/python/hsml/utils/schema/columnar_schema.py
@@ -0,0 +1,109 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import importlib
+
+import pandas
+from hsml.utils.schema.column import Column
+
+
+try:
+ import hsfs
+except ImportError:
+ pass
+
+try:
+ import pyspark
+except ImportError:
+ pass
+
+
+class ColumnarSchema:
+ """Metadata object representing a columnar schema for a model."""
+
+ def __init__(self, columnar_obj=None):
+ if isinstance(columnar_obj, list):
+ self.columns = self._convert_list_to_schema(columnar_obj)
+ elif isinstance(columnar_obj, pandas.DataFrame):
+ self.columns = self._convert_pandas_df_to_schema(columnar_obj)
+ elif isinstance(columnar_obj, pandas.Series):
+ self.columns = self._convert_pandas_series_to_schema(columnar_obj)
+ elif importlib.util.find_spec("pyspark") is not None and isinstance(
+ columnar_obj, pyspark.sql.dataframe.DataFrame
+ ):
+ self.columns = self._convert_spark_to_schema(columnar_obj)
+ elif importlib.util.find_spec("hsfs") is not None and isinstance(
+ columnar_obj, hsfs.training_dataset.TrainingDataset
+ ):
+ self.columns = self._convert_td_to_schema(columnar_obj)
+ else:
+ raise TypeError(
+ "{} is not supported in a columnar schema.".format(type(columnar_obj))
+ )
+
+ def _convert_list_to_schema(self, columnar_obj):
+ columns = []
+ for column in columnar_obj:
+ columns.append(self._build_column(column))
+ return columns
+
+ def _convert_pandas_df_to_schema(self, pandas_df):
+ pandas_columns = pandas_df.columns
+ pandas_data_types = pandas_df.dtypes
+ columns = []
+ for name in pandas_columns:
+ columns.append(Column(pandas_data_types[name], name=name))
+ return columns
+
+ def _convert_pandas_series_to_schema(self, pandas_series):
+ columns = []
+ columns.append(Column(pandas_series.dtype, name=pandas_series.name))
+ return columns
+
+ def _convert_spark_to_schema(self, spark_df):
+ columns = []
+ types = spark_df.dtypes
+ for dtype in types:
+ name, dtype = dtype
+ columns.append(Column(dtype, name=name))
+ return columns
+
+ def _convert_td_to_schema(self, td):
+ columns = []
+ features = td.schema
+ for feature in features:
+ columns.append(Column(feature.type, name=feature.name))
+ return columns
+
+ def _build_column(self, columnar_obj):
+ type = None
+ name = None
+ description = None
+
+ if "description" in columnar_obj:
+ description = columnar_obj["description"]
+
+ if "name" in columnar_obj:
+ name = columnar_obj["name"]
+
+ if "type" in columnar_obj:
+ type = columnar_obj["type"]
+ else:
+ raise ValueError(
+ "Mandatory 'type' key missing from entry {}".format(columnar_obj)
+ )
+
+ return Column(type, name=name, description=description)
diff --git a/hsml/python/hsml/utils/schema/tensor.py b/hsml/python/hsml/utils/schema/tensor.py
new file mode 100644
index 000000000..2722776b9
--- /dev/null
+++ b/hsml/python/hsml/utils/schema/tensor.py
@@ -0,0 +1,30 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+class Tensor:
+ """Metadata object representing a tensor in the schema for a model."""
+
+ def __init__(self, type, shape, name=None, description=None):
+ self.type = str(type)
+
+ self.shape = str(shape)
+
+ if name is not None:
+ self.name = str(name)
+
+ if description is not None:
+ self.description = str(description)
diff --git a/hsml/python/hsml/utils/schema/tensor_schema.py b/hsml/python/hsml/utils/schema/tensor_schema.py
new file mode 100644
index 000000000..da24ba836
--- /dev/null
+++ b/hsml/python/hsml/utils/schema/tensor_schema.py
@@ -0,0 +1,73 @@
+#
+# Copyright 2022 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy
+from hsml.utils.schema.tensor import Tensor
+
+
+class TensorSchema:
+ """Metadata object representing a tensor schema for a model."""
+
+ def __init__(self, tensor_obj=None):
+ if isinstance(tensor_obj, list):
+ self.tensors = self._convert_list_to_schema(tensor_obj)
+ elif isinstance(tensor_obj, numpy.ndarray):
+ self.tensors = self._convert_tensor_to_schema(tensor_obj)
+ else:
+ raise TypeError(
+ "{} is not supported in a tensor schema.".format(type(tensor_obj))
+ )
+
+ def _convert_tensor_to_schema(self, tensor_obj):
+ return Tensor(tensor_obj.dtype, tensor_obj.shape)
+
+ def _convert_list_to_schema(self, tensor_obj):
+ if len(tensor_obj) == 1:
+ return [self._build_tensor(tensor_obj[0])]
+ else:
+ tensors = []
+ for tensor in tensor_obj:
+ tensors.append(self._build_tensor(tensor))
+ return tensors
+
+ def _build_tensor(self, tensor_obj):
+ name = None
+ type = None
+ shape = None
+ description = None
+
+ # Name is optional
+ if "name" in tensor_obj:
+ name = tensor_obj["name"]
+
+ if "description" in tensor_obj:
+ description = tensor_obj["description"]
+
+ if "type" in tensor_obj:
+ type = tensor_obj["type"]
+ else:
+ raise ValueError(
+ "Mandatory 'type' key missing from entry {}".format(tensor_obj)
+ )
+
+ if "shape" in tensor_obj:
+ shape = tensor_obj["shape"]
+ else:
+ raise ValueError(
+ "Mandatory 'shape' key missing from entry {}".format(tensor_obj)
+ )
+
+ return Tensor(type, shape, name=name, description=description)
diff --git a/hsml/python/hsml/version.py b/hsml/python/hsml/version.py
new file mode 100644
index 000000000..a7136ad06
--- /dev/null
+++ b/hsml/python/hsml/version.py
@@ -0,0 +1,17 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+__version__ = "4.0.0.dev1"
diff --git a/hsml/python/pyproject.toml b/hsml/python/pyproject.toml
new file mode 100644
index 000000000..e4770cd4a
--- /dev/null
+++ b/hsml/python/pyproject.toml
@@ -0,0 +1,136 @@
+[project]
+name="hsml"
+dynamic = ["version"]
+requires-python = ">=3.8,<3.13"
+readme = "README.md"
+description = "HSML Python SDK to interact with Hopsworks Model Registry"
+keywords = ["Hopsworks", "Model Registry", "hsml", "Models", "ML", "Machine Learning Models", "TensorFlow", "PyTorch", "Machine Learning", "MLOps", "DataOps"]
+authors = [{name = "Hopsworks AB", email = "robin@hopswors.ai"}]
+license = { text = "Apache-2.0" }
+
+classifiers = [
+ "Development Status :: 5 - Production/Stable",
+ "Topic :: Utilities",
+ "License :: OSI Approved :: Apache Software License",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+ "Intended Audience :: Developers",
+]
+
+dependencies = [
+ "pyhumps==1.6.1",
+ "requests",
+ "furl",
+ "boto3",
+ "pandas",
+ "numpy",
+ "pyjks",
+ "mock",
+ "tqdm",
+ "grpcio>=1.49.1,<2.0.0", # ^1.49.1
+ "protobuf>=3.19.0,<4.0.0", # ^3.19.0
+]
+
+[project.optional-dependencies]
+dev = ["pytest==7.4.4", "pytest-mock==3.12.0", "ruff"]
+
+[build-system]
+requires = ["setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
+
+
+[tool.setuptools.packages.find]
+exclude = ["tests*"]
+include = ["../Readme.md", "../LICENSE", "hsml", "hsml.*"]
+
+
+[tool.setuptools.dynamic]
+version = {attr = "hsml.version.__version__"}
+
+[project.urls]
+Documentation = "https://docs.hopsworks.ai/latest"
+Repository = "https://github.com/logicalclocks/machine-learning-api"
+Homepage = "https://www.hopsworks.ai"
+Community = "https://community.hopsworks.ai"
+
+
+[tool.ruff]
+# Exclude a variety of commonly ignored directories.
+exclude = [
+ ".bzr",
+ ".direnv",
+ ".eggs",
+ ".git",
+ ".git-rewrite",
+ ".hg",
+ ".ipynb_checkpoints",
+ ".mypy_cache",
+ ".nox",
+ ".pants.d",
+ ".pyenv",
+ ".pytest_cache",
+ ".pytype",
+ ".ruff_cache",
+ ".svn",
+ ".tox",
+ ".venv",
+ ".vscode",
+ "__pypackages__",
+ "_build",
+ "buck-out",
+ "build",
+ "dist",
+ "node_modules",
+ "site-packages",
+ "venv",
+ "java",
+]
+
+# Same as Black.
+line-length = 88
+indent-width = 4
+
+# Assume Python 3.8+ syntax.
+target-version = "py38"
+
+[tool.ruff.lint]
+# 1. Enable flake8-bugbear (`B`) rules, in addition to the defaults.
+select = ["E4", "E7", "E9", "F", "B", "I", "W"]#, "ANN"]
+ignore = [
+ "B905", # zip has no strict kwarg until Python 3.10
+ "ANN101", # Missing type annotation for self in method
+ "ANN102", # Missing type annotation for cls in classmethod
+ "ANN003", # Missing type annotation for **kwarg in function
+ "ANN002", # Missing type annotation for *args in function
+ "ANN401", # Allow Any in type annotations
+ "W505", # Doc line too long
+]
+
+# Allow fix for all enabled rules (when `--fix`) is provided.
+fixable = ["ALL"]
+unfixable = []
+
+# Allow unused variables when underscore-prefixed.
+dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+known-third-party = ["hopsworks", "hsfs", "hsml"]
+
+
+[tool.ruff.format]
+# Like Black, use double quotes for strings.
+quote-style = "double"
+
+# Like Black, indent with spaces, rather than tabs.
+indent-style = "space"
+
+# Like Black, respect magic trailing commas.
+skip-magic-trailing-comma = false
+
+# Like Black, automatically detect the appropriate line ending.
+line-ending = "auto"
diff --git a/hsml/python/setup.py b/hsml/python/setup.py
new file mode 100644
index 000000000..cb916d7e6
--- /dev/null
+++ b/hsml/python/setup.py
@@ -0,0 +1,19 @@
+#
+# Copyright 2021 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from setuptools import setup
+
+
+setup()
diff --git a/hsml/python/tests/__init__.py b/hsml/python/tests/__init__.py
new file mode 100644
index 000000000..5b0cd48e7
--- /dev/null
+++ b/hsml/python/tests/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2024 Logical Clocks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/tests/conftest.py b/hsml/python/tests/conftest.py
new file mode 100644
index 000000000..00d23a9fc
--- /dev/null
+++ b/hsml/python/tests/conftest.py
@@ -0,0 +1,20 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+pytest_plugins = [
+ "tests.fixtures.backend_fixtures",
+ "tests.fixtures.model_fixtures",
+]
diff --git a/hsml/python/tests/fixtures/__init__.py b/hsml/python/tests/fixtures/__init__.py
new file mode 100644
index 000000000..ff8055b9b
--- /dev/null
+++ b/hsml/python/tests/fixtures/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/tests/fixtures/backend_fixtures.py b/hsml/python/tests/fixtures/backend_fixtures.py
new file mode 100644
index 000000000..c79bc6ddb
--- /dev/null
+++ b/hsml/python/tests/fixtures/backend_fixtures.py
@@ -0,0 +1,45 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import json
+import os
+
+import pytest
+
+
+FIXTURES_DIR = os.path.dirname(os.path.abspath(__file__))
+
+FIXTURES = [
+ "tag",
+ "model",
+ "resources",
+ "transformer",
+ "predictor",
+ "kafka_topic",
+ "inference_logger",
+ "inference_batcher",
+ "inference_endpoint",
+]
+
+backend_fixtures_json = {}
+for fixture in FIXTURES:
+ with open(os.path.join(FIXTURES_DIR, f"{fixture}_fixtures.json"), "r") as json_file:
+ backend_fixtures_json[fixture] = json.load(json_file)
+
+
+@pytest.fixture
+def backend_fixtures():
+ return backend_fixtures_json
diff --git a/hsml/python/tests/fixtures/inference_batcher_fixtures.json b/hsml/python/tests/fixtures/inference_batcher_fixtures.json
new file mode 100644
index 000000000..0fc8bbc15
--- /dev/null
+++ b/hsml/python/tests/fixtures/inference_batcher_fixtures.json
@@ -0,0 +1,54 @@
+{
+ "get_enabled": {
+ "response": {
+ "enabled": true
+ },
+ "response_nested": {
+ "batching_configuration": {
+ "enabled": true
+ }
+ }
+ },
+ "get_disabled": {
+ "response": {
+ "enabled": false
+ },
+ "response_nested": {
+ "batching_configuration": {
+ "enabled": true
+ }
+ }
+ },
+ "get_enabled_with_config": {
+ "response": {
+ "enabled": true,
+ "max_batch_size": 1,
+ "max_latency": 2,
+ "timeout": 3
+ },
+ "response_nested": {
+ "batching_configuration": {
+ "enabled": true,
+ "max_batch_size": 1,
+ "max_latency": 2,
+ "timeout": 3
+ }
+ }
+ },
+ "get_disabled_with_config": {
+ "response": {
+ "enabled": false,
+ "max_batch_size": 1,
+ "max_latency": 2,
+ "timeout": 3
+ },
+ "response_nested": {
+ "batching_configuration": {
+ "enabled": false,
+ "max_batch_size": 1,
+ "max_latency": 2,
+ "timeout": 3
+ }
+ }
+ }
+}
diff --git a/hsml/python/tests/fixtures/inference_endpoint_fixtures.json b/hsml/python/tests/fixtures/inference_endpoint_fixtures.json
new file mode 100644
index 000000000..7dce25daf
--- /dev/null
+++ b/hsml/python/tests/fixtures/inference_endpoint_fixtures.json
@@ -0,0 +1,66 @@
+{
+ "get_port": {
+ "response": {
+ "name": "port_name",
+ "number": 12345
+ }
+ },
+ "get_empty": {
+ "response": {
+ "count": 0,
+ "items": []
+ }
+ },
+ "get_singleton": {
+ "response": {
+ "count": 1,
+ "items": [
+ {
+ "type": "LOAD_BALANCER",
+ "hosts": ["host1", "host2", "host3"],
+ "ports": [
+ {
+ "name": "port_name",
+ "number": 12345
+ }
+ ]
+ }
+ ]
+ }
+ },
+ "get_list": {
+ "response": {
+ "count": 2,
+ "items": [
+ {
+ "type": "LOAD_BALANCER",
+ "hosts": ["host1", "host2"],
+ "ports": [
+ {
+ "name": "port_name",
+ "number": 12345
+ }
+ ]
+ },
+ {
+ "type": "NODE",
+ "hosts": 54321,
+ "ports": [
+ {
+ "name": "port_name",
+ "number": 12345
+ },
+ {
+ "name": "port_name_2",
+ "number": 54321
+ },
+ {
+ "name": "port_name_3",
+ "number": 15243
+ }
+ ]
+ }
+ ]
+ }
+ }
+}
diff --git a/hsml/python/tests/fixtures/inference_logger_fixtures.json b/hsml/python/tests/fixtures/inference_logger_fixtures.json
new file mode 100644
index 000000000..ae22a8c44
--- /dev/null
+++ b/hsml/python/tests/fixtures/inference_logger_fixtures.json
@@ -0,0 +1,96 @@
+{
+ "get_mode_all": {
+ "response": {
+ "inference_logging": "ALL"
+ },
+ "init_args": {
+ "mode": "ALL"
+ }
+ },
+ "get_mode_inputs": {
+ "response": {
+ "inference_logging": "MODEL_INPUTS"
+ },
+ "init_args": {
+ "mode": "MODEL_INPUTS"
+ }
+ },
+ "get_mode_outputs": {
+ "response": {
+ "inference_logging": "MODEL_OUTPUTS"
+ },
+ "init_args": {
+ "mode": "MODEL_OUTPUTS"
+ }
+ },
+ "get_mode_none": {
+ "response": {
+ "inference_logging": "NONE"
+ },
+ "init_args": {
+ "mode": "NONE"
+ }
+ },
+ "get_mode_all_with_kafka_topic": {
+ "response": {
+ "inference_logging": "ALL",
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ },
+ "init_args": {
+ "mode": "ALL",
+ "kafka_topic": {
+ "name": "topic"
+ }
+ }
+ },
+ "get_mode_inputs_with_kafka_topic": {
+ "response": {
+ "inference_logging": "MODEL_INPUTS",
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ },
+ "init_args": {
+ "mode": "MODEL_INPUTS",
+ "kafka_topic": {
+ "name": "topic",
+ "num_replicas": 1,
+ "num_partitions": 2
+ }
+ }
+ },
+ "get_mode_outputs_with_kafka_topic": {
+ "response": {
+ "inference_logging": "MODEL_OUTPUTS",
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ },
+ "init_args": {
+ "mode": "MODEL_OUTPUTS",
+ "kafka_topic": {
+ "name": "topic",
+ "num_replicas": 1,
+ "num_partitions": 2
+ }
+ }
+ },
+ "get_mode_none_with_kafka_topic": {
+ "response": {
+ "inference_logging": "NONE",
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ },
+ "init_args": {
+ "mode": "NONE",
+ "kafka_topic": {
+ "name": "topic",
+ "num_replicas": 1,
+ "num_partitions": 2
+ }
+ }
+ }
+}
diff --git a/hsml/python/tests/fixtures/kafka_topic_fixtures.json b/hsml/python/tests/fixtures/kafka_topic_fixtures.json
new file mode 100644
index 000000000..f69d1e567
--- /dev/null
+++ b/hsml/python/tests/fixtures/kafka_topic_fixtures.json
@@ -0,0 +1,59 @@
+{
+ "get_existing_with_name_only": {
+ "response": {
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ }
+ },
+ "get_existing_with_name_and_config": {
+ "response": {
+ "kafka_topic_dto": {
+ "name": "topic",
+ "num_replicas": 1,
+ "num_partitions": 2
+ }
+ }
+ },
+ "get_existing_with_name_and_config_alternative": {
+ "response": {
+ "kafka_topic_dto": {
+ "name": "topic",
+ "num_of_replicas": 1,
+ "num_of_partitions": 2
+ }
+ }
+ },
+ "get_none": {
+ "response": {
+ "kafka_topic_dto": {
+ "name": "NONE"
+ }
+ }
+ },
+ "get_none_with_config": {
+ "response": {
+ "kafka_topic_dto": {
+ "name": "NONE",
+ "num_replicas": 1,
+ "num_partitions": 2
+ }
+ }
+ },
+ "get_create_with_name_only": {
+ "response": {
+ "kafka_topic_dto": {
+ "name": "CREATE"
+ }
+ }
+ },
+ "get_create_with_name_and_config": {
+ "response": {
+ "kafka_topic_dto": {
+ "name": "CREATE",
+ "num_replicas": 1,
+ "num_partitions": 2
+ }
+ }
+ }
+}
diff --git a/hsml/python/tests/fixtures/model_fixtures.json b/hsml/python/tests/fixtures/model_fixtures.json
new file mode 100644
index 000000000..79f31ed72
--- /dev/null
+++ b/hsml/python/tests/fixtures/model_fixtures.json
@@ -0,0 +1,203 @@
+{
+ "get_base": {
+ "response": {
+ "count": 1,
+ "items": [
+ {
+ "id": "0",
+ "name": "basemodel",
+ "version": 0,
+ "created": "created",
+ "creator": "creator",
+ "environment": "environment.yml",
+ "description": "description",
+ "experiment_id": 1,
+ "project_name": "myproject",
+ "experiment_project_name": "myexperimentproject",
+ "metrics": { "acc": 0.7 },
+ "program": "program",
+ "user_full_name": "Full Name",
+ "model_schema": "model_schema.json",
+ "training_dataset": "training_dataset",
+ "input_example": "input_example.json",
+ "model_registry_id": 1,
+ "tags": [],
+ "href": "test_href"
+ }
+ ]
+ }
+ },
+ "get_python": {
+ "response": {
+ "count": 1,
+ "items": [
+ {
+ "id": "1",
+ "name": "pythonmodel",
+ "version": 0,
+ "created": "created",
+ "creator": "creator",
+ "environment": "environment.yml",
+ "description": "description",
+ "experiment_id": 1,
+ "project_name": "myproject",
+ "experiment_project_name": "myexperimentproject",
+ "metrics": { "acc": 0.7 },
+ "program": "program",
+ "user_full_name": "Full Name",
+ "model_schema": "model_schema.json",
+ "training_dataset": "training_dataset",
+ "input_example": "input_example.json",
+ "model_registry_id": 1,
+ "tags": [],
+ "framework": "PYTHON",
+ "href": "test_href"
+ }
+ ]
+ }
+ },
+ "get_sklearn": {
+ "response": {
+ "count": 1,
+ "items": [
+ {
+ "id": "2",
+ "name": "sklearnmodel",
+ "version": 0,
+ "created": "created",
+ "creator": "creator",
+ "environment": "environment.yml",
+ "description": "description",
+ "experiment_id": 1,
+ "project_name": "myproject",
+ "experiment_project_name": "myexperimentproject",
+ "metrics": { "acc": 0.7 },
+ "program": "program",
+ "user_full_name": "Full Name",
+ "model_schema": "model_schema.json",
+ "training_dataset": "training_dataset",
+ "input_example": "input_example.json",
+ "model_registry_id": 1,
+ "tags": [],
+ "framework": "SKLEARN",
+ "href": "test_href"
+ }
+ ]
+ }
+ },
+ "get_tensorflow": {
+ "response": {
+ "count": 1,
+ "items": [
+ {
+ "id": "3",
+ "name": "tensorflowmodel",
+ "version": 0,
+ "created": "created",
+ "creator": "creator",
+ "environment": "environment.yml",
+ "description": "description",
+ "experiment_id": 1,
+ "project_name": "myproject",
+ "experiment_project_name": "myexperimentproject",
+ "metrics": { "acc": 0.7 },
+ "program": "program",
+ "user_full_name": "Full Name",
+ "model_schema": "model_schema.json",
+ "training_dataset": "training_dataset",
+ "input_example": "input_example.json",
+ "model_registry_id": 1,
+ "tags": [],
+ "framework": "TENSORFLOW",
+ "href": "test_href"
+ }
+ ]
+ }
+ },
+ "get_torch": {
+ "response": {
+ "count": 1,
+ "items": [
+ {
+ "id": "4",
+ "name": "torchmodel",
+ "version": 0,
+ "created": "created",
+ "creator": "creator",
+ "environment": "environment.yml",
+ "description": "description",
+ "experiment_id": 1,
+ "project_name": "myproject",
+ "experiment_project_name": "myexperimentproject",
+ "metrics": { "acc": 0.7 },
+ "program": "program",
+ "user_full_name": "Full Name",
+ "model_schema": "model_schema.json",
+ "training_dataset": "training_dataset",
+ "input_example": "input_example.json",
+ "model_registry_id": 1,
+ "tags": [],
+ "framework": "TORCH",
+ "href": "test_href"
+ }
+ ]
+ }
+ },
+ "get_list": {
+ "response": {
+ "count": 2,
+ "items": [
+ {
+ "id": "1",
+ "name": "pythonmodel",
+ "version": 0,
+ "created": "created",
+ "creator": "creator",
+ "environment": "environment.yml",
+ "description": "description",
+ "experiment_id": 1,
+ "project_name": "myproject",
+ "experiment_project_name": "myexperimentproject",
+ "metrics": { "acc": 0.7 },
+ "program": "program",
+ "user_full_name": "Full Name",
+ "model_schema": "model_schema.json",
+ "training_dataset": "training_dataset",
+ "input_example": "input_example.json",
+ "model_registry_id": 1,
+ "tags": [],
+ "framework": "PYTHON",
+ "href": "test_href"
+ },
+ {
+ "id": "2",
+ "name": "pythonmodel",
+ "version": 0,
+ "created": "created",
+ "creator": "creator",
+ "environment": "environment.yml",
+ "description": "description",
+ "experiment_id": 2,
+ "project_name": "myproject",
+ "experiment_project_name": "myexperimentproject",
+ "metrics": { "acc": 0.7 },
+ "program": "program",
+ "user_full_name": "Full Name",
+ "model_schema": "model_schema.json",
+ "training_dataset": "training_dataset",
+ "input_example": "input_example.json",
+ "model_registry_id": 1,
+ "tags": [],
+ "framework": "PYTHON",
+ "href": "test_href"
+ }
+ ]
+ }
+ },
+ "get_empty": {
+ "response": {
+ "count": 0,
+ "items": []
+ }
+ }
+}
diff --git a/hsml/python/tests/fixtures/model_fixtures.py b/hsml/python/tests/fixtures/model_fixtures.py
new file mode 100644
index 000000000..32fe396de
--- /dev/null
+++ b/hsml/python/tests/fixtures/model_fixtures.py
@@ -0,0 +1,125 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+import pandas as pd
+import pytest
+from hsml.model import Model as BaseModel
+from hsml.python.model import Model as PythonModel
+from hsml.sklearn.model import Model as SklearnModel
+from hsml.tensorflow.model import Model as TensorflowModel
+from hsml.torch.model import Model as TorchModel
+
+
+MODEL_BASE_ID = 0
+MODEL_PYTHON_ID = 1
+MODEL_SKLEARN_ID = 2
+MODEL_TENSORFLOW_ID = 3
+MODEL_TORCH_ID = 4
+
+MODEL_BASE_NAME = "basemodel"
+MODEL_PYTHON_NAME = "pythonmodel"
+MODEL_SKLEARN_NAME = "sklearnmodel"
+MODEL_TENSORFLOW_NAME = "tensorflowmodel"
+MODEL_TORCH_NAME = "torchmodel"
+
+# models
+
+
+@pytest.fixture
+def model_base():
+ return BaseModel(MODEL_BASE_ID, MODEL_BASE_NAME)
+
+
+@pytest.fixture
+def model_python():
+ return PythonModel(MODEL_PYTHON_ID, MODEL_PYTHON_NAME)
+
+
+@pytest.fixture
+def model_sklearn():
+ return SklearnModel(MODEL_SKLEARN_ID, MODEL_SKLEARN_NAME)
+
+
+@pytest.fixture
+def model_tensorflow():
+ return TensorflowModel(MODEL_TENSORFLOW_ID, MODEL_TENSORFLOW_NAME)
+
+
+@pytest.fixture
+def model_torch():
+ return TorchModel(MODEL_TORCH_ID, MODEL_TORCH_NAME)
+
+
+# input example
+
+
+@pytest.fixture
+def input_example_numpy():
+ return np.array([1, 2, 3, 4])
+
+
+@pytest.fixture
+def input_example_dict():
+ return {"instances": [[1, 2, 3, 4]]}
+
+
+@pytest.fixture
+def input_example_dataframe_pandas_dataframe():
+ return pd.DataFrame({"a": [1], "b": [2], "c": [3], "d": [4]})
+
+
+@pytest.fixture
+def input_example_dataframe_pandas_dataframe_empty():
+ return pd.DataFrame()
+
+
+@pytest.fixture
+def input_example_dataframe_pandas_series():
+ return pd.Series([1, 2, 3, 4])
+
+
+@pytest.fixture
+def input_example_dataframe_pandas_series_empty():
+ return pd.Series()
+
+
+@pytest.fixture
+def input_example_dataframe_list():
+ return [1, 2, 3, 4]
+
+
+# metrics
+
+
+@pytest.fixture
+def model_metrics():
+ return {"accuracy": 0.4, "rmse": 0.6}
+
+
+@pytest.fixture
+def model_metrics_wrong_type():
+ return [0.4, 0.6]
+
+
+@pytest.fixture
+def model_metrics_wrong_metric_type():
+ return {1: 0.4, 2: 0.6}
+
+
+@pytest.fixture
+def model_metrics_wrong_metric_value():
+ return {"accuracy": "non-number", "rmse": 0.4}
diff --git a/hsml/python/tests/fixtures/predictor_fixtures.json b/hsml/python/tests/fixtures/predictor_fixtures.json
new file mode 100644
index 000000000..b0b7b2fcc
--- /dev/null
+++ b/hsml/python/tests/fixtures/predictor_fixtures.json
@@ -0,0 +1,400 @@
+{
+ "get_deployments_singleton": {
+ "response": {
+ "count": 1,
+ "items": [
+ {
+ "id": 1,
+ "name": "test",
+ "description": "test_desc",
+ "created": "",
+ "creator": "",
+ "model_path": "test_model_path",
+ "model_name": "test_model_name",
+ "model_version": 1,
+ "model_framework": "PYTHON",
+ "model_server": "PYTHON",
+ "serving_tool": "KSERVE",
+ "api_protocol": "REST",
+ "artifact_version": 2,
+ "predictor": "predictor_file",
+ "transformer": "transformer_file",
+ "requested_instances": 1,
+ "requested_transformer_instances": 1,
+ "predictor_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "transformer_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "batching_configuration": {
+ "batching_enabled": true,
+ "max_batch_size": 1000,
+ "max_latency": 1000,
+ "timeout": 1000
+ },
+ "inference_logging": "ALL",
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ }
+ ]
+ }
+ },
+ "get_deployments_empty": {
+ "response": {
+ "count": 0,
+ "items": []
+ }
+ },
+ "get_deployments_list": {
+ "response": {
+ "count": 2,
+ "items": [
+ {
+ "id": 1,
+ "name": "test",
+ "description": "test_desc",
+ "created": "",
+ "creator": "",
+ "model_path": "test_model_path",
+ "model_name": "test_model_name",
+ "model_version": 1,
+ "model_framework": "PYTHON",
+ "model_server": "PYTHON",
+ "serving_tool": "KSERVE",
+ "api_protocol": "REST",
+ "artifact_version": 2,
+ "predictor": "predictor_file",
+ "transformer": "transformer_file",
+ "requested_instances": 1,
+ "requested_transformer_instances": 1,
+ "predictor_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "transformer_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "batching_configuration": {
+ "batching_enabled": true,
+ "max_batch_size": 1000,
+ "max_latency": 1000,
+ "timeout": 1000
+ },
+ "inference_logging": "ALL",
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ },
+ {
+ "id": 2,
+ "name": "test_2",
+ "description": "test_desc_2",
+ "created": "",
+ "creator": "",
+ "model_path": "test_model_path",
+ "model_name": "test_model_name",
+ "model_version": 2,
+ "model_framework": "PYTHON",
+ "model_server": "PYTHON",
+ "serving_tool": "KSERVE",
+ "api_protocol": "REST",
+ "artifact_version": 3,
+ "predictor": "predictor_file",
+ "transformer": "transformer_file",
+ "requested_instances": 1,
+ "requested_transformer_instances": 1,
+ "predictor_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "transformer_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "batching_configuration": {
+ "batching_enabled": true,
+ "max_batch_size": 1000,
+ "max_latency": 1000,
+ "timeout": 1000
+ },
+ "inference_logging": "ALL",
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ }
+ ]
+ }
+ },
+ "get_deployment_tf_kserve_rest": {
+ "response": {
+ "id": 1,
+ "name": "test",
+ "description": "test_desc",
+ "created": "",
+ "creator": "",
+ "model_path": "test_model_path",
+ "model_name": "test_model_name",
+ "model_version": 1,
+ "model_framework": "TENSORFLOW",
+ "model_server": "TENSORFLOW_SERVING",
+ "serving_tool": "KSERVE",
+ "api_protocol": "REST",
+ "artifact_version": 2,
+ "requested_instances": 1,
+ "predictor_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "inference_logging": "ALL",
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ }
+ },
+ "get_deployment_tf_kserve_rest_trans": {
+ "response": {
+ "id": 1,
+ "name": "test",
+ "description": "test_desc",
+ "created": "",
+ "creator": "",
+ "model_path": "test_model_path",
+ "model_name": "test_model_name",
+ "model_version": 1,
+ "model_framework": "TENSORFLOW",
+ "model_server": "TENSORFLOW_SERVING",
+ "serving_tool": "KSERVE",
+ "api_protocol": "REST",
+ "artifact_version": 2,
+ "transformer": "transformer_file",
+ "requested_instances": 1,
+ "requested_transformer_instances": 1,
+ "predictor_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "transformer_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "batching_configuration": {
+ "batching_enabled": true,
+ "max_batch_size": 1000,
+ "max_latency": 1000,
+ "timeout": 1000
+ },
+ "inference_logging": "ALL",
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ }
+ },
+
+ "get_deployment_py_kserve_rest_pred": {
+ "response": {
+ "id": 1,
+ "name": "test",
+ "description": "test_desc",
+ "created": "",
+ "creator": "",
+ "model_path": "test_model_path",
+ "model_name": "test_model_name",
+ "model_version": 1,
+ "model_framework": "PYTHON",
+ "model_server": "PYTHON",
+ "serving_tool": "KSERVE",
+ "api_protocol": "REST",
+ "artifact_version": 2,
+ "predictor": "predictor_file",
+ "requested_instances": 1,
+ "predictor_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "batching_configuration": {
+ "batching_enabled": true,
+ "max_batch_size": 1000,
+ "max_latency": 1000,
+ "timeout": 1000
+ },
+ "inference_logging": "ALL",
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ }
+ },
+
+ "get_deployment_py_kserve_rest_pred_trans": {
+ "response": {
+ "id": 1,
+ "name": "test",
+ "description": "test_desc",
+ "created": "",
+ "creator": "",
+ "model_path": "test_model_path",
+ "model_name": "test_model_name",
+ "model_version": 1,
+ "model_framework": "PYTHON",
+ "model_server": "PYTHON",
+ "serving_tool": "KSERVE",
+ "api_protocol": "REST",
+ "artifact_version": 2,
+ "predictor": "predictor_file",
+ "transformer": "transformer_file",
+ "requested_instances": 1,
+ "requested_transformer_instances": 1,
+ "predictor_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "transformer_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "inference_logging": "ALL",
+ "batching_configuration": {
+ "batching_enabled": true,
+ "max_batch_size": 1000,
+ "max_latency": 1000,
+ "timeout": 1000
+ },
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ }
+ },
+
+ "get_deployment_py_kserve_grpc_pred": {
+ "response": {
+ "id": 1,
+ "name": "test",
+ "description": "test_desc",
+ "created": "",
+ "creator": "",
+ "model_path": "test_model_path",
+ "model_name": "test_model_name",
+ "model_version": 1,
+ "model_framework": "PYTHON",
+ "model_server": "PYTHON",
+ "serving_tool": "KSERVE",
+ "api_protocol": "GRPC",
+ "artifact_version": 2,
+ "predictor": "predictor_file",
+ "requested_instances": 1,
+ "predictor_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "inference_logging": "ALL",
+ "batching_configuration": {
+ "batching_enabled": true,
+ "max_batch_size": 1000,
+ "max_latency": 1000,
+ "timeout": 1000
+ },
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ }
+ },
+
+ "get_deployment_py_kserve_grpc_pred_trans": {
+ "response": {
+ "id": 1,
+ "name": "test",
+ "description": "test_desc",
+ "created": "",
+ "creator": "",
+ "model_path": "test_model_path",
+ "model_name": "test_model_name",
+ "model_version": 1,
+ "model_framework": "PYTHON",
+ "model_server": "PYTHON",
+ "serving_tool": "KSERVE",
+ "api_protocol": "GRPC",
+ "artifact_version": 2,
+ "predictor": "predictor_file",
+ "transformer": "transformer_file",
+ "requested_instances": 1,
+ "requested_transformer_instances": 1,
+ "predictor_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "transformer_resources": {
+ "requested_instances": 1,
+ "requests": { "cores": 0.2, "memory": 16, "gpus": 1 },
+ "limits": { "cores": 0.3, "memory": 17, "gpus": 2 }
+ },
+ "inference_logging": "ALL",
+ "batching_configuration": {
+ "batching_enabled": true,
+ "max_batch_size": 1000,
+ "max_latency": 1000,
+ "timeout": 1000
+ },
+ "kafka_topic_dto": {
+ "name": "topic"
+ }
+ }
+ },
+ "get_deployment_predictor_state": {
+ "response": {
+ "available_instances": 1,
+ "available_transformer_instances": 1,
+ "hopsworks_inference_path": "hopsworks/api/path",
+ "model_server_inference_path": "model-server/path",
+ "internal_port": 1234,
+ "revision": 1234,
+ "deployed": "1234",
+ "condition": {
+ "type": "TYPE",
+ "status": true,
+ "reason": "REASON"
+ },
+ "status": "RUNNING"
+ }
+ },
+ "get_deployment_component_logs_empty": {
+ "response": []
+ },
+ "get_deployment_component_logs_single": {
+ "response": [
+ {
+ "instance_name": "instance_name",
+ "content": "content"
+ }
+ ]
+ },
+ "get_deployment_component_logs_list": {
+ "response": [
+ {
+ "instance_name": "instance_name_2",
+ "content": "content_2"
+ },
+ {
+ "instance_name": "instance_name_2",
+ "content": "content_2"
+ }
+ ]
+ }
+}
diff --git a/hsml/python/tests/fixtures/resources_fixtures.json b/hsml/python/tests/fixtures/resources_fixtures.json
new file mode 100644
index 000000000..874daf0bf
--- /dev/null
+++ b/hsml/python/tests/fixtures/resources_fixtures.json
@@ -0,0 +1,155 @@
+{
+ "get_only_cores": {
+ "response": {
+ "cores": 0.2
+ }
+ },
+ "get_only_memory": {
+ "response": {
+ "memory": 16
+ }
+ },
+ "get_only_gpus": {
+ "response": {
+ "gpus": 1
+ }
+ },
+
+ "get_cores_and_memory": {
+ "response": {
+ "cores": 0.2,
+ "memory": 16
+ }
+ },
+
+ "get_cores_and_gpus": {
+ "response": {
+ "cores": 0.2,
+ "gpus": 1
+ }
+ },
+
+ "get_memory_and_gpus": {
+ "response": {
+ "memory": 16,
+ "gpus": 1
+ }
+ },
+
+ "get_cores_memory_and_gpus": {
+ "response": {
+ "cores": 0.2,
+ "memory": 16,
+ "gpus": 1
+ }
+ },
+ "get_component_resources_num_instances": {
+ "response": {
+ "num_instances": 1
+ }
+ },
+ "get_component_resources_requests": {
+ "response": {
+ "requests": {
+ "cores": 0.2,
+ "memory": 16,
+ "gpus": 1
+ }
+ }
+ },
+ "get_component_resources_limits": {
+ "response": {
+ "limits": {
+ "cores": 0.3,
+ "memory": 17,
+ "gpus": 2
+ }
+ }
+ },
+ "get_component_resources_num_instances_and_requests": {
+ "response": {
+ "num_instances": 1,
+ "requests": {
+ "cores": 0.2,
+ "memory": 16,
+ "gpus": 1
+ }
+ }
+ },
+ "get_component_resources_num_instances_and_limits": {
+ "response": {
+ "num_instances": 1,
+ "limits": {
+ "cores": 0.3,
+ "memory": 17,
+ "gpus": 2
+ }
+ }
+ },
+ "get_component_resources_num_instances_requests_and_limits": {
+ "response": {
+ "num_instances": 1,
+ "requests": {
+ "cores": 0.2,
+ "memory": 16,
+ "gpus": 1
+ },
+ "limits": {
+ "cores": 0.3,
+ "memory": 17,
+ "gpus": 2
+ }
+ }
+ },
+ "get_component_resources_requested_instances_and_predictor_resources": {
+ "response": {
+ "requested_instances": 1,
+ "predictor_resources": {
+ "requests": {
+ "cores": 0.2,
+ "memory": 16,
+ "gpus": 1
+ },
+ "limits": {
+ "cores": 0.3,
+ "memory": 17,
+ "gpus": 2
+ }
+ }
+ }
+ },
+ "get_component_resources_requested_instances_and_predictor_resources_alternative": {
+ "response": {
+ "num_instances": 1,
+ "resources": {
+ "requests": {
+ "cores": 0.2,
+ "memory": 16,
+ "gpus": 1
+ },
+ "limits": {
+ "cores": 0.3,
+ "memory": 17,
+ "gpus": 2
+ }
+ }
+ }
+ },
+ "get_component_resources_requested_instances_and_transformer_resources": {
+ "response": {
+ "requested_transformer_instances": 1,
+ "transformer_resources": {
+ "requests": {
+ "cores": 0.2,
+ "memory": 16,
+ "gpus": 1
+ },
+ "limits": {
+ "cores": 0.3,
+ "memory": 17,
+ "gpus": 2
+ }
+ }
+ }
+ }
+}
diff --git a/hsml/python/tests/fixtures/tag_fixtures.json b/hsml/python/tests/fixtures/tag_fixtures.json
new file mode 100644
index 000000000..a2516562f
--- /dev/null
+++ b/hsml/python/tests/fixtures/tag_fixtures.json
@@ -0,0 +1,25 @@
+{
+ "get": {
+ "response": {
+ "count": 1,
+ "items": [
+ {
+ "name": "test_name",
+ "value": "test_value",
+ "schema": "test_schema",
+ "href": "test_href",
+ "expand": "test_expand",
+ "items": [],
+ "count": 0,
+ "type": "tagDTO"
+ }
+ ]
+ }
+ },
+ "get_empty": {
+ "response": {
+ "count": 0,
+ "items": []
+ }
+ }
+}
diff --git a/hsml/python/tests/fixtures/transformer_fixtures.json b/hsml/python/tests/fixtures/transformer_fixtures.json
new file mode 100644
index 000000000..269e525de
--- /dev/null
+++ b/hsml/python/tests/fixtures/transformer_fixtures.json
@@ -0,0 +1,63 @@
+{
+ "get_deployment_with_transformer": {
+ "response": {
+ "name": "test",
+ "transformer": "transformer_file_name",
+ "transformer_resources": {
+ "requested_transformer_instances": 1,
+ "requests": {
+ "cores": 0.2,
+ "memory": 16,
+ "gpus": 1
+ },
+ "limits": {
+ "cores": 0.3,
+ "memory": 17,
+ "gpus": 2
+ }
+ }
+ }
+ },
+ "get_deployment_without_transformer": {
+ "response": {
+ "name": "test",
+ "predictor": "predictor_file_name",
+ "predictor_resources": {
+ "num_instances": 1,
+ "requests": {
+ "cores": 0.2,
+ "memory": 16,
+ "gpus": 1
+ },
+ "limits": {
+ "cores": 0.3,
+ "memory": 17,
+ "gpus": 2
+ }
+ }
+ }
+ },
+ "get_transformer_with_resources": {
+ "response": {
+ "script_file": "transformer_file_name",
+ "resources": {
+ "num_instances": 1,
+ "requests": {
+ "cores": 0.2,
+ "memory": 16,
+ "gpus": 1
+ },
+ "limits": {
+ "cores": 0.3,
+ "memory": 17,
+ "gpus": 2
+ }
+ }
+ }
+ },
+ "get_transformer_without_resources": {
+ "response": {
+ "script_file": "transformer_file_name"
+ }
+ }
+}
diff --git a/hsml/python/tests/test_connection.py b/hsml/python/tests/test_connection.py
new file mode 100644
index 000000000..c8d100279
--- /dev/null
+++ b/hsml/python/tests/test_connection.py
@@ -0,0 +1,173 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from hsml.connection import (
+ CONNECTION_SAAS_HOSTNAME,
+ HOPSWORKS_PORT_DEFAULT,
+ HOSTNAME_VERIFICATION_DEFAULT,
+ Connection,
+)
+from hsml.core import model_api, model_registry_api, model_serving_api
+
+
+class TestConnection:
+ # constants
+
+ def test_constants(self):
+ # The purpose of this test is to ensure that (1) we don't make undesired changes to contant values
+ # that might break things somewhere else, and (2) we remember to update the pytests accordingly by
+ # adding / removing / updating tests, if necessary.
+ assert CONNECTION_SAAS_HOSTNAME == "c.app.hopsworks.ai"
+ assert HOPSWORKS_PORT_DEFAULT == 443
+ assert HOSTNAME_VERIFICATION_DEFAULT
+
+ # constructor
+
+ def test_constructor_default(self, mocker):
+ # Arrange
+ class MockConnection:
+ pass
+
+ mock_connection = MockConnection()
+ mock_connection.connect = mocker.MagicMock()
+ mock_connection.init = Connection.__init__
+ mock_model_api_init = mocker.patch(
+ "hsml.core.model_api.ModelApi.__init__", return_value=None
+ )
+ mock_model_registry_api = mocker.patch(
+ "hsml.core.model_registry_api.ModelRegistryApi.__init__", return_value=None
+ )
+ mock_model_serving_api = mocker.patch(
+ "hsml.core.model_serving_api.ModelServingApi.__init__", return_value=None
+ )
+
+ # Act
+ mock_connection.init(mock_connection)
+
+ # Assert
+ assert mock_connection._host is None
+ assert mock_connection._port == HOPSWORKS_PORT_DEFAULT
+ assert mock_connection._project is None
+ assert mock_connection._hostname_verification == HOSTNAME_VERIFICATION_DEFAULT
+ assert mock_connection._trust_store_path is None
+ assert mock_connection._api_key_file is None
+ assert mock_connection._api_key_value is None
+ assert isinstance(mock_connection._model_api, model_api.ModelApi)
+ assert isinstance(
+ mock_connection._model_registry_api, model_registry_api.ModelRegistryApi
+ )
+ assert isinstance(
+ mock_connection._model_serving_api, model_serving_api.ModelServingApi
+ )
+ assert not mock_connection._connected
+ mock_model_api_init.assert_called_once()
+ mock_model_registry_api.assert_called_once()
+ mock_model_serving_api.assert_called_once()
+ mock_connection.connect.assert_called_once()
+
+ def test_constructor(self, mocker):
+ # Arrange
+ class MockConnection:
+ pass
+
+ mock_connection = MockConnection()
+ mock_connection.connect = mocker.MagicMock()
+ mock_connection.init = Connection.__init__
+ mock_model_api_init = mocker.patch(
+ "hsml.core.model_api.ModelApi.__init__", return_value=None
+ )
+ mock_model_registry_api = mocker.patch(
+ "hsml.core.model_registry_api.ModelRegistryApi.__init__", return_value=None
+ )
+ mock_model_serving_api = mocker.patch(
+ "hsml.core.model_serving_api.ModelServingApi.__init__", return_value=None
+ )
+
+ # Act
+ mock_connection.init(
+ mock_connection,
+ host="host",
+ port=1234,
+ project="project",
+ hostname_verification=False,
+ trust_store_path="ts_path",
+ api_key_file="ak_file",
+ api_key_value="ak_value",
+ )
+
+ # Assert
+ assert mock_connection._host == "host"
+ assert mock_connection._port == 1234
+ assert mock_connection._project == "project"
+ assert not mock_connection._hostname_verification
+ assert mock_connection._trust_store_path == "ts_path"
+ assert mock_connection._api_key_file == "ak_file"
+ assert mock_connection._api_key_value == "ak_value"
+ assert isinstance(mock_connection._model_api, model_api.ModelApi)
+ assert isinstance(
+ mock_connection._model_registry_api, model_registry_api.ModelRegistryApi
+ )
+ assert isinstance(
+ mock_connection._model_serving_api, model_serving_api.ModelServingApi
+ )
+ assert not mock_connection._connected
+ mock_model_api_init.assert_called_once()
+ mock_model_registry_api.assert_called_once()
+ mock_model_serving_api.assert_called_once()
+ mock_connection.connect.assert_called_once()
+
+ # handlers
+
+ def test_get_model_registry(self, mocker):
+ # Arrange
+ mock_connection = mocker.MagicMock()
+ mock_connection.get_model_registry = Connection.get_model_registry
+ mock_connection._model_registry_api = mocker.MagicMock()
+ mock_connection._model_registry_api.get = mocker.MagicMock(return_value="mr")
+
+ # Act
+ mr = mock_connection.get_model_registry(mock_connection)
+
+ # Assert
+ assert mr == "mr"
+ mock_connection._model_registry_api.get.assert_called_once()
+
+ def test_get_model_serving(self, mocker):
+ # Arrange
+ mock_connection = mocker.MagicMock()
+ mock_connection.get_model_serving = Connection.get_model_serving
+ mock_connection._model_serving_api = mocker.MagicMock()
+ mock_connection._model_serving_api.get = mocker.MagicMock(return_value="ms")
+
+ # Act
+ ms = mock_connection.get_model_serving(mock_connection)
+
+ # Assert
+ assert ms == "ms"
+ mock_connection._model_serving_api.get.assert_called_once()
+
+ # connection
+
+ # TODO: Add tests for connection-related methods
+
+ def test_connect(self, mocker):
+ pass
+
+ def test_close(self, mocker):
+ pass
+
+ def test_connection(self, mocker):
+ pass
diff --git a/hsml/python/tests/test_constants.py b/hsml/python/tests/test_constants.py
new file mode 100644
index 000000000..8c2b21695
--- /dev/null
+++ b/hsml/python/tests/test_constants.py
@@ -0,0 +1,383 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+
+from hsml import constants
+
+
+class TestConstants:
+ # NOTE
+ # This class contains validations for constants and enum values.
+ # The purpose of this class is to ensure that (1) we don't make undesired changes to contant values
+ # that might break things somewhere else, and (2) we remember to update the pytests accordingly by
+ # adding / removing / updating tests.
+
+ # This class includes the following validations:
+ # - Number of possible values of an Enum (to check for added/removed values)
+ # - Exact values of contants (to check for modified values)
+
+ # MODEL
+
+ def test_model_framework_constants(self):
+ # Arrange
+ model_frameworks = {
+ "FRAMEWORK_TENSORFLOW": "TENSORFLOW",
+ "FRAMEWORK_TORCH": "TORCH",
+ "FRAMEWORK_PYTHON": "PYTHON",
+ "FRAMEWORK_SKLEARN": "SKLEARN",
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.MODEL,
+ num_values=len(model_frameworks),
+ expected_constants=model_frameworks,
+ prefix="FRAMEWORK",
+ )
+
+ # MODEL_REGISTRY
+
+ def test_model_registry_constants(self):
+ # Arrange
+ hopsfs_mount_prefix = {"HOPSFS_MOUNT_PREFIX": "/home/yarnapp/hopsfs/"}
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.MODEL_REGISTRY,
+ num_values=len(hopsfs_mount_prefix),
+ expected_constants=hopsfs_mount_prefix,
+ )
+
+ # MODEL_SERVING
+
+ def test_model_serving_constants(self):
+ # Arrange
+ models_dataset = {"MODELS_DATASET": "Models"}
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.MODEL_SERVING,
+ num_values=len(models_dataset),
+ expected_constants=models_dataset,
+ )
+
+ # ARTIFACT_VERSION
+
+ def test_artifact_version_constants(self):
+ # Arrange
+ artifact_versions = {"CREATE": "CREATE"}
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.ARTIFACT_VERSION,
+ num_values=len(artifact_versions),
+ expected_constants=artifact_versions,
+ )
+
+ # RESOURCES
+
+ def test_resources_min_constants(self):
+ # Arrange
+ min_resources = {
+ "MIN_NUM_INSTANCES": 1,
+ "MIN_CORES": 0.2,
+ "MIN_MEMORY": 32,
+ "MIN_GPUS": 0,
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.RESOURCES,
+ num_values=len(min_resources),
+ expected_constants=min_resources,
+ prefix="MIN",
+ )
+
+ def test_resources_max_constants(self):
+ # Arrange
+ max_resources = {
+ "MAX_CORES": 2,
+ "MAX_MEMORY": 1024,
+ "MAX_GPUS": 0,
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.RESOURCES,
+ num_values=len(max_resources),
+ expected_constants=max_resources,
+ prefix="MAX",
+ )
+
+ # KAFKA_TOPIC
+
+ def test_kafka_topic_names_constants(self):
+ # Arrange
+ kafka_topic_cons = {
+ "NONE": "NONE",
+ "CREATE": "CREATE",
+ "NUM_REPLICAS": 1,
+ "NUM_PARTITIONS": 1,
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.KAFKA_TOPIC,
+ num_values=len(kafka_topic_cons),
+ expected_constants=kafka_topic_cons,
+ )
+
+ # INFERENCE_LOGGER
+
+ def test_inference_logger_constants(self):
+ # Arrange
+ if_modes = {
+ "MODE_NONE": "NONE",
+ "MODE_ALL": "ALL",
+ "MODE_MODEL_INPUTS": "MODEL_INPUTS",
+ "MODE_PREDICTIONS": "PREDICTIONS",
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.INFERENCE_LOGGER,
+ num_values=len(if_modes),
+ expected_constants=if_modes,
+ prefix="MODE",
+ )
+
+ # INFERENCE_BATCHER
+
+ def test_inference_batcher_constants(self):
+ # Arrange
+ if_batcher = {"ENABLED": False}
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.INFERENCE_BATCHER,
+ num_values=len(if_batcher),
+ expected_constants=if_batcher,
+ )
+
+ # DEPLOYMENT
+
+ def test_deployment_constants(self):
+ # Arrange
+ depl_actions = {"ACTION_START": "START", "ACTION_STOP": "STOP"}
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.DEPLOYMENT,
+ num_values=len(depl_actions),
+ expected_constants=depl_actions,
+ prefix="ACTION",
+ )
+
+ # PREDICTOR
+
+ def test_predictor_model_server_constants(self):
+ # Arrange
+ model_servers = {
+ "MODEL_SERVER_PYTHON": "PYTHON",
+ "MODEL_SERVER_TF_SERVING": "TENSORFLOW_SERVING",
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.PREDICTOR,
+ num_values=len(model_servers),
+ expected_constants=model_servers,
+ prefix="MODEL_SERVER",
+ )
+
+ def test_predictor_serving_tool_constants(self):
+ # Arrange
+ serving_tools = {
+ "SERVING_TOOL_DEFAULT": "DEFAULT",
+ "SERVING_TOOL_KSERVE": "KSERVE",
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.PREDICTOR,
+ num_values=len(serving_tools),
+ expected_constants=serving_tools,
+ prefix="SERVING_TOOL",
+ )
+
+ # PREDICTOR_STATE
+
+ def test_predictor_state_status_constants(self):
+ # Arrange
+ predictor_states = {
+ "STATUS_CREATING": "Creating",
+ "STATUS_CREATED": "Created",
+ "STATUS_STARTING": "Starting",
+ "STATUS_FAILED": "Failed",
+ "STATUS_RUNNING": "Running",
+ "STATUS_IDLE": "Idle",
+ "STATUS_UPDATING": "Updating",
+ "STATUS_STOPPING": "Stopping",
+ "STATUS_STOPPED": "Stopped",
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.PREDICTOR_STATE,
+ num_values=len(predictor_states),
+ expected_constants=predictor_states,
+ prefix="STATUS",
+ )
+
+ def test_predictor_state_condition_constants(self):
+ # Arrange
+ predictor_states = {
+ "CONDITION_TYPE_STOPPED": "STOPPED",
+ "CONDITION_TYPE_SCHEDULED": "SCHEDULED",
+ "CONDITION_TYPE_INITIALIZED": "INITIALIZED",
+ "CONDITION_TYPE_STARTED": "STARTED",
+ "CONDITION_TYPE_READY": "READY",
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.PREDICTOR_STATE,
+ num_values=len(predictor_states),
+ expected_constants=predictor_states,
+ prefix="CONDITION",
+ )
+
+ # INFERENCE_ENDPOINTS
+
+ def test_inference_endpoints_type_constants(self):
+ # Arrange
+ ie_types = {
+ "ENDPOINT_TYPE_NODE": "NODE",
+ "ENDPOINT_TYPE_KUBE_CLUSTER": "KUBE_CLUSTER",
+ "ENDPOINT_TYPE_LOAD_BALANCER": "LOAD_BALANCER",
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.INFERENCE_ENDPOINTS,
+ num_values=len(ie_types),
+ expected_constants=ie_types,
+ prefix="ENDPOINT_TYPE",
+ )
+
+ def test_inference_endpoints_port_constants(self):
+ # Arrange
+ ie_ports = {
+ "PORT_NAME_HTTP": "HTTP",
+ "PORT_NAME_HTTPS": "HTTPS",
+ "PORT_NAME_STATUS_PORT": "STATUS",
+ "PORT_NAME_TLS": "TLS",
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.INFERENCE_ENDPOINTS,
+ num_values=len(ie_ports),
+ expected_constants=ie_ports,
+ prefix="PORT_NAME",
+ )
+
+ def test_inference_endpoints_api_protocol_constants(self):
+ # Arrange
+ ie_api_protocols = {
+ "API_PROTOCOL_REST": "REST",
+ "API_PROTOCOL_GRPC": "GRPC",
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.INFERENCE_ENDPOINTS,
+ num_values=len(ie_api_protocols),
+ expected_constants=ie_api_protocols,
+ prefix="API_PROTOCOL",
+ )
+
+ # DEPLOYABLE_COMPONENT
+
+ def test_inference_endpoints_deployable_component_constants(self):
+ # Arrange
+ depl_components = {
+ "PREDICTOR": "predictor",
+ "TRANSFORMER": "transformer",
+ }
+
+ # Assert
+ self._check_added_modified_or_removed_values(
+ constants.DEPLOYABLE_COMPONENT,
+ num_values=len(depl_components),
+ expected_constants=depl_components,
+ )
+
+ # Auxiliary methods
+
+ def _check_added_modified_or_removed_values(
+ self, cls, num_values, expected_constants=None, prefix=None
+ ):
+ cname = cls.__name__ + ("." + prefix if prefix is not None else "")
+ const_dict = self._get_contants_name_value_dict(cls, prefix=prefix)
+ # exact constants
+ if expected_constants is not None:
+ # constant names
+ added_cnames = const_dict.keys() - expected_constants.keys()
+ removed_cnames = expected_constants.keys() - const_dict.keys()
+
+ assert len(added_cnames) == 0, (
+ f"One or more constants were added under {cname} with names {added_cnames}. "
+ + "If it was intentional, please add/remove/update tests accordingly (not only in this file, "
+ + "but wherever it corresponds)."
+ )
+
+ assert len(removed_cnames) == 0, (
+ f"One or more constants were removed under {cname} with names {removed_cnames}. "
+ + "If it was intentional, please add/remove/update tests accordingly (not only in this file, "
+ + "but wherever it corresponds)."
+ )
+
+ assert const_dict.keys() == expected_constants.keys(), (
+ f"One or more constants under {cname} were modified from {removed_cnames} to {added_cnames}. "
+ + "If it was intentional, please add/remove/update tests accordingly (not only in this file, "
+ + "but wherever it corresponds)."
+ )
+
+ # constant values
+ for cname, cvalue in expected_constants.items():
+ full_cname = f"{cls.__name__}.{cname}"
+ assert cvalue == const_dict[cname], (
+ f"The constant {full_cname} was modified from {cvalue} to {const_dict[cname]}. "
+ + "If it was intentional, please add/remove/update tests accordingly (not only in this file, "
+ + "but wherever it corresponds)."
+ )
+ else:
+ # number of values
+ assert len(const_dict) == num_values, (
+ f"A constant was added/removed under {cname}. If it was intentional, please "
+ + "add/remove/update tests accordingly (not only in this file, but wherever it corresponds)."
+ )
+
+ def _get_contants_name_value_dict(self, cls, prefix=None) -> dict:
+ const_dict = dict()
+ for m in inspect.getmembers(cls, lambda m: not (inspect.isroutine(m))):
+ n = m[0] # name
+ if (prefix is not None and n.startswith(prefix)) or (
+ prefix is None and not (n.startswith("__") and n.endswith("__"))
+ ):
+ const_dict[n] = m[1] # value
+ return const_dict
diff --git a/hsml/python/tests/test_decorators.py b/hsml/python/tests/test_decorators.py
new file mode 100644
index 000000000..7d17e18ea
--- /dev/null
+++ b/hsml/python/tests/test_decorators.py
@@ -0,0 +1,82 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+from hsml.decorators import (
+ HopsworksConnectionError,
+ NoHopsworksConnectionError,
+ connected,
+ not_connected,
+)
+
+
+class TestDecorators:
+ # test not connected
+
+ def test_not_connected_valid(self, mocker):
+ # Arrange
+ mock_instance = mocker.MagicMock()
+ mock_instance._connected = False
+
+ @not_connected
+ def assert_not_connected(inst, arg, key_arg):
+ assert not inst._connected
+ assert arg == "arg"
+ assert key_arg == "key_arg"
+
+ # Act
+ assert_not_connected(mock_instance, "arg", key_arg="key_arg")
+
+ def test_not_connected_invalid(self, mocker):
+ # Arrange
+ mock_instance = mocker.MagicMock()
+ mock_instance._connected = True
+
+ @not_connected
+ def assert_not_connected(inst, arg, key_arg):
+ pass
+
+ # Act
+ with pytest.raises(HopsworksConnectionError):
+ assert_not_connected(mock_instance, "arg", key_arg="key_arg")
+
+ # test connected
+
+ def test_connected_valid(self, mocker):
+ # Arrange
+ mock_instance = mocker.MagicMock()
+ mock_instance._connected = True
+
+ @connected
+ def assert_connected(inst, arg, key_arg):
+ assert inst._connected
+ assert arg == "arg"
+ assert key_arg == "key_arg"
+
+ # Act
+ assert_connected(mock_instance, "arg", key_arg="key_arg")
+
+ def test_connected_invalid(self, mocker):
+ # Arrange
+ mock_instance = mocker.MagicMock()
+ mock_instance._connected = False
+
+ @connected
+ def assert_connected(inst, arg, key_arg):
+ pass
+
+ # Act
+ with pytest.raises(NoHopsworksConnectionError):
+ assert_connected(mock_instance, "arg", key_arg="key_arg")
diff --git a/hsml/python/tests/test_deployable_component.py b/hsml/python/tests/test_deployable_component.py
new file mode 100644
index 000000000..97ec67018
--- /dev/null
+++ b/hsml/python/tests/test_deployable_component.py
@@ -0,0 +1,106 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from hsml import deployable_component, inference_batcher
+
+
+class TestDeployableComponent:
+ # from response json
+
+ def test_from_response_json(self, mocker):
+ # Arrange
+ json = {"test": "test"}
+ mock_from_json = mocker.patch(
+ "hsml.deployable_component.DeployableComponent.from_json",
+ return_value="from_json_result",
+ )
+
+ # Act
+ result = deployable_component.DeployableComponent.from_response_json(json)
+
+ # Assert
+ assert result == "from_json_result"
+ mock_from_json.assert_called_once_with(json)
+
+ # constructor
+
+ def test_constructor_default(self, mocker):
+ # Arrange
+ mock_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value=None
+ )
+ mock_ib_init = mocker.patch(
+ "hsml.inference_batcher.InferenceBatcher.__init__", return_value=None
+ )
+
+ class DeployableComponentChild(deployable_component.DeployableComponent):
+ def from_json():
+ pass
+
+ def update_from_response_json():
+ pass
+
+ def to_dict():
+ pass
+
+ # Act
+ dc = DeployableComponentChild()
+
+ # Assert
+ assert dc.script_file is None
+ assert dc.resources is None
+ mock_get_obj_from_json.assert_called_once_with(
+ None, inference_batcher.InferenceBatcher
+ )
+ mock_ib_init.assert_called_once()
+
+ def test_constructor_with_params(self, mocker):
+ # Arrange
+ script_file = "script_file"
+ resources = {}
+ inf_batcher = inference_batcher.InferenceBatcher()
+ mock_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value=inf_batcher
+ )
+ mock_ib_init = mocker.patch(
+ "hsml.inference_batcher.InferenceBatcher.__init__", return_value=None
+ )
+
+ class DeployableComponentChild(deployable_component.DeployableComponent):
+ def from_json():
+ pass
+
+ def update_from_response_json():
+ pass
+
+ def to_dict():
+ pass
+
+ # Act
+ dc = DeployableComponentChild(
+ script_file=script_file,
+ resources=resources,
+ inference_batcher=inf_batcher,
+ )
+
+ # Assert
+ assert dc.script_file == script_file
+ assert dc.resources == resources
+ mock_get_obj_from_json.assert_called_once_with(
+ inf_batcher, inference_batcher.InferenceBatcher
+ )
+ assert dc.inference_batcher == inf_batcher
+ mock_ib_init.assert_not_called()
diff --git a/hsml/python/tests/test_deployable_component_logs.py b/hsml/python/tests/test_deployable_component_logs.py
new file mode 100644
index 000000000..3c61aabb0
--- /dev/null
+++ b/hsml/python/tests/test_deployable_component_logs.py
@@ -0,0 +1,110 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import datetime
+
+import humps
+from hsml import deployable_component_logs
+
+
+class TestDeployableComponentLogs:
+ # from response json
+
+ def test_from_response_json(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["predictor"]["get_deployment_component_logs_single"][
+ "response"
+ ]
+ json_camelized = humps.camelize(json)
+ mocker_from_json = mocker.patch(
+ "hsml.deployable_component_logs.DeployableComponentLogs.from_json",
+ return_value=None,
+ )
+
+ # Act
+ dc_logs = deployable_component_logs.DeployableComponentLogs.from_response_json(
+ json_camelized
+ )
+
+ # Assert
+ assert isinstance(dc_logs, list)
+ assert len(dc_logs) == 1
+ mocker_from_json.assert_called_once_with(json[0])
+
+ def test_from_response_json_list(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["predictor"]["get_deployment_component_logs_list"][
+ "response"
+ ]
+ json_camelized = humps.camelize(json)
+ mocker_from_json = mocker.patch(
+ "hsml.deployable_component_logs.DeployableComponentLogs.from_json",
+ return_value=None,
+ )
+
+ # Act
+ dc_logs = deployable_component_logs.DeployableComponentLogs.from_response_json(
+ json_camelized
+ )
+
+ # Assert
+ assert isinstance(dc_logs, list)
+ assert len(dc_logs) == len(json_camelized)
+ assert mocker_from_json.call_count == len(json_camelized)
+
+ def test_from_response_json_empty(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["predictor"]["get_deployment_component_logs_empty"][
+ "response"
+ ]
+ json_camelized = humps.camelize(json)
+ mocker_from_json = mocker.patch(
+ "hsml.deployable_component_logs.DeployableComponentLogs.from_json",
+ return_value=None,
+ )
+
+ # Act
+ dc_logs = deployable_component_logs.DeployableComponentLogs.from_response_json(
+ json_camelized
+ )
+
+ # Assert
+ assert isinstance(dc_logs, list)
+ assert len(dc_logs) == 0
+ mocker_from_json.assert_not_called()
+
+ # constructor
+
+ def test_constructor(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["predictor"]["get_deployment_component_logs_single"][
+ "response"
+ ]
+ instance_name = json[0]["instance_name"]
+ content = json[0]["content"]
+ now = datetime.datetime.now()
+
+ # Act
+ dcl = deployable_component_logs.DeployableComponentLogs(
+ instance_name=instance_name, content=content
+ )
+
+ # Assert
+ assert dcl.instance_name == instance_name
+ assert dcl.content == content
+ assert (dcl.created_at >= now) and (
+ dcl.created_at < (now + datetime.timedelta(seconds=1))
+ )
diff --git a/hsml/python/tests/test_deployment.py b/hsml/python/tests/test_deployment.py
new file mode 100644
index 000000000..7e3d7e4a5
--- /dev/null
+++ b/hsml/python/tests/test_deployment.py
@@ -0,0 +1,795 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import pytest
+from hsml import deployment, predictor
+from hsml.client.exceptions import ModelServingException
+from hsml.constants import PREDICTOR_STATE
+from hsml.core import serving_api
+from hsml.engine import serving_engine
+
+
+class TestDeployment:
+ # from response json
+
+ def test_from_response_json_list(self, mocker, backend_fixtures):
+ # Arrange
+ preds = [{"name": "pred_name"}]
+ mock_pred_from_response_json = mocker.patch(
+ "hsml.predictor.Predictor.from_response_json",
+ return_value=preds,
+ )
+ mock_from_predictor = mocker.patch(
+ "hsml.deployment.Deployment.from_predictor", return_value=preds[0]
+ )
+
+ # Act
+ depl = deployment.Deployment.from_response_json(preds)
+
+ # Assert
+ assert isinstance(depl, list)
+ assert depl[0] == preds[0]
+ mock_pred_from_response_json.assert_called_once_with(preds)
+ mock_from_predictor.assert_called_once_with(preds[0])
+
+ def test_from_response_json_single(self, mocker, backend_fixtures):
+ # Arrange
+ pred = {"name": "pred_name"}
+ mock_pred_from_response_json = mocker.patch(
+ "hsml.predictor.Predictor.from_response_json",
+ return_value=pred,
+ )
+ mock_from_predictor = mocker.patch(
+ "hsml.deployment.Deployment.from_predictor", return_value=pred
+ )
+
+ # Act
+ depl = deployment.Deployment.from_response_json(pred)
+
+ # Assert
+ assert depl == pred
+ mock_pred_from_response_json.assert_called_once_with(pred)
+ mock_from_predictor.assert_called_once_with(pred)
+
+ # constructor
+
+ def test_constructor_default(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+
+ # Act
+ d = deployment.Deployment(predictor=p)
+
+ # Assert
+ assert d.name == p.name
+ assert d.description == p.description
+ assert d.predictor == p
+ assert isinstance(d._serving_api, serving_api.ServingApi)
+ assert isinstance(d._serving_engine, serving_engine.ServingEngine)
+
+ def test_constructor(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+
+ # Act
+ d = deployment.Deployment(predictor=p, name=p.name, description=p.description)
+
+ # Assert
+ assert d.name == p.name
+ assert d.description == p.description
+ assert d.predictor == p
+ assert isinstance(d._serving_api, serving_api.ServingApi)
+ assert isinstance(d._serving_engine, serving_engine.ServingEngine)
+
+ def test_constructor_no_predictor(self):
+ # Act
+ with pytest.raises(ModelServingException) as e_info:
+ _ = deployment.Deployment(predictor=None)
+
+ # Assert
+ assert "A predictor is required" in str(e_info.value)
+
+ def test_constructor_wrong_predictor(self):
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = deployment.Deployment(predictor={"wrong": "type"})
+
+ # Assert
+ assert "not an instance of the Predictor class" in str(e_info.value)
+
+ # from predictor
+
+ def test_from_predictor(self, mocker):
+ # Arrange
+ class MockPredictor:
+ _name = "name"
+ _description = "description"
+
+ p = MockPredictor()
+ mock_deployment_init = mocker.patch(
+ "hsml.deployment.Deployment.__init__", return_value=None
+ )
+
+ # Act
+ deployment.Deployment.from_predictor(p)
+
+ # Assert
+ mock_deployment_init.assert_called_once_with(
+ predictor=p, name=p._name, description=p._description
+ )
+
+ # save
+
+ def test_save_default(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_save = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.save"
+ )
+
+ # Act
+ d.save()
+
+ # Assert
+ mock_serving_engine_save.assert_called_once_with(d, 60)
+
+ def test_save(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_save = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.save"
+ )
+
+ # Act
+ await_update = 120
+ d.save(await_update=await_update)
+
+ # Assert
+ mock_serving_engine_save.assert_called_once_with(d, await_update)
+
+ # start
+
+ def test_start_default(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_start = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.start"
+ )
+
+ # Act
+ d.start()
+
+ # Assert
+ mock_serving_engine_start.assert_called_once_with(d, await_status=60)
+
+ def test_start(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_start = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.start"
+ )
+
+ # Act
+ await_running = 120
+ d.start(await_running=await_running)
+
+ # Assert
+ mock_serving_engine_start.assert_called_once_with(d, await_status=await_running)
+
+ # stop
+
+ def test_stop_default(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_stop = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.stop"
+ )
+
+ # Act
+ d.stop()
+
+ # Assert
+ mock_serving_engine_stop.assert_called_once_with(d, await_status=60)
+
+ def test_stop(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_start = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.stop"
+ )
+
+ # Act
+ await_stopped = 120
+ d.stop(await_stopped=await_stopped)
+
+ # Assert
+ mock_serving_engine_start.assert_called_once_with(d, await_status=await_stopped)
+
+ # delete
+
+ def test_delete_default(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_delete = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.delete"
+ )
+
+ # Act
+ d.delete()
+
+ # Assert
+ mock_serving_engine_delete.assert_called_once_with(d, False)
+
+ def test_delete(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_delete = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.delete"
+ )
+
+ # Act
+ force = True
+ d.delete(force=force)
+
+ # Assert
+ mock_serving_engine_delete.assert_called_once_with(d, force)
+
+ # get state
+
+ def test_get_state(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_get_state = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state"
+ )
+
+ # Act
+ d.get_state()
+
+ # Assert
+ mock_serving_engine_get_state.assert_called_once_with(d)
+
+ # status
+
+ # - is created
+
+ def test_is_created_false(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(PREDICTOR_STATE.STATUS_CREATING),
+ )
+
+ # Act
+ is_created = d.is_created()
+
+ # Assert
+ assert not is_created
+
+ def test_is_created_true(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ valid_statuses = [
+ PREDICTOR_STATE.STATUS_CREATED,
+ PREDICTOR_STATE.STATUS_FAILED,
+ PREDICTOR_STATE.STATUS_IDLE,
+ PREDICTOR_STATE.STATUS_RUNNING,
+ PREDICTOR_STATE.STATUS_STARTING,
+ PREDICTOR_STATE.STATUS_STOPPED,
+ PREDICTOR_STATE.STATUS_STOPPING,
+ PREDICTOR_STATE.STATUS_UPDATING,
+ ]
+
+ for valid_status in valid_statuses:
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(valid_status),
+ )
+
+ # Act and Assert
+ assert d.is_created()
+
+ # is running
+
+ def test_is_running_true(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(PREDICTOR_STATE.STATUS_RUNNING),
+ )
+
+ # Act and Assert
+ assert d.is_running()
+
+ def test_is_running_false(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ valid_statuses = [
+ PREDICTOR_STATE.STATUS_CREATED,
+ PREDICTOR_STATE.STATUS_CREATING,
+ PREDICTOR_STATE.STATUS_FAILED,
+ PREDICTOR_STATE.STATUS_IDLE,
+ PREDICTOR_STATE.STATUS_STARTING,
+ # PREDICTOR_STATE.STATUS_RUNNING,
+ PREDICTOR_STATE.STATUS_STOPPED,
+ PREDICTOR_STATE.STATUS_STOPPING,
+ PREDICTOR_STATE.STATUS_UPDATING,
+ ]
+
+ for valid_status in valid_statuses:
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(valid_status),
+ )
+
+ # Act and Assert
+ assert not d.is_running(or_idle=False, or_updating=False)
+
+ def test_is_running_or_idle_true(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ valid_statuses = [
+ PREDICTOR_STATE.STATUS_IDLE,
+ PREDICTOR_STATE.STATUS_RUNNING,
+ ]
+
+ for valid_status in valid_statuses:
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(valid_status),
+ )
+
+ # Act and Assert
+ assert d.is_running(or_idle=True)
+
+ def test_is_running_or_idle_false(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ valid_statuses = [
+ PREDICTOR_STATE.STATUS_CREATED,
+ PREDICTOR_STATE.STATUS_CREATING,
+ PREDICTOR_STATE.STATUS_FAILED,
+ # PREDICTOR_STATE.STATUS_IDLE,
+ PREDICTOR_STATE.STATUS_STARTING,
+ # PREDICTOR_STATE.STATUS_RUNNING,
+ PREDICTOR_STATE.STATUS_STOPPED,
+ PREDICTOR_STATE.STATUS_STOPPING,
+ PREDICTOR_STATE.STATUS_UPDATING,
+ ]
+
+ for valid_status in valid_statuses:
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(valid_status),
+ )
+
+ # Act and Assert
+ assert not d.is_running(or_idle=True, or_updating=False)
+
+ def test_is_running_or_updating_true(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ valid_statuses = [
+ # PREDICTOR_STATE.STATUS_CREATED,
+ # PREDICTOR_STATE.STATUS_CREATING,
+ # PREDICTOR_STATE.STATUS_FAILED,
+ # PREDICTOR_STATE.STATUS_IDLE,
+ # PREDICTOR_STATE.STATUS_STARTING,
+ PREDICTOR_STATE.STATUS_RUNNING,
+ # PREDICTOR_STATE.STATUS_STOPPED,
+ # PREDICTOR_STATE.STATUS_STOPPING,
+ PREDICTOR_STATE.STATUS_UPDATING,
+ ]
+
+ for valid_status in valid_statuses:
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(valid_status),
+ )
+
+ # Act and Assert
+ assert d.is_running(or_updating=True)
+
+ def test_is_running_or_updating_false(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ valid_statuses = [
+ PREDICTOR_STATE.STATUS_CREATED,
+ PREDICTOR_STATE.STATUS_CREATING,
+ PREDICTOR_STATE.STATUS_FAILED,
+ PREDICTOR_STATE.STATUS_IDLE,
+ PREDICTOR_STATE.STATUS_STARTING,
+ # PREDICTOR_STATE.STATUS_RUNNING,
+ PREDICTOR_STATE.STATUS_STOPPED,
+ PREDICTOR_STATE.STATUS_STOPPING,
+ # PREDICTOR_STATE.STATUS_UPDATING,
+ ]
+
+ for valid_status in valid_statuses:
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(valid_status),
+ )
+
+ # Act and Assert
+ assert not d.is_running(or_idle=False, or_updating=True)
+
+ # - is stopped
+
+ def test_is_stopped_true(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(PREDICTOR_STATE.STATUS_STOPPED),
+ )
+
+ # Act and Assert
+ assert d.is_stopped()
+
+ def test_is_stopped_false(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ valid_statuses = [
+ PREDICTOR_STATE.STATUS_CREATED,
+ PREDICTOR_STATE.STATUS_CREATING,
+ PREDICTOR_STATE.STATUS_FAILED,
+ PREDICTOR_STATE.STATUS_IDLE,
+ PREDICTOR_STATE.STATUS_STARTING,
+ PREDICTOR_STATE.STATUS_RUNNING,
+ # PREDICTOR_STATE.STATUS_STOPPED,
+ PREDICTOR_STATE.STATUS_STOPPING,
+ PREDICTOR_STATE.STATUS_UPDATING,
+ ]
+
+ for valid_status in valid_statuses:
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(valid_status),
+ )
+
+ # Act and Assert
+ assert not d.is_stopped(or_created=False)
+
+ def test_is_stopped_or_created_true(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ valid_statuses = [
+ PREDICTOR_STATE.STATUS_CREATED,
+ PREDICTOR_STATE.STATUS_CREATING,
+ # PREDICTOR_STATE.STATUS_FAILED,
+ # PREDICTOR_STATE.STATUS_IDLE,
+ # PREDICTOR_STATE.STATUS_STARTING,
+ # PREDICTOR_STATE.STATUS_RUNNING,
+ PREDICTOR_STATE.STATUS_STOPPED,
+ # PREDICTOR_STATE.STATUS_STOPPING,
+ # PREDICTOR_STATE.STATUS_UPDATING,
+ ]
+
+ for valid_status in valid_statuses:
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(valid_status),
+ )
+
+ # Act and Assert
+ assert d.is_stopped(or_created=True)
+
+ def test_is_stopped_or_created_false(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockPredictorState:
+ def __init__(self, status):
+ self.status = status
+
+ valid_statuses = [
+ # PREDICTOR_STATE.STATUS_CREATED,
+ # PREDICTOR_STATE.STATUS_CREATING,
+ PREDICTOR_STATE.STATUS_FAILED,
+ PREDICTOR_STATE.STATUS_IDLE,
+ PREDICTOR_STATE.STATUS_STARTING,
+ PREDICTOR_STATE.STATUS_RUNNING,
+ # PREDICTOR_STATE.STATUS_STOPPED,
+ PREDICTOR_STATE.STATUS_STOPPING,
+ PREDICTOR_STATE.STATUS_UPDATING,
+ ]
+
+ for valid_status in valid_statuses:
+ mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_state",
+ return_value=MockPredictorState(valid_status),
+ )
+
+ # Act and Assert
+ assert not d.is_stopped(or_created=True)
+
+ # predict
+
+ def test_predict(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_predict = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.predict"
+ )
+
+ # Act
+ d.predict("data", "inputs")
+
+ # Assert
+ mock_serving_engine_predict.assert_called_once_with(d, "data", "inputs")
+
+ # download artifact
+
+ def test_download_artifact(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_serving_engine_download_artifact = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.download_artifact"
+ )
+
+ # Act
+ d.download_artifact()
+
+ # Assert
+ mock_serving_engine_download_artifact.assert_called_once_with(d)
+
+ # get logs
+
+ def test_get_logs_default(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_util_get_members = mocker.patch(
+ "hsml.util.get_members", return_value=["predictor"]
+ )
+ mock_print = mocker.patch("builtins.print")
+
+ class MockLogs:
+ instance_name = "instance_name"
+ content = "content"
+
+ mock_logs = [MockLogs()]
+ mock_serving_get_logs = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_logs",
+ return_value=mock_logs,
+ )
+
+ # Act
+ d.get_logs()
+
+ # Assert
+ mock_util_get_members.assert_called_once()
+ mock_serving_get_logs.assert_called_once_with(d, "predictor", 10)
+ assert mock_print.call_count == len(mock_logs)
+
+ def test_get_logs_component_valid(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_util_get_members = mocker.patch(
+ "hsml.util.get_members", return_value=["valid"]
+ )
+ mock_print = mocker.patch("builtins.print")
+
+ class MockLogs:
+ instance_name = "instance_name"
+ content = "content"
+
+ mock_logs = [MockLogs()]
+ mock_serving_get_logs = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_logs",
+ return_value=mock_logs,
+ )
+
+ # Act
+ d.get_logs(component="valid")
+
+ # Assert
+ mock_util_get_members.assert_called_once()
+ mock_serving_get_logs.assert_called_once_with(d, "valid", 10)
+ assert mock_print.call_count == len(mock_logs)
+
+ def test_get_logs_component_invalid(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ d.get_logs(component="invalid")
+
+ # Assert
+ assert "is not valid" in str(e_info.value)
+
+ def test_get_logs_tail(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_util_get_members = mocker.patch(
+ "hsml.util.get_members", return_value=["predictor"]
+ )
+ mock_print = mocker.patch("builtins.print")
+
+ class MockLogs:
+ instance_name = "instance_name"
+ content = "content"
+
+ mock_logs = [MockLogs()]
+ mock_serving_get_logs = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_logs",
+ return_value=mock_logs,
+ )
+
+ # Act
+ d.get_logs(tail=40)
+
+ # Assert
+ mock_util_get_members.assert_called_once()
+ mock_serving_get_logs.assert_called_once_with(d, "predictor", 40)
+ assert mock_print.call_count == len(mock_logs)
+
+ def test_get_logs_no_logs(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+ mock_util_get_members = mocker.patch(
+ "hsml.util.get_members", return_value=["predictor"]
+ )
+ mock_print = mocker.patch("builtins.print")
+
+ mock_serving_get_logs = mocker.patch(
+ "hsml.engine.serving_engine.ServingEngine.get_logs",
+ return_value=None,
+ )
+
+ # Act
+ d.get_logs()
+
+ # Assert
+ mock_util_get_members.assert_called_once()
+ mock_serving_get_logs.assert_called_once_with(d, "predictor", 10)
+ assert mock_print.call_count == 0
+
+ # get url
+
+ def test_get_url(self, mocker, backend_fixtures):
+ # Arrange
+ p = self._get_dummy_predictor(mocker, backend_fixtures)
+ d = deployment.Deployment(predictor=p)
+
+ class MockClient:
+ _project_id = "project_id"
+
+ mock_client = MockClient()
+ path = "/p/" + str(mock_client._project_id) + "/deployments/" + str(d.id)
+
+ mock_util_get_hostname_replaced_url = mocker.patch(
+ "hsml.util.get_hostname_replaced_url", return_value="url"
+ )
+ mock_client_get_instance = mocker.patch(
+ "hsml.client.get_instance", return_value=mock_client
+ )
+
+ # Act
+ url = d.get_url()
+
+ # Assert
+ assert url == "url"
+ mock_util_get_hostname_replaced_url.assert_called_once_with(path)
+ mock_client_get_instance.assert_called_once()
+
+ # auxiliary methods
+
+ def _get_dummy_predictor(self, mocker, backend_fixtures):
+ p_json = backend_fixtures["predictor"]["get_deployments_singleton"]["response"][
+ "items"
+ ][0]
+ mocker.patch("hsml.predictor.Predictor._validate_serving_tool")
+ mocker.patch("hsml.predictor.Predictor._validate_resources")
+ mocker.patch("hsml.predictor.Predictor._validate_script_file")
+ mocker.patch("hsml.util.get_obj_from_json")
+ return predictor.Predictor(
+ id=p_json["id"],
+ name=p_json["name"],
+ description=p_json["description"],
+ model_name=p_json["model_name"],
+ model_path=p_json["model_path"],
+ model_version=p_json["model_version"],
+ model_framework=p_json["model_framework"],
+ model_server=p_json["model_server"],
+ artifact_version=p_json["artifact_version"],
+ )
diff --git a/hsml/python/tests/test_explicit_provenance.py b/hsml/python/tests/test_explicit_provenance.py
new file mode 100644
index 000000000..eb396b364
--- /dev/null
+++ b/hsml/python/tests/test_explicit_provenance.py
@@ -0,0 +1,78 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+from unittest import mock
+
+from hsml.core import explicit_provenance
+
+
+class TestExplicitProvenance(unittest.TestCase):
+ def test_one_accessible_parent(self):
+ artifact = {"id": 1}
+ links = explicit_provenance.Links(accessible=[artifact])
+ parent = explicit_provenance.Links.get_one_accessible_parent(links)
+ self.assertEqual(artifact["id"], parent["id"])
+
+ def test_one_accessible_parent_none(self):
+ links = explicit_provenance.Links()
+ with mock.patch.object(explicit_provenance._logger, "info") as mock_logger:
+ parent = explicit_provenance.Links.get_one_accessible_parent(links)
+ mock_logger.assert_called_once_with("There is no parent information")
+ self.assertIsNone(parent)
+
+ def test_one_accessible_parent_inaccessible(self):
+ artifact = {"id": 1}
+ links = explicit_provenance.Links(inaccessible=[artifact])
+ with mock.patch.object(explicit_provenance._logger, "info") as mock_logger:
+ parent = explicit_provenance.Links.get_one_accessible_parent(links)
+ mock_logger.assert_called_once_with(
+ "The parent is deleted or inaccessible. For more details get the full provenance from `_provenance` method"
+ )
+ self.assertIsNone(parent)
+
+ def test_one_accessible_parent_deleted(self):
+ artifact = {"id": 1}
+ links = explicit_provenance.Links(deleted=[artifact])
+ with mock.patch.object(explicit_provenance._logger, "info") as mock_logger:
+ parent = explicit_provenance.Links.get_one_accessible_parent(links)
+ mock_logger.assert_called_once_with(
+ "The parent is deleted or inaccessible. For more details get the full provenance from `_provenance` method"
+ )
+ self.assertIsNone(parent)
+
+ def test_one_accessible_parent_too_many(self):
+ artifact1 = {"id": 1}
+ artifact2 = {"id": 2}
+ links = explicit_provenance.Links(accessible=[artifact1, artifact2])
+ with self.assertRaises(Exception) as context:
+ explicit_provenance.Links.get_one_accessible_parent(links)
+ self.assertTrue(
+ "Backend inconsistency - provenance returned more than one parent"
+ in context.exception
+ )
+
+ def test_one_accessible_parent_should_not_be_artifact(self):
+ artifact = explicit_provenance.Artifact(
+ 1, "test", 1, None, explicit_provenance.Artifact.MetaType.NOT_SUPPORTED
+ )
+ links = explicit_provenance.Links(accessible=[artifact])
+ with self.assertRaises(Exception) as context:
+ explicit_provenance.Links.get_one_accessible_parent(links)
+ self.assertTrue(
+ "The returned object is not a valid object. For more details get the full provenance from `_provenance` method"
+ in context.exception
+ )
diff --git a/hsml/python/tests/test_inference_batcher.py b/hsml/python/tests/test_inference_batcher.py
new file mode 100644
index 000000000..441fbff7e
--- /dev/null
+++ b/hsml/python/tests/test_inference_batcher.py
@@ -0,0 +1,234 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+
+import humps
+from hsml import inference_batcher
+from hsml.constants import INFERENCE_BATCHER
+
+
+class TestInferenceBatcher:
+ # from response json
+
+ def test_from_response_json_enabled(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_enabled"]["response"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_ib_from_json = mocker.patch(
+ "hsml.inference_batcher.InferenceBatcher.from_json"
+ )
+
+ # Act
+ _ = inference_batcher.InferenceBatcher.from_response_json(json_camelized)
+
+ # Assert
+ mock_ib_from_json.assert_called_once_with(json)
+
+ def test_from_response_json_enabled_with_config(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_enabled_with_config"][
+ "response"
+ ]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_ib_from_json = mocker.patch(
+ "hsml.inference_batcher.InferenceBatcher.from_json"
+ )
+
+ # Act
+ _ = inference_batcher.InferenceBatcher.from_response_json(json_camelized)
+
+ # Assert
+ mock_ib_from_json.assert_called_once_with(json)
+
+ # from json
+
+ def test_from_json_enabled(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_enabled"]["response"]
+ mock_ib_extract_fields = mocker.patch(
+ "hsml.inference_batcher.InferenceBatcher.extract_fields_from_json",
+ return_value=json,
+ )
+ mock_ib_init = mocker.patch(
+ "hsml.inference_batcher.InferenceBatcher.__init__", return_value=None
+ )
+
+ # Act
+ _ = inference_batcher.InferenceBatcher.from_json(json)
+
+ # Assert
+ mock_ib_extract_fields.assert_called_once_with(json)
+ mock_ib_init.assert_called_once_with(**json)
+
+ def test_from_json_enabled_with_config(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_enabled_with_config"][
+ "response"
+ ]
+ mock_ib_extract_fields = mocker.patch(
+ "hsml.inference_batcher.InferenceBatcher.extract_fields_from_json",
+ return_value=json,
+ )
+ mock_ib_init = mocker.patch(
+ "hsml.inference_batcher.InferenceBatcher.__init__", return_value=None
+ )
+
+ # Act
+ _ = inference_batcher.InferenceBatcher.from_json(json)
+
+ # Assert
+ mock_ib_extract_fields.assert_called_once_with(json)
+ mock_ib_init.assert_called_once_with(**json)
+
+ # constructor
+
+ def test_constructor_default(self):
+ # Act
+ ib = inference_batcher.InferenceBatcher()
+
+ # Assert
+ assert isinstance(ib, inference_batcher.InferenceBatcher)
+ assert ib.enabled == INFERENCE_BATCHER.ENABLED
+ assert ib.max_batch_size is None
+ assert ib.max_latency is None
+ assert ib.timeout is None
+
+ def test_constructor_enabled(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_enabled"]["response"]
+
+ # Act
+ ib = inference_batcher.InferenceBatcher(**json)
+
+ # Assert
+ assert isinstance(ib, inference_batcher.InferenceBatcher)
+ assert ib.enabled == json["enabled"]
+ assert ib.max_batch_size is None
+ assert ib.max_latency is None
+ assert ib.timeout is None
+
+ def test_constructor_enabled_with_config(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_enabled_with_config"][
+ "response"
+ ]
+
+ # Act
+ ib = inference_batcher.InferenceBatcher(**json)
+
+ # Assert
+ assert isinstance(ib, inference_batcher.InferenceBatcher)
+ assert ib.enabled == json["enabled"]
+ assert ib.max_batch_size == json["max_batch_size"]
+ assert ib.max_latency == json["max_latency"]
+ assert ib.timeout == json["timeout"]
+
+ def test_constructor_disabled(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_disabled"]["response"]
+
+ # Act
+ ib = inference_batcher.InferenceBatcher(**json)
+
+ # Assert
+ assert isinstance(ib, inference_batcher.InferenceBatcher)
+ assert ib.enabled == json["enabled"]
+ assert ib.max_batch_size is None
+ assert ib.max_latency is None
+ assert ib.timeout is None
+
+ def test_constructor_disabled_with_config(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_disabled_with_config"][
+ "response"
+ ]
+
+ # Act
+ ib = inference_batcher.InferenceBatcher(**json)
+
+ # Assert
+ assert isinstance(ib, inference_batcher.InferenceBatcher)
+ assert ib.enabled == json["enabled"]
+ assert ib.max_batch_size == json["max_batch_size"]
+ assert ib.max_latency == json["max_latency"]
+ assert ib.timeout == json["timeout"]
+
+ # # extract fields from json
+
+ def test_extract_fields_from_json_enabled(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_enabled"]["response"]
+ json_copy = copy.deepcopy(json)
+
+ # Act
+ kwargs = inference_batcher.InferenceBatcher.extract_fields_from_json(json_copy)
+
+ # Assert
+ assert kwargs["enabled"] == json["enabled"]
+ assert kwargs["max_batch_size"] is None
+ assert kwargs["max_latency"] is None
+ assert kwargs["timeout"] is None
+
+ def test_extract_fields_from_json_enabled_with_config(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_enabled_with_config"][
+ "response"
+ ]
+ json_copy = copy.deepcopy(json)
+
+ # Act
+ kwargs = inference_batcher.InferenceBatcher.extract_fields_from_json(json_copy)
+
+ # Assert
+ assert kwargs["enabled"] == json["enabled"]
+ assert kwargs["max_batch_size"] == json["max_batch_size"]
+ assert kwargs["max_latency"] == json["max_latency"]
+ assert kwargs["timeout"] == json["timeout"]
+
+ def test_extract_fields_from_json_enabled_nested(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_enabled"]["response_nested"]
+ json_copy = copy.deepcopy(json)
+
+ # Act
+ kwargs = inference_batcher.InferenceBatcher.extract_fields_from_json(json_copy)
+
+ # Assert
+ assert kwargs["enabled"] == json["batching_configuration"]["enabled"]
+ assert kwargs["max_batch_size"] is None
+ assert kwargs["max_latency"] is None
+ assert kwargs["timeout"] is None
+
+ def test_extract_fields_from_json_enabled_with_config_nested(
+ self, backend_fixtures
+ ):
+ # Arrange
+ json = backend_fixtures["inference_batcher"]["get_enabled_with_config"][
+ "response_nested"
+ ]
+ json_copy = copy.deepcopy(json)
+
+ # Act
+ kwargs = inference_batcher.InferenceBatcher.extract_fields_from_json(json_copy)
+
+ # Assert
+ assert kwargs["enabled"] == json["batching_configuration"]["enabled"]
+ assert (
+ kwargs["max_batch_size"] == json["batching_configuration"]["max_batch_size"]
+ )
+ assert kwargs["max_latency"] == json["batching_configuration"]["max_latency"]
+ assert kwargs["timeout"] == json["batching_configuration"]["timeout"]
diff --git a/hsml/python/tests/test_inference_endpoint.py b/hsml/python/tests/test_inference_endpoint.py
new file mode 100644
index 000000000..0f79a6ff3
--- /dev/null
+++ b/hsml/python/tests/test_inference_endpoint.py
@@ -0,0 +1,298 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+
+import humps
+from hsml import inference_endpoint
+
+
+class TestInferenceEndpoint:
+ # InferenceEndpointPort
+
+ # from response json
+
+ def test_from_response_json_port(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_port"]["response"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_ie_from_json = mocker.patch(
+ "hsml.inference_endpoint.InferenceEndpointPort.from_json"
+ )
+
+ # Act
+ _ = inference_endpoint.InferenceEndpointPort.from_response_json(json_camelized)
+
+ # Assert
+ mock_ie_from_json.assert_called_once_with(json)
+
+ # from json
+
+ def test_from_json_port(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_port"]["response"]
+ mock_ie_extract_fields = mocker.patch(
+ "hsml.inference_endpoint.InferenceEndpointPort.extract_fields_from_json",
+ return_value=json,
+ )
+ mock_ie_init = mocker.patch(
+ "hsml.inference_endpoint.InferenceEndpointPort.__init__", return_value=None
+ )
+
+ # Act
+ _ = inference_endpoint.InferenceEndpointPort.from_json(json)
+
+ # Assert
+ mock_ie_extract_fields.assert_called_once_with(json)
+ mock_ie_init.assert_called_once_with(**json)
+
+ # constructor
+
+ def test_constructor_port(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_port"]["response"]
+
+ # Act
+ ie_port = inference_endpoint.InferenceEndpointPort(
+ name=json["name"], number=json["number"]
+ )
+
+ # Assert
+ assert ie_port.name == json["name"]
+ assert ie_port.number == json["number"]
+
+ # extract fields from json
+
+ def test_extract_fields_from_json_port(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_port"]["response"]
+ json_copy = copy.deepcopy(json)
+
+ # Act
+ kwargs = inference_endpoint.InferenceEndpointPort.extract_fields_from_json(
+ json_copy
+ )
+
+ # Assert
+ assert kwargs["name"] == json["name"]
+ assert kwargs["number"] == json["number"]
+
+ # InferenceEndpoint
+
+ # from response json
+
+ def test_from_response_json_empty(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_empty"]["response"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_ie_from_json = mocker.patch(
+ "hsml.inference_endpoint.InferenceEndpoint.from_json"
+ )
+
+ # Act
+ ie = inference_endpoint.InferenceEndpoint.from_response_json(json_camelized)
+
+ # Assert
+ assert isinstance(ie, list)
+ assert len(ie) == 0
+ mock_ie_from_json.assert_not_called()
+
+ def test_from_response_json_singleton(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_singleton"]["response"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_ie_from_json = mocker.patch(
+ "hsml.inference_endpoint.InferenceEndpoint.from_json"
+ )
+
+ # Act
+ ie = inference_endpoint.InferenceEndpoint.from_response_json(json_camelized)
+
+ # Assert
+ assert isinstance(ie, list)
+ assert len(ie) == 1
+ mock_ie_from_json.assert_called_once_with(json["items"][0])
+
+ def test_from_response_json_list(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_list"]["response"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_ie_from_json = mocker.patch(
+ "hsml.inference_endpoint.InferenceEndpoint.from_json"
+ )
+
+ # Act
+ ie = inference_endpoint.InferenceEndpoint.from_response_json(json_camelized)
+
+ # Assert
+ assert isinstance(ie, list)
+ assert len(ie) == json["count"]
+ assert mock_ie_from_json.call_count == json["count"]
+
+ def test_from_response_json_single(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_singleton"]["response"][
+ "items"
+ ][0]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_ie_from_json = mocker.patch(
+ "hsml.inference_endpoint.InferenceEndpoint.from_json"
+ )
+
+ # Act
+ _ = inference_endpoint.InferenceEndpoint.from_response_json(json_camelized)
+
+ # Assert
+ mock_ie_from_json.assert_called_once_with(json)
+
+ # from json
+
+ def test_from_json(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_singleton"]["response"][
+ "items"
+ ][0]
+ mock_ie_extract_fields = mocker.patch(
+ "hsml.inference_endpoint.InferenceEndpoint.extract_fields_from_json",
+ return_value=json,
+ )
+ mock_ie_init = mocker.patch(
+ "hsml.inference_endpoint.InferenceEndpoint.__init__", return_value=None
+ )
+
+ # Act
+ _ = inference_endpoint.InferenceEndpoint.from_json(json)
+
+ # Assert
+ mock_ie_extract_fields.assert_called_once_with(json)
+ mock_ie_init.assert_called_once_with(**json)
+
+ # constructor
+
+ def test_constructor(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_singleton"]["response"][
+ "items"
+ ][0]
+
+ # Act
+ ie = inference_endpoint.InferenceEndpoint(
+ type=json["type"], hosts=json["hosts"], ports=json["ports"]
+ )
+
+ # Assert
+ assert isinstance(ie, inference_endpoint.InferenceEndpoint)
+ assert ie.type == json["type"]
+ assert ie.hosts == json["hosts"]
+ assert ie.ports == json["ports"]
+
+ # extract fields from json
+
+ def test_extract_fields_from_json(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_singleton"]["response"][
+ "items"
+ ][0]
+ json_copy = copy.deepcopy(json)
+ mock_ie_port_from_json = mocker.patch(
+ "hsml.inference_endpoint.InferenceEndpointPort.from_json", return_value=None
+ )
+
+ # Act
+ kwargs = inference_endpoint.InferenceEndpoint.extract_fields_from_json(
+ json_copy
+ )
+
+ # Assert
+ assert kwargs["type"] == json["type"]
+ assert kwargs["hosts"] == json["hosts"]
+ mock_ie_port_from_json.assert_called_once_with(json["ports"][0])
+
+ # get any host
+
+ def test_get_any_host(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_singleton"]["response"][
+ "items"
+ ][0]
+ ie = inference_endpoint.InferenceEndpoint(
+ type=None, hosts=json["hosts"], ports=None
+ )
+ mocker_random_choice = mocker.patch("random.choice", return_value=None)
+
+ # Act
+ _ = ie.get_any_host()
+
+ # Assert
+ mocker_random_choice.assert_called_once_with(ie.hosts)
+
+ def test_get_any_host_none(self, mocker, backend_fixtures):
+ # Arrange
+ ie = inference_endpoint.InferenceEndpoint(type=None, hosts=None, ports=None)
+ mocker_random_choice = mocker.patch("random.choice", return_value=None)
+
+ # Act
+ host = ie.get_any_host()
+
+ # Assert
+ assert host is None
+ mocker_random_choice.assert_not_called()
+
+ # get port
+
+ def test_get_port_existing(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_list"]["response"]["items"][
+ 1
+ ]
+ ports = [
+ inference_endpoint.InferenceEndpointPort(p["name"], p["number"])
+ for p in json["ports"]
+ ]
+ ie = inference_endpoint.InferenceEndpoint(type=None, hosts=None, ports=ports)
+
+ # Act
+ port = ie.get_port(ports[0].name)
+
+ # Assert
+ assert port == ports[0]
+
+ def test_get_port_not_found(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_endpoint"]["get_list"]["response"]["items"][
+ 1
+ ]
+ ports = [
+ inference_endpoint.InferenceEndpointPort(p["name"], p["number"])
+ for p in json["ports"]
+ ]
+ ie = inference_endpoint.InferenceEndpoint(type=None, hosts=None, ports=ports)
+
+ # Act
+ port = ie.get_port("not_found")
+
+ # Assert
+ assert port is None
+
+ def test_get_port_none(self):
+ # Arrange
+ ie = inference_endpoint.InferenceEndpoint(type=None, hosts=None, ports=None)
+
+ # Act
+ port = ie.get_port("not_found")
+
+ # Assert
+ assert port is None
diff --git a/hsml/python/tests/test_inference_logger.py b/hsml/python/tests/test_inference_logger.py
new file mode 100644
index 000000000..1f137cefa
--- /dev/null
+++ b/hsml/python/tests/test_inference_logger.py
@@ -0,0 +1,413 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+
+import humps
+import pytest
+from hsml import inference_logger, kafka_topic
+from hsml.constants import DEFAULT, INFERENCE_LOGGER
+
+
+class TestInferenceLogger:
+ # from response json
+
+ def test_from_response_json_with_mode_only(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_all"]["response"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_il_from_json = mocker.patch(
+ "hsml.inference_logger.InferenceLogger.from_json"
+ )
+
+ # Act
+ _ = inference_logger.InferenceLogger.from_response_json(json_camelized)
+
+ # Assert
+ mock_il_from_json.assert_called_once_with(json)
+
+ def test_from_response_json_with_mode_and_kafka_topic(
+ self, mocker, backend_fixtures
+ ):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_all_with_kafka_topic"][
+ "response"
+ ]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_il_from_json = mocker.patch(
+ "hsml.inference_logger.InferenceLogger.from_json"
+ )
+
+ # Act
+ _ = inference_logger.InferenceLogger.from_response_json(json_camelized)
+
+ # Assert
+ mock_il_from_json.assert_called_once_with(json)
+
+ # from json
+
+ def test_from_json_with_mode_all(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_all"]["response"]
+ mock_il_extract_fields = mocker.patch(
+ "hsml.inference_logger.InferenceLogger.extract_fields_from_json",
+ return_value=json,
+ )
+ mock_il_init = mocker.patch(
+ "hsml.inference_logger.InferenceLogger.__init__", return_value=None
+ )
+
+ # Act
+ _ = inference_logger.InferenceLogger.from_json(json)
+
+ # Assert
+ mock_il_extract_fields.assert_called_once_with(json)
+ mock_il_init.assert_called_once_with(**json)
+
+ def test_from_json_with_mode_all_and_kafka_topic(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_all_with_kafka_topic"][
+ "response"
+ ]
+ mock_il_extract_fields = mocker.patch(
+ "hsml.inference_logger.InferenceLogger.extract_fields_from_json",
+ return_value=json,
+ )
+ mock_il_init = mocker.patch(
+ "hsml.inference_logger.InferenceLogger.__init__", return_value=None
+ )
+
+ # Act
+ _ = inference_logger.InferenceLogger.from_json(json)
+
+ # Assert
+ mock_il_extract_fields.assert_called_once_with(json)
+ mock_il_init.assert_called_once_with(**json)
+
+ # constructor
+
+ def test_constructor_default(self, mocker):
+ # Arrange
+ mock_il_validate_mode = mocker.patch(
+ "hsml.inference_logger.InferenceLogger._validate_mode",
+ return_value=INFERENCE_LOGGER.MODE_ALL,
+ )
+ default_kt = kafka_topic.KafkaTopic()
+ mock_util_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value=default_kt
+ )
+
+ # Act
+ il = inference_logger.InferenceLogger()
+
+ # Assert
+ assert isinstance(il, inference_logger.InferenceLogger)
+ assert il.mode == INFERENCE_LOGGER.MODE_ALL
+ assert isinstance(il.kafka_topic, kafka_topic.KafkaTopic)
+ assert il.kafka_topic.name == default_kt.name
+ assert il.kafka_topic.num_replicas == default_kt.num_replicas
+ assert il.kafka_topic.num_partitions == default_kt.num_partitions
+ mock_util_get_obj_from_json.assert_called_once_with(
+ DEFAULT, kafka_topic.KafkaTopic
+ )
+ mock_il_validate_mode.assert_called_once_with(
+ INFERENCE_LOGGER.MODE_ALL, default_kt
+ )
+
+ def test_constructor_mode_all(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_all"]["init_args"]
+ mock_il_validate_mode = mocker.patch(
+ "hsml.inference_logger.InferenceLogger._validate_mode",
+ return_value=json["mode"],
+ )
+ default_kt = kafka_topic.KafkaTopic()
+ mock_util_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value=default_kt
+ )
+
+ # Act
+ il = inference_logger.InferenceLogger(**json)
+
+ # Assert
+ assert isinstance(il, inference_logger.InferenceLogger)
+ assert il.mode == json["mode"]
+ assert isinstance(il.kafka_topic, kafka_topic.KafkaTopic)
+ assert il.kafka_topic.name == default_kt.name
+ assert il.kafka_topic.num_replicas == default_kt.num_replicas
+ assert il.kafka_topic.num_partitions == default_kt.num_partitions
+ mock_util_get_obj_from_json.assert_called_once_with(
+ DEFAULT, kafka_topic.KafkaTopic
+ )
+ mock_il_validate_mode.assert_called_once_with(json["mode"], default_kt)
+
+ def test_constructor_mode_inputs(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_inputs"]["init_args"]
+ mock_il_validate_mode = mocker.patch(
+ "hsml.inference_logger.InferenceLogger._validate_mode",
+ return_value=json["mode"],
+ )
+ default_kt = kafka_topic.KafkaTopic()
+ mock_util_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value=default_kt
+ )
+
+ # Act
+ il = inference_logger.InferenceLogger(**json)
+
+ # Assert
+ assert isinstance(il, inference_logger.InferenceLogger)
+ assert il.mode == json["mode"]
+ assert isinstance(il.kafka_topic, kafka_topic.KafkaTopic)
+ assert il.kafka_topic.name == default_kt.name
+ assert il.kafka_topic.num_replicas == default_kt.num_replicas
+ assert il.kafka_topic.num_partitions == default_kt.num_partitions
+ mock_util_get_obj_from_json.assert_called_once_with(
+ DEFAULT, kafka_topic.KafkaTopic
+ )
+ mock_il_validate_mode.assert_called_once_with(json["mode"], default_kt)
+
+ def test_constructor_mode_outputs(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_outputs"]["init_args"]
+ mock_il_validate_mode = mocker.patch(
+ "hsml.inference_logger.InferenceLogger._validate_mode",
+ return_value=json["mode"],
+ )
+ default_kt = kafka_topic.KafkaTopic()
+ mock_util_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value=default_kt
+ )
+
+ # Act
+ il = inference_logger.InferenceLogger(**json)
+
+ # Assert
+ assert isinstance(il, inference_logger.InferenceLogger)
+ assert il.mode == json["mode"]
+ assert isinstance(il.kafka_topic, kafka_topic.KafkaTopic)
+ assert il.kafka_topic.name == default_kt.name
+ assert il.kafka_topic.num_replicas == default_kt.num_replicas
+ assert il.kafka_topic.num_partitions == default_kt.num_partitions
+ mock_util_get_obj_from_json.assert_called_once_with(
+ DEFAULT, kafka_topic.KafkaTopic
+ )
+ mock_il_validate_mode.assert_called_once_with(json["mode"], default_kt)
+
+ def test_constructor_mode_all_and_kafka_topic(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_all_with_kafka_topic"][
+ "init_args"
+ ]
+ json_copy = copy.deepcopy(json)
+ mock_il_validate_mode = mocker.patch(
+ "hsml.inference_logger.InferenceLogger._validate_mode",
+ return_value=json["mode"],
+ )
+ kt = kafka_topic.KafkaTopic(json["kafka_topic"]["name"])
+ mock_util_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value=kt
+ )
+
+ # Act
+ il = inference_logger.InferenceLogger(**json_copy)
+
+ # Assert
+ assert isinstance(il, inference_logger.InferenceLogger)
+ assert il.mode == json["mode"]
+ assert isinstance(il.kafka_topic, kafka_topic.KafkaTopic)
+ assert il.kafka_topic.name == kt.name
+ assert il.kafka_topic.num_replicas is None
+ assert il.kafka_topic.num_partitions is None
+ mock_util_get_obj_from_json.assert_called_once_with(
+ json["kafka_topic"], kafka_topic.KafkaTopic
+ )
+ mock_il_validate_mode.assert_called_once_with(json["mode"], kt)
+
+ def test_constructor_mode_inputs_and_kafka_topic(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_inputs_with_kafka_topic"][
+ "init_args"
+ ]
+ json_copy = copy.deepcopy(json)
+ mock_il_validate_mode = mocker.patch(
+ "hsml.inference_logger.InferenceLogger._validate_mode",
+ return_value=json["mode"],
+ )
+ kt = kafka_topic.KafkaTopic(json["kafka_topic"]["name"])
+ mock_util_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value=kt
+ )
+
+ # Act
+ il = inference_logger.InferenceLogger(**json_copy)
+
+ # Assert
+ assert isinstance(il, inference_logger.InferenceLogger)
+ assert il.mode == json["mode"]
+ assert isinstance(il.kafka_topic, kafka_topic.KafkaTopic)
+ assert il.kafka_topic.name == kt.name
+ assert il.kafka_topic.num_replicas is None
+ assert il.kafka_topic.num_partitions is None
+ mock_util_get_obj_from_json.assert_called_once_with(
+ json["kafka_topic"], kafka_topic.KafkaTopic
+ )
+ mock_il_validate_mode.assert_called_once_with(json["mode"], kt)
+
+ def test_constructor_mode_outputs_and_kafka_topic(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_all_with_kafka_topic"][
+ "init_args"
+ ]
+ json_copy = copy.deepcopy(json)
+ mock_il_validate_mode = mocker.patch(
+ "hsml.inference_logger.InferenceLogger._validate_mode",
+ return_value=json["mode"],
+ )
+ kt = kafka_topic.KafkaTopic(json["kafka_topic"]["name"])
+ mock_util_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value=kt
+ )
+
+ # Act
+ il = inference_logger.InferenceLogger(**json_copy)
+
+ # Assert
+ assert isinstance(il, inference_logger.InferenceLogger)
+ assert il.mode == json["mode"]
+ assert isinstance(il.kafka_topic, kafka_topic.KafkaTopic)
+ assert il.kafka_topic.name == kt.name
+ assert il.kafka_topic.num_replicas is None
+ assert il.kafka_topic.num_partitions is None
+ mock_util_get_obj_from_json.assert_called_once_with(
+ json["kafka_topic"], kafka_topic.KafkaTopic
+ )
+ mock_il_validate_mode.assert_called_once_with(json["mode"], kt)
+
+ def test_constructor_mode_none_and_kafka_topic(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_none_with_kafka_topic"][
+ "init_args"
+ ]
+ json_copy = copy.deepcopy(json)
+ mock_il_validate_mode = mocker.patch(
+ "hsml.inference_logger.InferenceLogger._validate_mode",
+ return_value=json["mode"],
+ )
+ kt = kafka_topic.KafkaTopic(json["kafka_topic"]["name"])
+ mock_util_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value=kt
+ )
+
+ # Act
+ il = inference_logger.InferenceLogger(**json_copy)
+
+ # Assert
+ assert isinstance(il, inference_logger.InferenceLogger)
+ assert il.mode == json["mode"]
+ assert isinstance(il.kafka_topic, kafka_topic.KafkaTopic)
+ assert il.kafka_topic.name == kt.name
+ assert il.kafka_topic.num_replicas is None
+ assert il.kafka_topic.num_partitions is None
+ mock_util_get_obj_from_json.assert_called_once_with(
+ json["kafka_topic"], kafka_topic.KafkaTopic
+ )
+ mock_il_validate_mode.assert_called_once_with(json["mode"], kt)
+
+ # validate mode
+
+ def test_validate_mode_none_and_kafka_topic_none(self):
+ # Act
+ mode = inference_logger.InferenceLogger._validate_mode(None, None)
+
+ # Assert
+ assert mode is None
+
+ def test_validate_mode_all_and_kafka_topic_none(self):
+ # Act
+ mode = inference_logger.InferenceLogger._validate_mode(
+ INFERENCE_LOGGER.MODE_ALL, None
+ )
+
+ # Assert
+ assert mode is None
+
+ def test_validate_mode_invalid_and_kafka_topic_none(self):
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = inference_logger.InferenceLogger._validate_mode("invalid", None)
+
+ # Assert
+ assert "is not valid" in str(e_info.value)
+
+ def test_validate_mode_none_and_kafka_topic(self):
+ # Act
+ mode = inference_logger.InferenceLogger._validate_mode(
+ None, kafka_topic.KafkaTopic()
+ )
+
+ # Assert
+ assert mode == INFERENCE_LOGGER.MODE_NONE
+
+ def test_validate_mode_all_and_kafka_topic(self):
+ # Act
+ mode = inference_logger.InferenceLogger._validate_mode(
+ INFERENCE_LOGGER.MODE_ALL, kafka_topic.KafkaTopic()
+ )
+
+ # Assert
+ assert mode == INFERENCE_LOGGER.MODE_ALL
+
+ def test_validate_mode_invalid_and_kafka_topic(self):
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = inference_logger.InferenceLogger._validate_mode(
+ "invalid", kafka_topic.KafkaTopic()
+ )
+
+ # Assert
+ assert "is not valid" in str(e_info.value)
+
+ # extract fields from json
+
+ def test_extract_fields_from_json(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_all_with_kafka_topic"][
+ "response"
+ ]
+ json_copy = copy.deepcopy(json)
+
+ # Act
+ kwargs = inference_logger.InferenceLogger.extract_fields_from_json(json_copy)
+
+ # Assert
+ assert kwargs["kafka_topic"] == json["kafka_topic_dto"]
+ assert kwargs["mode"] == json["inference_logging"]
+
+ def test_extract_fields_from_json_alternative(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["inference_logger"]["get_mode_all_with_kafka_topic"][
+ "init_args"
+ ]
+ json_copy = copy.deepcopy(json)
+
+ # Act
+ kwargs = inference_logger.InferenceLogger.extract_fields_from_json(json_copy)
+
+ # Assert
+ assert kwargs["kafka_topic"] == json["kafka_topic"]
+ assert kwargs["mode"] == json["mode"]
diff --git a/hsml/python/tests/test_kafka_topic.py b/hsml/python/tests/test_kafka_topic.py
new file mode 100644
index 000000000..b9ada2a91
--- /dev/null
+++ b/hsml/python/tests/test_kafka_topic.py
@@ -0,0 +1,289 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+
+import humps
+import pytest
+from hsml import kafka_topic
+from hsml.constants import KAFKA_TOPIC
+
+
+class TestKafkaTopic:
+ # from response json
+
+ def test_from_response_json_with_name_only(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_existing_with_name_only"][
+ "response"
+ ]["kafka_topic_dto"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_kt_from_json = mocker.patch("hsml.kafka_topic.KafkaTopic.from_json")
+
+ # Act
+ _ = kafka_topic.KafkaTopic.from_response_json(json_camelized)
+
+ # Assert
+ mock_kt_from_json.assert_called_once_with(json)
+
+ def test_from_response_json_with_name_and_config(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_existing_with_name_and_config"][
+ "response"
+ ]["kafka_topic_dto"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+ mock_kt_from_json = mocker.patch("hsml.kafka_topic.KafkaTopic.from_json")
+
+ # Act
+ _ = kafka_topic.KafkaTopic.from_response_json(json_camelized)
+
+ # Assert
+ mock_kt_from_json.assert_called_once_with(json)
+
+ # from json
+
+ def test_from_json_with_name_only(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_existing_with_name_only"][
+ "response"
+ ]["kafka_topic_dto"]
+ mock_kt_extract_fields = mocker.patch(
+ "hsml.kafka_topic.KafkaTopic.extract_fields_from_json", return_value=json
+ )
+ mock_kt_init = mocker.patch(
+ "hsml.kafka_topic.KafkaTopic.__init__", return_value=None
+ )
+
+ # Act
+ _ = kafka_topic.KafkaTopic.from_response_json(json)
+
+ # Assert
+ mock_kt_extract_fields.assert_called_once_with(json)
+ mock_kt_init.assert_called_once_with(**json)
+
+ def test_from_json_with_name_and_config(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_existing_with_name_and_config"][
+ "response"
+ ]["kafka_topic_dto"]
+ mock_kt_extract_fields = mocker.patch(
+ "hsml.kafka_topic.KafkaTopic.extract_fields_from_json", return_value=json
+ )
+ mock_kt_init = mocker.patch(
+ "hsml.kafka_topic.KafkaTopic.__init__", return_value=None
+ )
+
+ # Act
+ _ = kafka_topic.KafkaTopic.from_response_json(json)
+
+ # Assert
+ mock_kt_extract_fields.assert_called_once_with(json)
+ mock_kt_init.assert_called_once_with(**json)
+
+ # constructor
+
+ def test_constructor_existing_with_name_only(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_existing_with_name_only"][
+ "response"
+ ]["kafka_topic_dto"]
+ mock_kt_validate_topic_config = mocker.patch(
+ "hsml.kafka_topic.KafkaTopic._validate_topic_config",
+ return_value=(KAFKA_TOPIC.NUM_REPLICAS, KAFKA_TOPIC.NUM_PARTITIONS),
+ )
+
+ # Act
+ kt = kafka_topic.KafkaTopic(**json)
+
+ # Assert
+ assert isinstance(kt, kafka_topic.KafkaTopic)
+ assert kt.name == json["name"]
+ assert kt.num_replicas == KAFKA_TOPIC.NUM_REPLICAS
+ assert kt.num_partitions == KAFKA_TOPIC.NUM_PARTITIONS
+ mock_kt_validate_topic_config.assert_called_once_with(json["name"], None, None)
+
+ def test_constructor_existing_with_name_and_config(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_existing_with_name_and_config"][
+ "response"
+ ]["kafka_topic_dto"]
+ mock_kt_validate_topic_config = mocker.patch(
+ "hsml.kafka_topic.KafkaTopic._validate_topic_config",
+ return_value=(json["num_replicas"], json["num_partitions"]),
+ )
+
+ # Act
+ kt = kafka_topic.KafkaTopic(**json)
+
+ # Assert
+ assert isinstance(kt, kafka_topic.KafkaTopic)
+ assert kt.name == json["name"]
+ assert kt.num_replicas == json["num_replicas"]
+ assert kt.num_partitions == json["num_partitions"]
+ mock_kt_validate_topic_config.assert_called_once_with(
+ json["name"], json["num_replicas"], json["num_partitions"]
+ )
+
+ # validate topic config
+
+ def test_validate_topic_config_existing_with_name_only(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_existing_with_name_only"][
+ "response"
+ ]["kafka_topic_dto"]
+
+ # Act
+ num_repl, num_part = kafka_topic.KafkaTopic._validate_topic_config(
+ json["name"], None, None
+ )
+
+ # Assert
+ assert num_repl is None
+ assert num_part is None
+
+ def test_validate_topic_config_existing_with_name_and_config(
+ self, backend_fixtures
+ ):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_existing_with_name_and_config"][
+ "response"
+ ]["kafka_topic_dto"]
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ num_repl, num_part = kafka_topic.KafkaTopic._validate_topic_config(
+ json["name"], json["num_replicas"], json["num_partitions"]
+ )
+
+ # Assert
+ assert "Number of replicas or partitions cannot be changed" in str(e_info.value)
+
+ def test_validate_topic_config_none(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_none"]["response"][
+ "kafka_topic_dto"
+ ]
+
+ # Act
+ num_repl, num_part = kafka_topic.KafkaTopic._validate_topic_config(
+ json["name"], None, None
+ )
+
+ # Assert
+ assert num_repl is None
+ assert num_part is None
+
+ def test_validate_topic_config_none_with_config(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_none_with_config"]["response"][
+ "kafka_topic_dto"
+ ]
+
+ # Act
+ num_repl, num_part = kafka_topic.KafkaTopic._validate_topic_config(
+ json["name"], json["num_replicas"], json["num_partitions"]
+ )
+
+ # Assert
+ assert num_repl is None
+ assert num_part is None
+
+ def test_validate_topic_config_none_value(self):
+ # Act
+ num_repl, num_part = kafka_topic.KafkaTopic._validate_topic_config(
+ None, None, None
+ )
+
+ # Assert
+ assert num_repl is None
+ assert num_part is None
+
+ def test_validate_topic_config_none_value_with_config(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_none_with_config"]["response"][
+ "kafka_topic_dto"
+ ]
+
+ # Act
+ num_repl, num_part = kafka_topic.KafkaTopic._validate_topic_config(
+ None, json["num_replicas"], json["num_partitions"]
+ )
+
+ # Assert
+ assert num_repl is None
+ assert num_part is None
+
+ def test_validate_topic_config_create(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_create_with_name_only"]["response"][
+ "kafka_topic_dto"
+ ]
+
+ # Act
+ num_repl, num_part = kafka_topic.KafkaTopic._validate_topic_config(
+ json["name"], None, None
+ )
+
+ # Assert
+ assert num_repl == KAFKA_TOPIC.NUM_REPLICAS
+ assert num_part == KAFKA_TOPIC.NUM_PARTITIONS
+
+ def test_validate_topic_config_create_with_config(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_create_with_name_and_config"][
+ "response"
+ ]["kafka_topic_dto"]
+
+ # Act
+ num_repl, num_part = kafka_topic.KafkaTopic._validate_topic_config(
+ json["name"], json["num_replicas"], json["num_partitions"]
+ )
+
+ # Assert
+ assert num_repl == json["num_replicas"]
+ assert num_part == json["num_partitions"]
+
+ # extract fields from json
+
+ def test_extract_fields_from_json(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"]["get_existing_with_name_and_config"][
+ "response"
+ ]["kafka_topic_dto"]
+ json_copy = copy.deepcopy(json)
+
+ # Act
+ kwargs = kafka_topic.KafkaTopic.extract_fields_from_json(json_copy)
+
+ # Assert
+ assert kwargs["name"] == json["name"]
+ assert kwargs["num_replicas"] == json["num_replicas"]
+ assert kwargs["num_partitions"] == json["num_partitions"]
+
+ def test_extract_fields_from_json_alternative(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["kafka_topic"][
+ "get_existing_with_name_and_config_alternative"
+ ]["response"]["kafka_topic_dto"]
+ json_copy = copy.deepcopy(json)
+
+ # Act
+ kwargs = kafka_topic.KafkaTopic.extract_fields_from_json(json_copy)
+
+ # Assert
+ assert kwargs["name"] == json["name"]
+ assert kwargs["num_replicas"] == json["num_of_replicas"]
+ assert kwargs["num_partitions"] == json["num_of_partitions"]
diff --git a/hsml/python/tests/test_model.py b/hsml/python/tests/test_model.py
new file mode 100644
index 000000000..31757c062
--- /dev/null
+++ b/hsml/python/tests/test_model.py
@@ -0,0 +1,470 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+import os
+
+import humps
+from hsml import model
+from hsml.constants import MODEL
+from hsml.core import explicit_provenance
+
+
+class TestModel:
+ # from response json
+
+ def test_from_response_json_empty(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_empty"]["response"]
+
+ # Act
+ m_lst = model.Model.from_response_json(json)
+
+ # Assert
+ assert isinstance(m_lst, list)
+ assert len(m_lst) == 0
+
+ def test_from_response_json_singleton(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_python"]["response"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+
+ # Act
+ m = model.Model.from_response_json(copy.deepcopy(json_camelized))
+
+ # Assert
+ assert isinstance(m, list)
+ assert len(m) == 1
+
+ m = m[0]
+ m_json = json["items"][0]
+
+ self.assert_model(mocker, m, m_json, MODEL.FRAMEWORK_PYTHON)
+
+ def test_from_response_json_list(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_list"]["response"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+
+ # Act
+ m_lst = model.Model.from_response_json(copy.deepcopy(json_camelized))
+
+ # Assert
+ assert isinstance(m_lst, list)
+ assert len(m_lst) == 2
+
+ for i in range(len(m_lst)):
+ m = m_lst[i]
+ m_json = json["items"][i]
+ self.assert_model(mocker, m, m_json, MODEL.FRAMEWORK_PYTHON)
+
+ # constructor
+
+ def test_constructor_base(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_base"]["response"]["items"][0]
+ m_json = copy.deepcopy(json)
+ id = m_json.pop("id")
+ name = m_json.pop("name")
+
+ # Act
+ m = model.Model(id=id, name=name, **m_json)
+
+ # Assert
+ self.assert_model(mocker, m, json, None)
+
+ def test_constructor_python(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+ m_json = copy.deepcopy(json)
+ id = m_json.pop("id")
+ name = m_json.pop("name")
+
+ # Act
+ m = model.Model(id=id, name=name, **m_json)
+
+ # Assert
+ self.assert_model(mocker, m, json, MODEL.FRAMEWORK_PYTHON)
+
+ def test_constructor_sklearn(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_sklearn"]["response"]["items"][0]
+ m_json = copy.deepcopy(json)
+ id = m_json.pop("id")
+ name = m_json.pop("name")
+
+ # Act
+ m = model.Model(id=id, name=name, **m_json)
+
+ # Assert
+ self.assert_model(mocker, m, json, MODEL.FRAMEWORK_SKLEARN)
+
+ def test_constructor_tensorflow(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_tensorflow"]["response"]["items"][0]
+ m_json = copy.deepcopy(json)
+ id = m_json.pop("id")
+ name = m_json.pop("name")
+
+ # Act
+ m = model.Model(id=id, name=name, **m_json)
+
+ # Assert
+ self.assert_model(mocker, m, json, MODEL.FRAMEWORK_TENSORFLOW)
+
+ def test_constructor_torch(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_torch"]["response"]["items"][0]
+ m_json = copy.deepcopy(json)
+ id = m_json.pop("id")
+ name = m_json.pop("name")
+
+ # Act
+ m = model.Model(id=id, name=name, **m_json)
+
+ # Assert
+ self.assert_model(mocker, m, json, MODEL.FRAMEWORK_TORCH)
+
+ # save
+
+ def test_save(self, mocker, backend_fixtures):
+ # Arrange
+ m_json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+ mock_model_engine_save = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.save"
+ )
+ upload_configuration = {"config": "value"}
+
+ # Act
+ m = model.Model.from_response_json(m_json)
+ m.save(
+ model_path="model_path",
+ await_registration=1234,
+ keep_original_files=True,
+ upload_configuration=upload_configuration,
+ )
+
+ # Assert
+ mock_model_engine_save.assert_called_once_with(
+ model_instance=m,
+ model_path="model_path",
+ await_registration=1234,
+ keep_original_files=True,
+ upload_configuration=upload_configuration,
+ )
+
+ # deploy
+
+ def test_deploy(self, mocker, backend_fixtures):
+ # Arrange
+ m_json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+ p_json = backend_fixtures["predictor"]["get_deployments_singleton"]["response"][
+ "items"
+ ][0]
+ mock_predictor = mocker.Mock()
+ mock_predictor_for_model = mocker.patch(
+ "hsml.predictor.Predictor.for_model", return_value=mock_predictor
+ )
+ # params
+ resources = copy.deepcopy(p_json["predictor_resources"])
+ inference_logger = {
+ "mode": p_json["inference_logging"],
+ "kafka_topic": copy.deepcopy(p_json["kafka_topic_dto"]),
+ }
+ inference_batcher = copy.deepcopy(p_json["batching_configuration"])
+ transformer = {
+ "script_file": p_json["transformer"],
+ "resources": copy.deepcopy(p_json["transformer_resources"]),
+ }
+
+ # Act
+ m = model.Model.from_response_json(m_json)
+ m.deploy(
+ name=p_json["name"],
+ description=p_json["description"],
+ artifact_version=p_json["artifact_version"],
+ serving_tool=p_json["serving_tool"],
+ script_file=p_json["predictor"],
+ resources=resources,
+ inference_logger=inference_logger,
+ inference_batcher=inference_batcher,
+ transformer=transformer,
+ api_protocol=p_json["api_protocol"],
+ )
+
+ # Assert
+ mock_predictor_for_model.assert_called_once_with(
+ m,
+ name=p_json["name"],
+ description=p_json["description"],
+ artifact_version=p_json["artifact_version"],
+ serving_tool=p_json["serving_tool"],
+ script_file=p_json["predictor"],
+ resources=resources,
+ inference_logger=inference_logger,
+ inference_batcher=inference_batcher,
+ transformer=transformer,
+ api_protocol=p_json["api_protocol"],
+ )
+ mock_predictor.deploy.assert_called_once()
+
+ # delete
+
+ def test_delete(self, mocker, backend_fixtures):
+ # Arrange
+ m_json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+ mock_model_engine_delete = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.delete"
+ )
+
+ # Act
+ m = model.Model.from_response_json(m_json)
+ m.delete()
+
+ # Assert
+ mock_model_engine_delete.assert_called_once_with(model_instance=m)
+
+ # download
+
+ def test_download(self, mocker, backend_fixtures):
+ # Arrange
+ m_json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+ mock_model_engine_download = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.download"
+ )
+
+ # Act
+ m = model.Model.from_response_json(m_json)
+ m.download()
+
+ # Assert
+ mock_model_engine_download.assert_called_once_with(model_instance=m)
+
+ # tags
+
+ def test_get_tag(self, mocker, backend_fixtures):
+ # Arrange
+ m_json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+ mock_model_engine_get_tag = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.get_tag"
+ )
+
+ # Act
+ m = model.Model.from_response_json(m_json)
+ m.get_tag("tag_name")
+
+ # Assert
+ mock_model_engine_get_tag.assert_called_once_with(
+ model_instance=m, name="tag_name"
+ )
+
+ def test_get_tags(self, mocker, backend_fixtures):
+ # Arrange
+ m_json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+ mock_model_engine_get_tags = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.get_tags"
+ )
+
+ # Act
+ m = model.Model.from_response_json(m_json)
+ m.get_tags()
+
+ # Assert
+ mock_model_engine_get_tags.assert_called_once_with(model_instance=m)
+
+ def test_set_tag(self, mocker, backend_fixtures):
+ # Arrange
+ m_json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+ mock_model_engine_set_tag = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.set_tag"
+ )
+
+ # Act
+ m = model.Model.from_response_json(m_json)
+ m.set_tag("tag_name", "tag_value")
+
+ # Assert
+ mock_model_engine_set_tag.assert_called_once_with(
+ model_instance=m, name="tag_name", value="tag_value"
+ )
+
+ def test_delete_tag(self, mocker, backend_fixtures):
+ # Arrange
+ m_json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+ mock_model_engine_delete_tag = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.delete_tag"
+ )
+
+ # Act
+ m = model.Model.from_response_json(m_json)
+ m.delete_tag("tag_name")
+
+ # Assert
+ mock_model_engine_delete_tag.assert_called_once_with(
+ model_instance=m, name="tag_name"
+ )
+
+ # get url
+
+ def test_get_url(self, mocker, backend_fixtures):
+ # Arrange
+ m_json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+
+ class ClientMock:
+ _project_id = 1
+
+ mock_client_get_instance = mocker.patch(
+ "hsml.client.get_instance", return_value=ClientMock()
+ )
+ mock_util_get_hostname_replaced_url = mocker.patch(
+ "hsml.util.get_hostname_replaced_url", return_value="full_path"
+ )
+ path_arg = "/p/1/models/" + m_json["name"] + "/" + str(m_json["version"])
+
+ # Act
+ m = model.Model.from_response_json(m_json)
+ url = m.get_url()
+
+ # Assert
+ assert url == "full_path"
+ mock_client_get_instance.assert_called_once()
+ mock_util_get_hostname_replaced_url.assert_called_once_with(sub_path=path_arg)
+
+ # auxiliary methods
+ def assert_model(self, mocker, m, m_json, model_framework):
+ assert isinstance(m, model.Model)
+ assert m.id == m_json["id"]
+ assert m.name == m_json["name"]
+ assert m.version == m_json["version"]
+ assert m.created == m_json["created"]
+ assert m.creator == m_json["creator"]
+ assert m.description == m_json["description"]
+ assert m.experiment_id == m_json["experiment_id"]
+ assert m.project_name == m_json["project_name"]
+ assert m.experiment_project_name == m_json["experiment_project_name"]
+ assert m.training_metrics == m_json["metrics"]
+ assert m._user_full_name == m_json["user_full_name"]
+ assert m.training_dataset == m_json["training_dataset"]
+ assert m.model_registry_id == m_json["model_registry_id"]
+
+ if model_framework is None:
+ assert m.framework is None
+ else:
+ assert m.framework == model_framework
+
+ mock_read_json = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.read_json",
+ return_value="input_example_content",
+ )
+ assert m.input_example == "input_example_content"
+ mock_read_json.assert_called_once_with(
+ model_instance=m, resource=m_json["input_example"]
+ )
+
+ mock_read_json = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.read_json",
+ return_value="model_schema_content",
+ )
+ assert m.model_schema == "model_schema_content"
+ mock_read_json.assert_called_once_with(
+ model_instance=m, resource=m_json["model_schema"]
+ )
+
+ mock_read_file = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.read_file",
+ return_value="program_file_content",
+ )
+ assert m.program == "program_file_content"
+ mock_read_file.assert_called_once_with(
+ model_instance=m, resource=m_json["program"]
+ )
+
+ mock_read_file = mocker.patch(
+ "hsml.engine.model_engine.ModelEngine.read_file",
+ return_value="env_file_content",
+ )
+ assert m.environment == "env_file_content"
+ mock_read_file.assert_called_once_with(
+ model_instance=m, resource=m_json["environment"]
+ )
+
+ def test_get_feature_view(self, mocker):
+ mock_fv = mocker.Mock()
+ links = explicit_provenance.Links(accessible=[mock_fv])
+ mock_fv_provenance = mocker.patch(
+ "hsml.model.Model.get_feature_view_provenance", return_value=links
+ )
+ mock_td_provenance = mocker.patch(
+ "hsml.model.Model.get_training_dataset_provenance", return_value=links
+ )
+ mocker.patch("os.environ", return_value={})
+ m = model.Model(1, "test")
+ m.get_feature_view()
+ mock_fv_provenance.assert_called_once()
+ mock_td_provenance.assert_called_once()
+ assert not mock_fv.init_serving.called
+ assert not mock_fv.init_batch_scoring.called
+
+ def test_get_feature_view_online(self, mocker):
+ mock_fv = mocker.Mock()
+ links = explicit_provenance.Links(accessible=[mock_fv])
+ mock_fv_provenance = mocker.patch(
+ "hsml.model.Model.get_feature_view_provenance", return_value=links
+ )
+ mock_td_provenance = mocker.patch(
+ "hsml.model.Model.get_training_dataset_provenance", return_value=links
+ )
+ mocker.patch("os.environ", return_value={})
+ m = model.Model(1, "test")
+ m.get_feature_view(online=True)
+ mock_fv_provenance.assert_called_once()
+ mock_td_provenance.assert_called_once()
+ assert mock_fv.init_serving.called
+ assert not mock_fv.init_batch_scoring.called
+
+ def test_get_feature_view_batch(self, mocker):
+ mock_fv = mocker.Mock()
+ links = explicit_provenance.Links(accessible=[mock_fv])
+ mock_fv_provenance = mocker.patch(
+ "hsml.model.Model.get_feature_view_provenance", return_value=links
+ )
+ mock_td_provenance = mocker.patch(
+ "hsml.model.Model.get_training_dataset_provenance", return_value=links
+ )
+ mocker.patch("os.environ", return_value={})
+ m = model.Model(1, "test")
+ m.get_feature_view(online=False)
+ mock_fv_provenance.assert_called_once()
+ mock_td_provenance.assert_called_once()
+ assert not mock_fv.init_serving.called
+ assert mock_fv.init_batch_scoring.called
+
+ def test_get_feature_view_deployment(self, mocker):
+ mock_fv = mocker.Mock()
+ links = explicit_provenance.Links(accessible=[mock_fv])
+ mock_fv_provenance = mocker.patch(
+ "hsml.model.Model.get_feature_view_provenance", return_value=links
+ )
+ mock_td_provenance = mocker.patch(
+ "hsml.model.Model.get_training_dataset_provenance", return_value=links
+ )
+ mocker.patch.dict(os.environ, {"DEPLOYMENT_NAME": "test"})
+ m = model.Model(1, "test")
+ m.get_feature_view()
+ mock_fv_provenance.assert_called_once()
+ mock_td_provenance.assert_called_once()
+ assert mock_fv.init_serving.called
+ assert not mock_fv.init_batch_scoring.called
diff --git a/hsml/python/tests/test_model_schema.py b/hsml/python/tests/test_model_schema.py
new file mode 100644
index 000000000..826975f9e
--- /dev/null
+++ b/hsml/python/tests/test_model_schema.py
@@ -0,0 +1,30 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+from hsml import model_schema
+
+
+class TestModelSchema:
+ # constructor
+
+ def test_constructor(self):
+ # Act
+ msch = model_schema.ModelSchema(input_schema="1234", output_schema="4321")
+
+ # Assert
+ assert msch.input_schema == "1234"
+ assert msch.output_schema == "4321"
diff --git a/hsml/python/tests/test_predictor.py b/hsml/python/tests/test_predictor.py
new file mode 100644
index 000000000..4cb29efdc
--- /dev/null
+++ b/hsml/python/tests/test_predictor.py
@@ -0,0 +1,703 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+
+import pytest
+from hsml import (
+ inference_batcher,
+ inference_logger,
+ predictor,
+ resources,
+ transformer,
+ util,
+)
+from hsml.constants import MODEL, PREDICTOR, RESOURCES
+
+
+SERVING_RESOURCE_LIMITS = {"cores": 2, "memory": 1024, "gpus": 2}
+SERVING_NUM_INSTANCES_NO_LIMIT = [-1]
+SERVING_NUM_INSTANCES_SCALE_TO_ZERO = [0]
+SERVING_NUM_INSTANCES_ONE = [0]
+
+
+class TestPredictor:
+ # from response json
+
+ def test_from_response_json_empty(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ json = backend_fixtures["predictor"]["get_deployments_empty"]["response"]
+
+ # Act
+ pred = predictor.Predictor.from_response_json(json)
+
+ # Assert
+ assert isinstance(pred, list)
+ assert len(pred) == 0
+
+ def test_from_response_json_singleton(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ json = backend_fixtures["predictor"]["get_deployments_singleton"]["response"]
+
+ # Act
+ pred = predictor.Predictor.from_response_json(json)
+
+ # Assert
+ assert isinstance(pred, list)
+ assert len(pred) == 1
+
+ p = pred[0]
+ p_json = json["items"][0]
+
+ assert isinstance(p, predictor.Predictor)
+ assert p.id == p_json["id"]
+ assert p.name == p_json["name"]
+ assert p.description == p_json["description"]
+ assert p.created_at == p_json["created"]
+ assert p.creator == p_json["creator"]
+ assert p.model_path == p_json["model_path"]
+ assert p.model_name == p_json["model_name"]
+ assert p.model_version == p_json["model_version"]
+ assert p.model_framework == p_json["model_framework"]
+ assert p.model_server == p_json["model_server"]
+ assert p.serving_tool == p_json["serving_tool"]
+ assert p.api_protocol == p_json["api_protocol"]
+ assert p.artifact_version == p_json["artifact_version"]
+ assert p.script_file == p_json["predictor"]
+ assert isinstance(p.resources, resources.PredictorResources)
+ assert isinstance(p.transformer, transformer.Transformer)
+ assert p.transformer.script_file == p_json["transformer"]
+ assert isinstance(p.transformer.resources, resources.TransformerResources)
+ assert isinstance(p.inference_logger, inference_logger.InferenceLogger)
+ assert p.inference_logger.mode == p_json["inference_logging"]
+ assert isinstance(p.inference_batcher, inference_batcher.InferenceBatcher)
+ assert p.inference_batcher.enabled == bool(
+ p_json["batching_configuration"]["batching_enabled"]
+ )
+
+ def test_from_response_json_list(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ json = backend_fixtures["predictor"]["get_deployments_list"]["response"]
+
+ # Act
+ pred = predictor.Predictor.from_response_json(json)
+
+ # Assert
+ assert isinstance(pred, list)
+ assert len(pred) == 2
+
+ for i in range(len(pred)):
+ p = pred[i]
+ p_json = json["items"][i]
+
+ assert isinstance(p, predictor.Predictor)
+ assert p.id == p_json["id"]
+ assert p.name == p_json["name"]
+ assert p.description == p_json["description"]
+ assert p.created_at == p_json["created"]
+ assert p.creator == p_json["creator"]
+ assert p.model_path == p_json["model_path"]
+ assert p.model_name == p_json["model_name"]
+ assert p.model_version == p_json["model_version"]
+ assert p.model_framework == p_json["model_framework"]
+ assert p.model_server == p_json["model_server"]
+ assert p.serving_tool == p_json["serving_tool"]
+ assert p.api_protocol == p_json["api_protocol"]
+ assert p.artifact_version == p_json["artifact_version"]
+ assert p.script_file == p_json["predictor"]
+ assert isinstance(p.resources, resources.PredictorResources)
+ assert isinstance(p.transformer, transformer.Transformer)
+ assert p.transformer.script_file == p_json["transformer"]
+ assert isinstance(p.transformer.resources, resources.TransformerResources)
+ assert isinstance(p.inference_logger, inference_logger.InferenceLogger)
+ assert p.inference_logger.mode == p_json["inference_logging"]
+ assert isinstance(p.inference_batcher, inference_batcher.InferenceBatcher)
+ assert p.inference_batcher.enabled == bool(
+ p_json["batching_configuration"]["batching_enabled"]
+ )
+
+ def test_from_response_json_single(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ p_json = backend_fixtures["predictor"]["get_deployments_singleton"]["response"][
+ "items"
+ ][0]
+
+ # Act
+ p = predictor.Predictor.from_response_json(p_json)
+
+ # Assert
+ assert isinstance(p, predictor.Predictor)
+ assert p.id == p_json["id"]
+ assert p.name == p_json["name"]
+ assert p.description == p_json["description"]
+ assert p.created_at == p_json["created"]
+ assert p.creator == p_json["creator"]
+ assert p.model_path == p_json["model_path"]
+ assert p.model_version == p_json["model_version"]
+ assert p.model_name == p_json["model_name"]
+ assert p.model_framework == p_json["model_framework"]
+ assert p.model_server == p_json["model_server"]
+ assert p.serving_tool == p_json["serving_tool"]
+ assert p.api_protocol == p_json["api_protocol"]
+ assert p.artifact_version == p_json["artifact_version"]
+ assert p.script_file == p_json["predictor"]
+ assert isinstance(p.resources, resources.PredictorResources)
+ assert isinstance(p.transformer, transformer.Transformer)
+ assert p.transformer.script_file == p_json["transformer"]
+ assert isinstance(p.transformer.resources, resources.TransformerResources)
+ assert isinstance(p.inference_logger, inference_logger.InferenceLogger)
+ assert p.inference_logger.mode == p_json["inference_logging"]
+ assert isinstance(p.inference_batcher, inference_batcher.InferenceBatcher)
+ assert p.inference_batcher.enabled == bool(
+ p_json["batching_configuration"]["batching_enabled"]
+ )
+
+ # constructor
+
+ def test_constructor(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ p_json = backend_fixtures["predictor"]["get_deployments_singleton"]["response"][
+ "items"
+ ][0]
+ mock_validate_serving_tool = mocker.patch(
+ "hsml.predictor.Predictor._validate_serving_tool",
+ return_value=p_json["serving_tool"],
+ )
+ mock_resources = util.get_obj_from_json(
+ copy.deepcopy(p_json["predictor_resources"]), resources.PredictorResources
+ )
+ mock_validate_resources = mocker.patch(
+ "hsml.predictor.Predictor._validate_resources",
+ return_value=mock_resources,
+ )
+ mock_validate_script_file = mocker.patch(
+ "hsml.predictor.Predictor._validate_script_file",
+ return_value=p_json["predictor"],
+ )
+
+ # Act
+ p = predictor.Predictor(
+ id=p_json["id"],
+ name=p_json["name"],
+ description=p_json["description"],
+ created_at=p_json["created"],
+ creator=p_json["creator"],
+ model_path=p_json["model_path"],
+ model_version=p_json["model_version"],
+ model_name=p_json["model_name"],
+ model_framework=p_json["model_framework"],
+ model_server=p_json["model_server"],
+ serving_tool=p_json["serving_tool"],
+ api_protocol=p_json["api_protocol"],
+ artifact_version=p_json["artifact_version"],
+ script_file=p_json["predictor"],
+ resources=p_json["predictor_resources"],
+ transformer={
+ "script_file": p_json["transformer"],
+ "resources": copy.deepcopy(p_json["transformer_resources"]),
+ },
+ inference_logger={
+ "mode": p_json["inference_logging"],
+ "kafka_topic": copy.deepcopy(p_json["kafka_topic_dto"]),
+ },
+ inference_batcher=copy.deepcopy(p_json["batching_configuration"]),
+ )
+
+ # Assert
+ assert p.id == p_json["id"]
+ assert p.name == p_json["name"]
+ assert p.description == p_json["description"]
+ assert p.created_at == p_json["created"]
+ assert p.creator == p_json["creator"]
+ assert p.model_path == p_json["model_path"]
+ assert p.model_name == p_json["model_name"]
+ assert p.model_version == p_json["model_version"]
+ assert p.model_framework == p_json["model_framework"]
+ assert p.model_server == p_json["model_server"]
+ assert p.serving_tool == p_json["serving_tool"]
+ assert p.api_protocol == p_json["api_protocol"]
+ assert p.artifact_version == p_json["artifact_version"]
+ assert p.script_file == p_json["predictor"]
+ assert isinstance(p.resources, resources.PredictorResources)
+ assert isinstance(p.transformer, transformer.Transformer)
+ assert p.transformer.script_file == p_json["transformer"]
+ assert isinstance(p.transformer.resources, resources.TransformerResources)
+ assert isinstance(p.inference_logger, inference_logger.InferenceLogger)
+ assert p.inference_logger.mode == p_json["inference_logging"]
+ assert isinstance(p.inference_batcher, inference_batcher.InferenceBatcher)
+ assert p.inference_batcher.enabled == bool(
+ p_json["batching_configuration"]["batching_enabled"]
+ )
+ mock_validate_serving_tool.assert_called_once_with(p_json["serving_tool"])
+ assert mock_validate_resources.call_count == 1
+ mock_validate_script_file.assert_called_once_with(
+ p_json["model_framework"], p_json["predictor"]
+ )
+
+ # validate serving tool
+
+ def test_validate_serving_tool_none(self):
+ # Act
+ st = predictor.Predictor._validate_serving_tool(None)
+
+ # Assert
+ assert st is None
+
+ def test_validate_serving_tool_valid(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, is_saas_connection=False
+ )
+
+ # Act
+ st = predictor.Predictor._validate_serving_tool(PREDICTOR.SERVING_TOOL_DEFAULT)
+
+ # Assert
+ assert st == PREDICTOR.SERVING_TOOL_DEFAULT
+
+ def test_validate_serving_tool_invalid(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, is_saas_connection=False
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = predictor.Predictor._validate_serving_tool("INVALID_NAME")
+
+ # Assert
+ assert "is not valid" in str(e_info.value)
+
+ def test_validate_serving_tool_valid_saas_connection(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, is_saas_connection=True
+ )
+
+ # Act
+ st = predictor.Predictor._validate_serving_tool(PREDICTOR.SERVING_TOOL_KSERVE)
+
+ # Assert
+ assert st == PREDICTOR.SERVING_TOOL_KSERVE
+
+ def test_validate_serving_tool_invalid_saas_connection(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, is_saas_connection=True
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = predictor.Predictor._validate_serving_tool(
+ PREDICTOR.SERVING_TOOL_DEFAULT
+ )
+
+ # Assert
+ assert "KServe deployments are the only supported" in str(e_info.value)
+
+ # validate script file
+
+ def test_validate_script_file_tf_none(self):
+ # Act
+ predictor.Predictor._validate_script_file(MODEL.FRAMEWORK_TENSORFLOW, None)
+
+ def test_validate_script_file_sk_none(self):
+ # Act
+ predictor.Predictor._validate_script_file(MODEL.FRAMEWORK_SKLEARN, None)
+
+ def test_validate_script_file_th_none(self):
+ # Act
+ predictor.Predictor._validate_script_file(MODEL.FRAMEWORK_TORCH, None)
+
+ def test_validate_script_file_py_none(self):
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = predictor.Predictor._validate_script_file(MODEL.FRAMEWORK_PYTHON, None)
+
+ # Assert
+ assert "Predictor scripts are required" in str(e_info.value)
+
+ def test_validate_script_file_tf_script_file(self):
+ # Act
+ predictor.Predictor._validate_script_file(
+ MODEL.FRAMEWORK_TENSORFLOW, "script_file"
+ )
+
+ def test_validate_script_file_sk_script_file(self):
+ # Act
+ predictor.Predictor._validate_script_file(
+ MODEL.FRAMEWORK_SKLEARN, "script_file"
+ )
+
+ def test_validate_script_file_th_script_file(self):
+ # Act
+ predictor.Predictor._validate_script_file(MODEL.FRAMEWORK_TORCH, "script_file")
+
+ def test_validate_script_file_py_script_file(self):
+ # Act
+ predictor.Predictor._validate_script_file(MODEL.FRAMEWORK_PYTHON, "script_file")
+
+ # infer model server
+
+ def test_infer_model_server_tf(self):
+ # Act
+ ms = predictor.Predictor._infer_model_server(MODEL.FRAMEWORK_TENSORFLOW)
+
+ # Assert
+ assert ms == PREDICTOR.MODEL_SERVER_TF_SERVING
+
+ def test_infer_model_server_sk(self):
+ # Act
+ ms = predictor.Predictor._infer_model_server(MODEL.FRAMEWORK_SKLEARN)
+
+ # Assert
+ assert ms == PREDICTOR.MODEL_SERVER_PYTHON
+
+ def test_infer_model_server_th(self):
+ # Act
+ ms = predictor.Predictor._infer_model_server(MODEL.FRAMEWORK_TORCH)
+
+ # Assert
+ assert ms == PREDICTOR.MODEL_SERVER_PYTHON
+
+ def test_infer_model_server_py(self):
+ # Act
+ ms = predictor.Predictor._infer_model_server(MODEL.FRAMEWORK_PYTHON)
+
+ # Assert
+ assert ms == PREDICTOR.MODEL_SERVER_PYTHON
+
+ # default serving tool
+
+ def test_get_default_serving_tool_kserve_installed(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, is_kserve_installed=True
+ )
+
+ # Act
+ st = predictor.Predictor._get_default_serving_tool()
+
+ # Assert
+ assert st == PREDICTOR.SERVING_TOOL_KSERVE
+
+ def test_get_default_serving_tool_kserve_not_installed(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, is_kserve_installed=False
+ )
+
+ # Act
+ st = predictor.Predictor._get_default_serving_tool()
+
+ # Assert
+ assert st == PREDICTOR.SERVING_TOOL_DEFAULT
+
+ # validate resources
+
+ def test_validate_resources_none_non_kserve(self):
+ # Act
+ res = predictor.Predictor._validate_resources(
+ None, PREDICTOR.SERVING_TOOL_DEFAULT
+ )
+
+ # Assert
+ assert res is None
+
+ def test_validate_resources_none_kserve(self):
+ # Act
+ res = predictor.Predictor._validate_resources(
+ None, PREDICTOR.SERVING_TOOL_KSERVE
+ )
+
+ # Assert
+ assert res is None
+
+ def test_validate_resources_num_instances_zero_non_kserve(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+ pr = resources.PredictorResources(num_instances=0)
+
+ # Act
+ res = predictor.Predictor._validate_resources(
+ pr, PREDICTOR.SERVING_TOOL_DEFAULT
+ )
+
+ # Assert
+ assert res == pr
+
+ def test_validate_resources_num_instances_zero_kserve(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+ pr = resources.PredictorResources(num_instances=0)
+
+ # Act
+ res = predictor.Predictor._validate_resources(pr, PREDICTOR.SERVING_TOOL_KSERVE)
+
+ # Assert
+ assert res == pr
+
+ def test_validate_resources_num_instances_one_without_scale_to_zero_non_kserve(
+ self, mocker
+ ):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+ pr = resources.PredictorResources(num_instances=1)
+
+ # Act
+ res = predictor.Predictor._validate_resources(
+ pr, PREDICTOR.SERVING_TOOL_DEFAULT
+ )
+
+ # Assert
+ assert res == pr
+
+ def test_validate_resources_num_instances_one_without_scale_to_zero_kserve(
+ self, mocker
+ ):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+ pr = resources.PredictorResources(num_instances=1)
+
+ # Act
+ res = predictor.Predictor._validate_resources(pr, PREDICTOR.SERVING_TOOL_KSERVE)
+
+ # Assert
+ assert res == pr
+
+ def test_validate_resources_num_instances_one_with_scale_to_zero_non_kserve(
+ self, mocker
+ ):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=True
+ )
+ pr = resources.PredictorResources(num_instances=1)
+
+ # Act
+ res = predictor.Predictor._validate_resources(
+ pr, PREDICTOR.SERVING_TOOL_DEFAULT
+ )
+
+ # Assert
+ assert res == pr
+
+ def test_validate_resources_num_instances_one_with_scale_to_zero_kserve(
+ self, mocker
+ ):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=True
+ )
+ pr = resources.PredictorResources(num_instances=1)
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = predictor.Predictor._validate_resources(
+ pr, PREDICTOR.SERVING_TOOL_KSERVE
+ )
+
+ # Assert
+ assert "Scale-to-zero is required" in str(e_info.value)
+
+ # default resources
+
+ def test_get_default_resources_non_kserve_without_scale_to_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+
+ # Act
+ res = predictor.Predictor._get_default_resources(PREDICTOR.SERVING_TOOL_DEFAULT)
+
+ # Assert
+ assert isinstance(res, resources.PredictorResources)
+ assert res.num_instances == RESOURCES.MIN_NUM_INSTANCES
+
+ def test_get_default_resources_non_kserve_with_scale_to_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=True
+ )
+
+ # Act
+ res = predictor.Predictor._get_default_resources(PREDICTOR.SERVING_TOOL_DEFAULT)
+
+ # Assert
+ assert isinstance(res, resources.PredictorResources)
+ assert res.num_instances == RESOURCES.MIN_NUM_INSTANCES
+
+ def test_get_default_resources_kserve_without_scale_to_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+
+ # Act
+ res = predictor.Predictor._get_default_resources(PREDICTOR.SERVING_TOOL_KSERVE)
+
+ # Assert
+ assert isinstance(res, resources.PredictorResources)
+ assert res.num_instances == RESOURCES.MIN_NUM_INSTANCES
+
+ def test_get_default_resources_kserve_with_scale_to_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=True
+ )
+
+ # Act
+ res = predictor.Predictor._get_default_resources(PREDICTOR.SERVING_TOOL_KSERVE)
+
+ # Assert
+ assert isinstance(res, resources.PredictorResources)
+ assert res.num_instances == 0
+
+ # for model
+
+ def test_for_model(self, mocker):
+ # Arrange
+ def spec(model, model_name, model_version, model_path):
+ pass
+
+ mock_get_predictor_for_model = mocker.patch(
+ "hsml.util.get_predictor_for_model", return_value=True, spec=spec
+ )
+
+ class MockModel:
+ name = "model_name"
+ version = "model_version"
+ model_path = "model_path"
+
+ mock_model = MockModel()
+
+ # Act
+ predictor.Predictor.for_model(mock_model)
+
+ # Assert
+ mock_get_predictor_for_model.assert_called_once_with(
+ model=mock_model,
+ model_name=mock_model.name,
+ model_version=mock_model.version,
+ model_path=mock_model.model_path,
+ )
+
+ # extract fields from json
+
+ def extract_fields_from_json(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ p_json = backend_fixtures["predictor"]["get_deployments_singleton"]["response"][
+ "items"
+ ][0]
+
+ # Act
+ kwargs = predictor.Predictor.extract_fields_from_json(p_json)
+
+ # Assert
+ assert kwargs["id"] == p_json["id"]
+ assert kwargs["name"] == p_json["name"]
+ assert kwargs["description"] == p_json["description"]
+ assert kwargs["created_at"] == p_json["created"]
+ assert kwargs["creator"] == p_json["creator"]
+ assert kwargs["model_name"] == p_json["model_name"]
+ assert kwargs["model_path"] == p_json["model_path"]
+ assert kwargs["model_version"] == p_json["model_version"]
+ assert kwargs["model_framework"] == p_json["model_framework"]
+ assert kwargs["artifact_version"] == p_json["artifact_version"]
+ assert kwargs["model_server"] == p_json["model_server"]
+ assert kwargs["serving_tool"] == p_json["serving_tool"]
+ assert kwargs["script_file"] == p_json["predictor"]
+ assert isinstance(kwargs["resources"], resources.PredictorResources)
+ assert isinstance(kwargs["inference_logger"], inference_logger.InferenceLogger)
+ assert kwargs["inference_logger"].mode == p_json["inference_logging"]
+ assert isinstance(
+ kwargs["inference_batcher"], inference_batcher.InferenceBatcher
+ )
+ assert kwargs["inference_batcher"].enabled == bool(
+ p_json["batching_configuration"]["batching_enabled"]
+ )
+ assert kwargs["api_protocol"] == p_json["api_protocol"]
+ assert isinstance(kwargs["transformer"], transformer.Transformer)
+ assert kwargs["transformer"].script_file == p_json["transformer"]
+ assert isinstance(
+ kwargs["transformer"].resources, resources.TransformerResources
+ )
+
+ # deploy
+
+ def test_deploy(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ p_json = backend_fixtures["predictor"]["get_deployments_singleton"]["response"][
+ "items"
+ ][0]
+ mock_deployment_init = mocker.patch(
+ "hsml.deployment.Deployment.__init__", return_value=None
+ )
+ mock_deployment_save = mocker.patch("hsml.deployment.Deployment.save")
+
+ # Act
+
+ p = predictor.Predictor.from_response_json(p_json)
+ p.deploy()
+
+ # Assert
+ mock_deployment_init.assert_called_once_with(
+ predictor=p,
+ name=p.name,
+ description=p.description,
+ )
+ mock_deployment_save.assert_called_once()
+
+ # auxiliary methods
+
+ def _mock_serving_variables(
+ self,
+ mocker,
+ num_instances,
+ force_scale_to_zero=False,
+ is_saas_connection=False,
+ is_kserve_installed=True,
+ ):
+ mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=SERVING_RESOURCE_LIMITS,
+ )
+ mocker.patch(
+ "hsml.client.get_serving_num_instances_limits", return_value=num_instances
+ )
+ mocker.patch(
+ "hsml.client.is_scale_to_zero_required", return_value=force_scale_to_zero
+ )
+ mocker.patch("hsml.client.is_saas_connection", return_value=is_saas_connection)
+ mocker.patch(
+ "hsml.client.is_kserve_installed", return_value=is_kserve_installed
+ )
diff --git a/hsml/python/tests/test_predictor_state.py b/hsml/python/tests/test_predictor_state.py
new file mode 100644
index 000000000..c9feabcc5
--- /dev/null
+++ b/hsml/python/tests/test_predictor_state.py
@@ -0,0 +1,126 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+
+import humps
+from hsml import predictor_state, predictor_state_condition
+
+
+class TestPredictorState:
+ # from response json
+
+ def test_from_response_json(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["predictor"]["get_deployment_predictor_state"][
+ "response"
+ ]
+ json_camelized = humps.camelize(json) # as returned by the backend
+
+ # Act
+ ps = predictor_state.PredictorState.from_response_json(json_camelized)
+
+ # Assert
+ assert isinstance(ps, predictor_state.PredictorState)
+ assert ps.available_predictor_instances == json["available_instances"]
+ assert (
+ ps.available_transformer_instances
+ == json["available_transformer_instances"]
+ )
+ assert ps.hopsworks_inference_path == json["hopsworks_inference_path"]
+ assert ps.model_server_inference_path == json["model_server_inference_path"]
+ assert ps.internal_port == json["internal_port"]
+ assert ps.revision == json["revision"]
+ assert ps.deployed == json["deployed"]
+ assert isinstance(
+ ps.condition, predictor_state_condition.PredictorStateCondition
+ )
+ assert ps.condition.status == json["condition"]["status"]
+ assert ps.status == json["status"]
+
+ # constructor
+
+ def test_constructor(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["predictor"]["get_deployment_predictor_state"][
+ "response"
+ ]
+
+ # Act
+ ps = predictor_state.PredictorState(
+ available_predictor_instances=json["available_instances"],
+ available_transformer_instances=json["available_transformer_instances"],
+ hopsworks_inference_path=json["hopsworks_inference_path"],
+ model_server_inference_path=json["model_server_inference_path"],
+ internal_port=json["internal_port"],
+ revision=json["revision"],
+ deployed=json["deployed"],
+ condition=predictor_state_condition.PredictorStateCondition(
+ **copy.deepcopy(json["condition"])
+ ),
+ status=json["status"],
+ )
+
+ # Assert
+ assert isinstance(ps, predictor_state.PredictorState)
+ assert ps.available_predictor_instances == json["available_instances"]
+ assert (
+ ps.available_transformer_instances
+ == json["available_transformer_instances"]
+ )
+ assert ps.hopsworks_inference_path == json["hopsworks_inference_path"]
+ assert ps.model_server_inference_path == json["model_server_inference_path"]
+ assert ps.internal_port == json["internal_port"]
+ assert ps.revision == json["revision"]
+ assert ps.deployed == json["deployed"]
+ assert isinstance(
+ ps.condition, predictor_state_condition.PredictorStateCondition
+ )
+ assert ps.condition.status == json["condition"]["status"]
+ assert ps.status == json["status"]
+
+ # extract fields from json
+
+ def test_extract_fields_from_json(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["predictor"]["get_deployment_predictor_state"][
+ "response"
+ ]
+
+ # Act
+ (
+ ai,
+ ati,
+ hip,
+ msip,
+ ipt,
+ r,
+ d,
+ c,
+ s,
+ ) = predictor_state.PredictorState.extract_fields_from_json(copy.deepcopy(json))
+
+ # Assert
+ assert ai == json["available_instances"]
+ assert ati == json["available_transformer_instances"]
+ assert hip == json["hopsworks_inference_path"]
+ assert msip == json["model_server_inference_path"]
+ assert ipt == json["internal_port"]
+ assert r == json["revision"]
+ assert d == json["deployed"]
+ assert isinstance(c, predictor_state_condition.PredictorStateCondition)
+ assert c.status == json["condition"]["status"]
+ assert s == json["status"]
diff --git a/hsml/python/tests/test_predictor_state_condition.py b/hsml/python/tests/test_predictor_state_condition.py
new file mode 100644
index 000000000..5b0387f97
--- /dev/null
+++ b/hsml/python/tests/test_predictor_state_condition.py
@@ -0,0 +1,81 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+
+import humps
+from hsml import predictor_state_condition
+
+
+class TestPredictorStateCondition:
+ # from response json
+
+ def test_from_response_json(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["predictor"]["get_deployment_predictor_state"][
+ "response"
+ ]["condition"]
+ json_camelized = humps.camelize(json) # as returned by the backend
+
+ # Act
+ psc = predictor_state_condition.PredictorStateCondition.from_response_json(
+ json_camelized
+ )
+
+ # Assert
+ assert isinstance(psc, predictor_state_condition.PredictorStateCondition)
+ assert psc.type == json["type"]
+ assert psc.status == json["status"]
+ assert psc.reason == json["reason"]
+
+ # constructor
+
+ def test_constructor(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["predictor"]["get_deployment_predictor_state"][
+ "response"
+ ]["condition"]
+
+ # Act
+ psc = predictor_state_condition.PredictorStateCondition(
+ type=json["type"], status=json["status"], reason=json["reason"]
+ )
+
+ # Assert
+ assert isinstance(psc, predictor_state_condition.PredictorStateCondition)
+ assert psc.type == json["type"]
+ assert psc.status == json["status"]
+ assert psc.reason == json["reason"]
+
+ # extract fields from json
+
+ def test_extract_fields_from_json(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["predictor"]["get_deployment_predictor_state"][
+ "response"
+ ]["condition"]
+
+ # Act
+ kwargs = (
+ predictor_state_condition.PredictorStateCondition.extract_fields_from_json(
+ copy.deepcopy(json)
+ )
+ )
+
+ # Assert
+ assert kwargs["type"] == json["type"]
+ assert kwargs["status"] == json["status"]
+ assert kwargs["reason"] == json["reason"]
diff --git a/hsml/python/tests/test_resources.py b/hsml/python/tests/test_resources.py
new file mode 100644
index 000000000..f77863b38
--- /dev/null
+++ b/hsml/python/tests/test_resources.py
@@ -0,0 +1,928 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+
+import pytest
+from hsml import resources
+from hsml.constants import RESOURCES
+from mock import call
+
+
+SERVING_RESOURCE_LIMITS = {"cores": 2, "memory": 516, "gpus": 2}
+
+
+class TestResources:
+ # Resources
+
+ def test_from_response_json_cpus(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["resources"]["get_only_cores"]["response"]
+
+ # Act
+ r = resources.Resources.from_response_json(json)
+
+ # Assert
+ assert r.cores == json["cores"]
+ assert r.memory is None
+ assert r.gpus is None
+
+ def test_from_response_json_memory(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["resources"]["get_only_memory"]["response"]
+
+ # Act
+ r = resources.Resources.from_response_json(json)
+
+ # Assert
+ assert r.cores is None
+ assert r.memory is json["memory"]
+ assert r.gpus is None
+
+ def test_from_response_json_gpus(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["resources"]["get_only_gpus"]["response"]
+
+ # Act
+ r = resources.Resources.from_response_json(json)
+
+ # Assert
+ assert r.cores is None
+ assert r.memory is None
+ assert r.gpus == json["gpus"]
+
+ def test_from_response_json_cores_and_memory(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["resources"]["get_cores_and_memory"]["response"]
+
+ # Act
+ r = resources.Resources.from_response_json(json)
+
+ # Assert
+ assert r.cores == json["cores"]
+ assert r.memory == json["memory"]
+ assert r.gpus is None
+
+ def test_from_response_json_cores_and_gpus(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["resources"]["get_cores_and_gpus"]["response"]
+
+ # Act
+ r = resources.Resources.from_response_json(json)
+
+ # Assert
+ assert r.cores == json["cores"]
+ assert r.memory is None
+ assert r.gpus == json["gpus"]
+
+ def test_from_response_json_memory_and_gpus(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["resources"]["get_memory_and_gpus"]["response"]
+
+ # Act
+ r = resources.Resources.from_response_json(json)
+
+ # Assert
+ assert r.cores is None
+ assert r.memory == json["memory"]
+ assert r.gpus == json["gpus"]
+
+ def test_from_response_json_cores_memory_and_gpus(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["resources"]["get_cores_memory_and_gpus"]["response"]
+
+ # Act
+ r = resources.Resources.from_response_json(json)
+
+ # Assert
+ assert r.cores == json["cores"]
+ assert r.memory == json["memory"]
+ assert r.gpus == json["gpus"]
+
+ # ComponentResources
+
+ # - from response json
+
+ def test_from_response_json_component_resources(self, mocker):
+ # Arrange
+ res = {"something": "here"}
+ json_decamelized = {"key": "value"}
+ mock_humps_decamelize = mocker.patch(
+ "humps.decamelize", return_value=json_decamelized
+ )
+ mock_from_json = mocker.patch(
+ "hsml.resources.ComponentResources.from_json",
+ return_value="from_json_result",
+ )
+
+ # Act
+ result = resources.ComponentResources.from_response_json(res)
+
+ # Assert
+ assert result == "from_json_result"
+ mock_humps_decamelize.assert_called_once_with(res)
+ mock_from_json.assert_called_once_with(json_decamelized)
+
+ # - constructor
+
+ def test_constructor_component_resources_default(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["resources"][
+ "get_component_resources_num_instances_requests_and_limits"
+ ]["response"]
+ mock_default_resource_limits = mocker.patch(
+ "hsml.resources.ComponentResources._get_default_resource_limits",
+ return_value=(0, 1, 2),
+ )
+ mock_fill_missing_resources = mocker.patch(
+ "hsml.resources.ComponentResources._fill_missing_resources"
+ )
+ mock_validate_resources = mocker.patch(
+ "hsml.resources.ComponentResources._validate_resources"
+ )
+ mock_resources_init = mocker.patch(
+ "hsml.resources.Resources.__init__", return_value=None
+ )
+
+ # Act
+ pr = resources.PredictorResources(num_instances=json["num_instances"])
+
+ # Assert
+ assert pr.num_instances == json["num_instances"]
+ assert mock_default_resource_limits.call_count == 2
+ assert mock_fill_missing_resources.call_count == 2
+ assert (
+ mock_fill_missing_resources.call_args_list[0][0][1] == RESOURCES.MIN_CORES
+ )
+ assert (
+ mock_fill_missing_resources.call_args_list[0][0][2] == RESOURCES.MIN_MEMORY
+ )
+ assert mock_fill_missing_resources.call_args_list[0][0][3] == RESOURCES.MIN_GPUS
+ assert mock_fill_missing_resources.call_args_list[1][0][1] == 0
+ assert mock_fill_missing_resources.call_args_list[1][0][2] == 1
+ assert mock_fill_missing_resources.call_args_list[1][0][3] == 2
+ mock_validate_resources.assert_called_once_with(pr._requests, pr._limits)
+ expected_calls = [
+ call(RESOURCES.MIN_CORES, RESOURCES.MIN_MEMORY, RESOURCES.MIN_GPUS),
+ call(0, 1, 2),
+ ]
+ mock_resources_init.assert_has_calls(expected_calls)
+
+ def test_constructor_component_resources(self, mocker, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["resources"][
+ "get_component_resources_num_instances_requests_and_limits"
+ ]["response"]
+ mock_util_get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json",
+ side_effect=[json["requests"], json["limits"]],
+ )
+ mock_default_resource_limits = mocker.patch(
+ "hsml.resources.ComponentResources._get_default_resource_limits",
+ return_value=(0, 1, 2),
+ )
+ mock_fill_missing_resources = mocker.patch(
+ "hsml.resources.ComponentResources._fill_missing_resources"
+ )
+ mock_validate_resources = mocker.patch(
+ "hsml.resources.ComponentResources._validate_resources"
+ )
+
+ # Act
+ pr = resources.PredictorResources(
+ num_instances=json["num_instances"],
+ requests=json["requests"],
+ limits=json["limits"],
+ )
+
+ # Assert
+ assert pr.num_instances == json["num_instances"]
+ assert pr.requests == json["requests"]
+ assert pr.limits == json["limits"]
+ mock_default_resource_limits.assert_called_once()
+ assert mock_fill_missing_resources.call_count == 2
+ assert (
+ mock_fill_missing_resources.call_args_list[0][0][1] == RESOURCES.MIN_CORES
+ )
+ assert (
+ mock_fill_missing_resources.call_args_list[0][0][2] == RESOURCES.MIN_MEMORY
+ )
+ assert mock_fill_missing_resources.call_args_list[0][0][3] == RESOURCES.MIN_GPUS
+ assert mock_fill_missing_resources.call_args_list[1][0][1] == 0
+ assert mock_fill_missing_resources.call_args_list[1][0][2] == 1
+ assert mock_fill_missing_resources.call_args_list[1][0][3] == 2
+ mock_validate_resources.assert_called_once_with(pr._requests, pr._limits)
+ assert mock_util_get_obj_from_json.call_count == 2
+ expected_calls = [
+ call(json["requests"], resources.Resources),
+ call(json["limits"], resources.Resources),
+ ]
+ mock_util_get_obj_from_json.assert_has_calls(expected_calls)
+
+ # - extract fields from json
+
+ def test_extract_fields_from_json_component_resources_with_key(
+ self, backend_fixtures
+ ):
+ # Arrange
+ json = backend_fixtures["resources"][
+ "get_component_resources_requested_instances_and_predictor_resources"
+ ]["response"]
+ copy_json = copy.deepcopy(json)
+ resources.ComponentResources.RESOURCES_CONFIG_KEY = "predictor_resources"
+ resources.ComponentResources.NUM_INSTANCES_KEY = "requested_instances"
+
+ # Act
+ kwargs = resources.ComponentResources.extract_fields_from_json(copy_json)
+
+ # Assert
+ assert kwargs["num_instances"] == json["requested_instances"]
+ assert isinstance(kwargs["requests"], resources.Resources)
+ assert (
+ kwargs["requests"].cores == json["predictor_resources"]["requests"]["cores"]
+ )
+ assert (
+ kwargs["requests"].memory
+ == json["predictor_resources"]["requests"]["memory"]
+ )
+ assert (
+ kwargs["requests"].gpus == json["predictor_resources"]["requests"]["gpus"]
+ )
+ assert isinstance(kwargs["limits"], resources.Resources)
+ assert kwargs["limits"].cores == json["predictor_resources"]["limits"]["cores"]
+ assert (
+ kwargs["limits"].memory == json["predictor_resources"]["limits"]["memory"]
+ )
+ assert kwargs["limits"].gpus == json["predictor_resources"]["limits"]["gpus"]
+
+ def test_extract_fields_from_json_component_resources(
+ self, mocker, backend_fixtures
+ ):
+ # Arrange
+ json = backend_fixtures["resources"][
+ "get_component_resources_requested_instances_and_predictor_resources_alternative"
+ ]["response"]
+ copy_json = copy.deepcopy(json)
+ resources.ComponentResources.RESOURCES_CONFIG_KEY = "predictor_resources"
+ resources.ComponentResources.NUM_INSTANCES_KEY = "requested_instances"
+
+ # Act
+ kwargs = resources.ComponentResources.extract_fields_from_json(copy_json)
+
+ # Assert
+ assert kwargs["num_instances"] == json["num_instances"]
+ assert isinstance(kwargs["requests"], resources.Resources)
+ assert kwargs["requests"].cores == json["resources"]["requests"]["cores"]
+ assert kwargs["requests"].memory == json["resources"]["requests"]["memory"]
+ assert kwargs["requests"].gpus == json["resources"]["requests"]["gpus"]
+ assert isinstance(kwargs["limits"], resources.Resources)
+ assert kwargs["limits"].cores == json["resources"]["limits"]["cores"]
+ assert kwargs["limits"].memory == json["resources"]["limits"]["memory"]
+ assert kwargs["limits"].gpus == json["resources"]["limits"]["gpus"]
+
+ def test_extract_fields_from_json_component_resources_flatten(
+ self, backend_fixtures
+ ):
+ # Arrange
+ json = backend_fixtures["resources"][
+ "get_component_resources_num_instances_requests_and_limits"
+ ]["response"]
+ copy_json = copy.deepcopy(json)
+ resources.ComponentResources.RESOURCES_CONFIG_KEY = "predictor_resources"
+ resources.ComponentResources.NUM_INSTANCES_KEY = "requested_instances"
+
+ # Act
+ kwargs = resources.ComponentResources.extract_fields_from_json(copy_json)
+
+ # Assert
+ assert kwargs["num_instances"] == json["num_instances"]
+ assert isinstance(kwargs["requests"], resources.Resources)
+ assert kwargs["requests"].cores == json["requests"]["cores"]
+ assert kwargs["requests"].memory == json["requests"]["memory"]
+ assert kwargs["requests"].gpus == json["requests"]["gpus"]
+ assert isinstance(kwargs["limits"], resources.Resources)
+ assert kwargs["limits"].cores == json["limits"]["cores"]
+ assert kwargs["limits"].memory == json["limits"]["memory"]
+ assert kwargs["limits"].gpus == json["limits"]["gpus"]
+
+ # - fill missing dependencies
+
+ def test_fill_missing_dependencies_none(self, mocker):
+ # Arrange
+ class MockResources:
+ cores = None
+ memory = None
+ gpus = None
+
+ mock_resource = MockResources()
+
+ # Act
+ resources.ComponentResources._fill_missing_resources(mock_resource, 10, 11, 12)
+
+ # Assert
+ assert mock_resource.cores == 10
+ assert mock_resource.memory == 11
+ assert mock_resource.gpus == 12
+
+ def test_fill_missing_dependencies_all(self, mocker):
+ # Arrange
+ class MockResources:
+ cores = 1
+ memory = 2
+ gpus = 3
+
+ mock_resource = MockResources()
+
+ # Act
+ resources.ComponentResources._fill_missing_resources(mock_resource, 10, 11, 12)
+
+ # Assert
+ assert mock_resource.cores == 1
+ assert mock_resource.memory == 2
+ assert mock_resource.gpus == 3
+
+ def test_fill_missing_dependencies_some(self, mocker):
+ # Arrange
+ class MockResources:
+ cores = 1
+ memory = None
+ gpus = None
+
+ mock_resource = MockResources()
+
+ # Act
+ resources.ComponentResources._fill_missing_resources(mock_resource, 10, 11, 12)
+
+ # Assert
+ assert mock_resource.cores == 1
+ assert mock_resource.memory == 11
+ assert mock_resource.gpus == 12
+
+ # - get default resource limits
+
+ def test_get_default_resource_limits_no_hard_limit_and_lower_than_default(
+ self, mocker
+ ):
+ # Arrange
+ no_limit_res = {"cores": -1, "memory": -1, "gpus": -1}
+ mock_comp_res = mocker.MagicMock()
+ mock_comp_res._requests = resources.Resources(cores=0.2, memory=516, gpus=0)
+ mock_comp_res._default_resource_limits = (
+ resources.ComponentResources._get_default_resource_limits
+ )
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=no_limit_res, # no upper limit
+ )
+
+ # Act
+ cores, memory, gpus = mock_comp_res._default_resource_limits(mock_comp_res)
+
+ # Assert
+ assert cores == RESOURCES.MAX_CORES
+ assert memory == RESOURCES.MAX_MEMORY
+ assert gpus == RESOURCES.MAX_GPUS
+ mock_get_serving_res_limits.assert_called_once()
+
+ def test_get_default_resource_limits_no_hard_limit_and_higher_than_default(
+ self, mocker
+ ):
+ # Arrange
+ no_limit_res = {"cores": -1, "memory": -1, "gpus": -1}
+ mock_comp_res = mocker.MagicMock()
+ mock_comp_res._requests = resources.Resources(cores=4, memory=2048, gpus=2)
+ mock_comp_res._default_resource_limits = (
+ resources.ComponentResources._get_default_resource_limits
+ )
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=no_limit_res, # no upper limit
+ )
+
+ # Act
+ cores, memory, gpus = mock_comp_res._default_resource_limits(mock_comp_res)
+
+ # Assert
+ assert cores == mock_comp_res._requests.cores
+ assert memory == mock_comp_res._requests.memory
+ assert gpus == mock_comp_res._requests.gpus
+ mock_get_serving_res_limits.assert_called_once()
+
+ def test_get_default_resource_limits_with_higher_hard_limit_and_lower_than_default(
+ self, mocker
+ ):
+ # Arrange
+ hard_limit_res = {"cores": 3, "memory": 3072, "gpus": 3}
+ mock_comp_res = mocker.MagicMock()
+ mock_comp_res._requests = resources.Resources(cores=1, memory=516, gpus=0)
+ mock_comp_res._default_resource_limits = (
+ resources.ComponentResources._get_default_resource_limits
+ )
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ cores, memory, gpus = mock_comp_res._default_resource_limits(mock_comp_res)
+
+ # Assert
+ assert cores == RESOURCES.MAX_CORES
+ assert memory == RESOURCES.MAX_MEMORY
+ assert gpus == RESOURCES.MAX_GPUS
+ mock_get_serving_res_limits.assert_called_once()
+
+ def test_get_default_resource_limits_with_higher_hard_limit_and_higher_than_default(
+ self, mocker
+ ):
+ # Arrange
+ hard_limit_res = {"cores": 3, "memory": 3072, "gpus": 3}
+ mock_comp_res = mocker.MagicMock()
+ mock_comp_res._requests = resources.Resources(cores=3, memory=2048, gpus=1)
+ mock_comp_res._default_resource_limits = (
+ resources.ComponentResources._get_default_resource_limits
+ )
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ cores, memory, gpus = mock_comp_res._default_resource_limits(mock_comp_res)
+
+ # Assert
+ assert cores == hard_limit_res["cores"]
+ assert memory == hard_limit_res["memory"]
+ assert gpus == hard_limit_res["gpus"]
+ mock_get_serving_res_limits.assert_called_once()
+
+ def test_get_default_resource_limits_with_lower_hard_limit_and_lower_than_default(
+ self, mocker
+ ):
+ # Arrange
+ RESOURCES.MAX_GPUS = 1 # override default
+ hard_limit_res = {"cores": 1, "memory": 516, "gpus": 0}
+ mock_comp_res = mocker.MagicMock()
+ mock_comp_res._requests = resources.Resources(cores=0.5, memory=256, gpus=0.5)
+ mock_comp_res._default_resource_limits = (
+ resources.ComponentResources._get_default_resource_limits
+ )
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ cores, memory, gpus = mock_comp_res._default_resource_limits(mock_comp_res)
+
+ # Assert
+ assert cores == hard_limit_res["cores"]
+ assert memory == hard_limit_res["memory"]
+ assert gpus == hard_limit_res["gpus"]
+ mock_get_serving_res_limits.assert_called_once()
+
+ def test_get_default_resource_limits_with_lower_hard_limit_and_higher_than_default(
+ self, mocker
+ ):
+ # Arrange
+ RESOURCES.MAX_GPUS = 1 # override default
+ hard_limit_res = {"cores": 1, "memory": 516, "gpus": 0}
+ mock_comp_res = mocker.MagicMock()
+ mock_comp_res._requests = resources.Resources(cores=4, memory=4080, gpus=4)
+ mock_comp_res._default_resource_limits = (
+ resources.ComponentResources._get_default_resource_limits
+ )
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ cores, memory, gpus = mock_comp_res._default_resource_limits(mock_comp_res)
+
+ # Assert
+ assert cores == hard_limit_res["cores"]
+ assert memory == hard_limit_res["memory"]
+ assert gpus == hard_limit_res["gpus"]
+ mock_get_serving_res_limits.assert_called_once()
+
+ # - validate resources
+
+ def test_validate_resources_no_hard_limits_valid_resources(self, mocker):
+ # Arrange
+ no_limit_res = {"cores": -1, "memory": -1, "gpus": -1}
+ requests = resources.Resources(cores=1, memory=1024, gpus=0)
+ limits = resources.Resources(cores=2, memory=2048, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=no_limit_res, # upper limit
+ )
+
+ # Act
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+
+ def test_validate_resources_no_hard_limit_invalid_cores_request(self, mocker):
+ # Arrange
+ no_limit_res = {"cores": -1, "memory": -1, "gpus": -1}
+ requests = resources.Resources(cores=0, memory=1024, gpus=0)
+ limits = resources.Resources(cores=2, memory=2048, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=no_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert "Requested number of cores must be greater than 0 cores." in str(
+ e_info.value
+ )
+
+ def test_validate_resources_no_hard_limit_invalid_memory_request(self, mocker):
+ # Arrange
+ no_limit_res = {"cores": -1, "memory": -1, "gpus": -1}
+ requests = resources.Resources(cores=1, memory=0, gpus=0)
+ limits = resources.Resources(cores=2, memory=2048, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=no_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert "Requested memory resources must be greater than 0 MB." in str(
+ e_info.value
+ )
+
+ def test_validate_resources_no_hard_limit_invalid_gpus_request(self, mocker):
+ # Arrange
+ no_limit_res = {"cores": -1, "memory": -1, "gpus": -1}
+ requests = resources.Resources(
+ cores=1, memory=1024, gpus=-1
+ ) # 0 gpus is accepted
+ limits = resources.Resources(cores=2, memory=2048, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=no_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert (
+ "Requested number of gpus must be greater than or equal to 0 gpus."
+ in str(e_info.value)
+ )
+
+ def test_validate_resources_no_hard_limit_cores_request_out_of_range(self, mocker):
+ # Arrange
+ no_limit_res = {"cores": -1, "memory": -1, "gpus": -1}
+ requests = resources.Resources(cores=2, memory=1024, gpus=0)
+ limits = resources.Resources(cores=1, memory=2048, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=no_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert (
+ f"Requested number of cores cannot exceed the limit of {str(limits.cores)} cores."
+ in str(e_info.value)
+ )
+
+ def test_validate_resources_no_hard_limit_invalid_memory_request_out_of_range(
+ self, mocker
+ ):
+ # Arrange
+ no_limit_res = {"cores": -1, "memory": -1, "gpus": -1}
+ requests = resources.Resources(cores=1, memory=2048, gpus=0)
+ limits = resources.Resources(cores=2, memory=1024, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=no_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert (
+ f"Requested memory resources cannot exceed the limit of {str(limits.memory)} MB."
+ in str(e_info.value)
+ )
+
+ def test_validate_resources_no_hard_limit_invalid_gpus_request_out_of_range(
+ self, mocker
+ ):
+ # Arrange
+ no_limit_res = {"cores": -1, "memory": -1, "gpus": -1}
+ requests = resources.Resources(cores=1, memory=1024, gpus=2)
+ limits = resources.Resources(cores=2, memory=2048, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=no_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert (
+ f"Requested number of gpus cannot exceed the limit of {str(limits.gpus)} gpus."
+ in str(e_info.value)
+ )
+
+ def test_validate_resources_with_hard_limit_valid_resources(self, mocker):
+ # Arrange
+ hard_limit_res = {"cores": 3, "memory": 3072, "gpus": 3}
+ requests = resources.Resources(cores=1, memory=1024, gpus=0)
+ limits = resources.Resources(cores=2, memory=2048, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+
+ def test_validate_resources_with_hard_limit_invalid_cores_limit(self, mocker):
+ # Arrange
+ hard_limit_res = {"cores": 3, "memory": 3072, "gpus": 3}
+ requests = resources.Resources(cores=2, memory=1024, gpus=0)
+ limits = resources.Resources(cores=0, memory=2048, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert "Limit number of cores must be greater than 0 cores." in str(
+ e_info.value
+ )
+
+ def test_validate_resources_with_hard_limit_invalid_memory_limit(self, mocker):
+ # Arrange
+ hard_limit_res = {"cores": 3, "memory": 3072, "gpus": 3}
+ requests = resources.Resources(cores=2, memory=1024, gpus=0)
+ limits = resources.Resources(cores=1, memory=0, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert "Limit memory resources must be greater than 0 MB." in str(e_info.value)
+
+ def test_validate_resources_with_hard_limit_invalid_gpus_limit(self, mocker):
+ # Arrange
+ hard_limit_res = {"cores": 3, "memory": 3072, "gpus": 3}
+ requests = resources.Resources(cores=2, memory=1024, gpus=0)
+ limits = resources.Resources(cores=1, memory=2048, gpus=-1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert "Limit number of gpus must be greater than or equal to 0 gpus." in str(
+ e_info.value
+ )
+
+ def test_validate_resources_with_hard_limit_invalid_cores_request(self, mocker):
+ # Arrange
+ hard_limit_res = {"cores": 3, "memory": 3072, "gpus": 3}
+ requests = resources.Resources(cores=2, memory=1024, gpus=0)
+ limits = resources.Resources(cores=4, memory=2048, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert (
+ f"Limit number of cores cannot exceed the maximum of {hard_limit_res['cores']} cores."
+ in str(e_info.value)
+ )
+
+ def test_validate_resources_with_hard_limit_invalid_memory_request(self, mocker):
+ # Arrange
+ hard_limit_res = {"cores": 3, "memory": 3072, "gpus": 3}
+ requests = resources.Resources(cores=2, memory=1024, gpus=0)
+ limits = resources.Resources(cores=3, memory=4076, gpus=1)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert (
+ f"Limit memory resources cannot exceed the maximum of {hard_limit_res['memory']} MB."
+ in str(e_info.value)
+ )
+
+ def test_validate_resources_with_hard_limit_invalid_gpus_request(self, mocker):
+ # Arrange
+ hard_limit_res = {"cores": 3, "memory": 3072, "gpus": 3}
+ requests = resources.Resources(cores=2, memory=1024, gpus=0)
+ limits = resources.Resources(cores=3, memory=2048, gpus=4)
+ mock_get_serving_res_limits = mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=hard_limit_res, # upper limit
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ resources.ComponentResources._validate_resources(requests, limits)
+
+ # Assert
+ mock_get_serving_res_limits.assert_called_once()
+ assert (
+ f"Limit number of gpus cannot exceed the maximum of {hard_limit_res['gpus']} gpus."
+ in str(e_info.value)
+ )
+
+ # PredictorResources
+
+ def test_from_response_json_predictor_resources(self, mocker, backend_fixtures):
+ mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=SERVING_RESOURCE_LIMITS,
+ )
+ json = backend_fixtures["resources"][
+ "get_component_resources_num_instances_requests_and_limits"
+ ]["response"]
+
+ # Act
+ r = resources.PredictorResources.from_response_json(json)
+
+ # Assert
+ assert r.num_instances == json["num_instances"]
+ assert r.requests.cores == json["requests"]["cores"]
+ assert r.requests.memory == json["requests"]["memory"]
+ assert r.requests.gpus == json["requests"]["gpus"]
+ assert r.limits.cores == json["limits"]["cores"]
+ assert r.limits.memory == json["limits"]["memory"]
+ assert r.limits.gpus == json["limits"]["gpus"]
+
+ def test_from_response_json_predictor_resources_specific_keys(
+ self, mocker, backend_fixtures
+ ):
+ mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=SERVING_RESOURCE_LIMITS,
+ )
+ json = backend_fixtures["resources"][
+ "get_component_resources_requested_instances_and_predictor_resources"
+ ]["response"]
+
+ # Act
+ r = resources.PredictorResources.from_response_json(json)
+
+ # Assert
+ assert r.num_instances == json["requested_instances"]
+ assert r.requests.cores == json["predictor_resources"]["requests"]["cores"]
+ assert r.requests.memory == json["predictor_resources"]["requests"]["memory"]
+ assert r.requests.gpus == json["predictor_resources"]["requests"]["gpus"]
+ assert r.limits.cores == json["predictor_resources"]["limits"]["cores"]
+ assert r.limits.memory == json["predictor_resources"]["limits"]["memory"]
+ assert r.limits.gpus == json["predictor_resources"]["limits"]["gpus"]
+
+ # TransformerResources
+
+ def test_from_response_json_transformer_resources(self, mocker, backend_fixtures):
+ mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=SERVING_RESOURCE_LIMITS,
+ )
+ json = backend_fixtures["resources"][
+ "get_component_resources_num_instances_requests_and_limits"
+ ]["response"]
+
+ # Act
+ r = resources.TransformerResources.from_response_json(json)
+
+ # Assert
+ assert r.num_instances == json["num_instances"]
+ assert r.requests.cores == json["requests"]["cores"]
+ assert r.requests.memory == json["requests"]["memory"]
+ assert r.requests.gpus == json["requests"]["gpus"]
+ assert r.limits.cores == json["limits"]["cores"]
+ assert r.limits.memory == json["limits"]["memory"]
+ assert r.limits.gpus == json["limits"]["gpus"]
+
+ def test_from_response_json_transformer_resources_specific_keys(
+ self, mocker, backend_fixtures
+ ):
+ mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=SERVING_RESOURCE_LIMITS,
+ )
+ json = backend_fixtures["resources"][
+ "get_component_resources_requested_instances_and_transformer_resources"
+ ]["response"]
+
+ # Act
+ r = resources.TransformerResources.from_response_json(json)
+
+ # Assert
+ assert r.num_instances == json["requested_transformer_instances"]
+ assert r.requests.cores == json["transformer_resources"]["requests"]["cores"]
+ assert r.requests.memory == json["transformer_resources"]["requests"]["memory"]
+ assert r.requests.gpus == json["transformer_resources"]["requests"]["gpus"]
+ assert r.limits.cores == json["transformer_resources"]["limits"]["cores"]
+ assert r.limits.memory == json["transformer_resources"]["limits"]["memory"]
+ assert r.limits.gpus == json["transformer_resources"]["limits"]["gpus"]
+
+ def test_from_response_json_transformer_resources_default_limits(
+ self, mocker, backend_fixtures
+ ):
+ mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=SERVING_RESOURCE_LIMITS,
+ )
+ mocker.patch(
+ "hsml.resources.ComponentResources._get_default_resource_limits",
+ return_value=(
+ SERVING_RESOURCE_LIMITS["cores"],
+ SERVING_RESOURCE_LIMITS["memory"],
+ SERVING_RESOURCE_LIMITS["gpus"],
+ ),
+ )
+ json = backend_fixtures["resources"][
+ "get_component_resources_num_instances_and_requests"
+ ]["response"]
+
+ # Act
+ r = resources.TransformerResources.from_response_json(json)
+
+ # Assert
+ assert r.num_instances == json["num_instances"]
+ assert r.requests.cores == json["requests"]["cores"]
+ assert r.requests.memory == json["requests"]["memory"]
+ assert r.requests.gpus == json["requests"]["gpus"]
+ assert r.limits.cores == SERVING_RESOURCE_LIMITS["cores"]
+ assert r.limits.memory == SERVING_RESOURCE_LIMITS["memory"]
+ assert r.limits.gpus == SERVING_RESOURCE_LIMITS["gpus"]
diff --git a/hsml/python/tests/test_schema.py b/hsml/python/tests/test_schema.py
new file mode 100644
index 000000000..69ddd0782
--- /dev/null
+++ b/hsml/python/tests/test_schema.py
@@ -0,0 +1,199 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import numpy as np
+from hsml import schema
+
+
+class TestSchema:
+ # constructor
+
+ def test_constructor_default(self, mocker):
+ # Arrange
+ mock_tensor = mocker.MagicMock()
+ mock_tensor.tensors = mocker.MagicMock(return_value="tensor_schema")
+ mock_columnar = mocker.MagicMock()
+ mock_columnar.columns = mocker.MagicMock(return_value="columnar_schema")
+ mock_convert_tensor_to_schema = mocker.patch(
+ "hsml.schema.Schema._convert_tensor_to_schema", return_value=mock_tensor
+ )
+ mock_convert_columnar_to_schema = mocker.patch(
+ "hsml.schema.Schema._convert_columnar_to_schema",
+ return_value=mock_columnar,
+ )
+
+ # Act
+ s = schema.Schema()
+
+ # Assert
+ assert s.columnar_schema == mock_columnar.columns
+ assert not hasattr(s, "tensor_schema")
+ mock_convert_tensor_to_schema.assert_not_called()
+ mock_convert_columnar_to_schema.assert_called_once_with(None)
+
+ def test_constructor_numpy(self, mocker):
+ # Arrange
+ obj = np.array([])
+ mock_tensor = mocker.MagicMock()
+ mock_tensor.tensors = mocker.MagicMock(return_value="tensor_schema")
+ mock_columnar = mocker.MagicMock()
+ mock_columnar.columns = mocker.MagicMock(return_value="columnar_schema")
+ mock_convert_tensor_to_schema = mocker.patch(
+ "hsml.schema.Schema._convert_tensor_to_schema", return_value=mock_tensor
+ )
+ mock_convert_columnar_to_schema = mocker.patch(
+ "hsml.schema.Schema._convert_columnar_to_schema",
+ return_value=mock_columnar,
+ )
+
+ # Act
+ s = schema.Schema(obj)
+
+ # Assert
+ assert s.tensor_schema == mock_tensor.tensors
+ assert not hasattr(s, "columnar_schema")
+ mock_convert_columnar_to_schema.assert_not_called()
+ mock_convert_tensor_to_schema.assert_called_once_with(obj)
+
+ def test_constructor_tensor_list(self, mocker):
+ # Arrange
+ obj = [{"shape": "some_shape"}]
+ mock_tensor = mocker.MagicMock()
+ mock_tensor.tensors = mocker.MagicMock(return_value="tensor_schema")
+ mock_columnar = mocker.MagicMock()
+ mock_columnar.columns = mocker.MagicMock(return_value="columnar_schema")
+ mock_convert_tensor_to_schema = mocker.patch(
+ "hsml.schema.Schema._convert_tensor_to_schema", return_value=mock_tensor
+ )
+ mock_convert_columnar_to_schema = mocker.patch(
+ "hsml.schema.Schema._convert_columnar_to_schema",
+ return_value=mock_columnar,
+ )
+
+ # Act
+ s = schema.Schema(obj)
+
+ # Assert
+ assert s.tensor_schema == mock_tensor.tensors
+ assert not hasattr(s, "columnar_schema")
+ mock_convert_columnar_to_schema.assert_not_called()
+ mock_convert_tensor_to_schema.assert_called_once_with(obj)
+
+ def test_constructor_column_list(self, mocker):
+ # Arrange
+ obj = [{"no_shape": "nothing"}]
+ mock_tensor = mocker.MagicMock()
+ mock_tensor.tensors = mocker.MagicMock(return_value="tensor_schema")
+ mock_columnar = mocker.MagicMock()
+ mock_columnar.columns = mocker.MagicMock(return_value="columnar_schema")
+ mock_convert_tensor_to_schema = mocker.patch(
+ "hsml.schema.Schema._convert_tensor_to_schema", return_value=mock_tensor
+ )
+ mock_convert_columnar_to_schema = mocker.patch(
+ "hsml.schema.Schema._convert_columnar_to_schema",
+ return_value=mock_columnar,
+ )
+
+ # Act
+ s = schema.Schema(obj)
+
+ # Assert
+ assert s.columnar_schema == mock_columnar.columns
+ assert not hasattr(s, "tensor_schema")
+ mock_convert_tensor_to_schema.assert_not_called()
+ mock_convert_columnar_to_schema.assert_called_once_with(obj)
+
+ # convert to schema
+
+ def test_convert_columnar_to_schema(self, mocker):
+ # Arrange
+ obj = {"key": "value"}
+ mock_columnar_schema_init = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema.__init__",
+ return_value=None,
+ )
+ mock_schema = mocker.MagicMock()
+ mock_schema._convert_columnar_to_schema = (
+ schema.Schema._convert_columnar_to_schema
+ )
+
+ # Act
+ ret = mock_schema._convert_columnar_to_schema(mock_schema, obj)
+
+ # Assert
+ assert isinstance(ret, schema.ColumnarSchema)
+ mock_columnar_schema_init.assert_called_once_with(obj)
+
+ def test_convert_tensor_to_schema(self, mocker):
+ # Arrange
+ obj = {"key": "value"}
+ mock_tensor_schema_init = mocker.patch(
+ "hsml.utils.schema.tensor_schema.TensorSchema.__init__",
+ return_value=None,
+ )
+ mock_schema = mocker.MagicMock()
+ mock_schema._convert_tensor_to_schema = schema.Schema._convert_tensor_to_schema
+
+ # Act
+ ret = mock_schema._convert_tensor_to_schema(mock_schema, obj)
+
+ # Assert
+ assert isinstance(ret, schema.TensorSchema)
+ mock_tensor_schema_init.assert_called_once_with(obj)
+
+ # get type
+
+ def test_get_type_none(self, mocker):
+ # Arrange
+ class MockSchema:
+ pass
+
+ mock_schema = MockSchema()
+ mock_schema._get_type = schema.Schema._get_type
+
+ # Act
+ t = mock_schema._get_type(mock_schema)
+
+ # Assert
+ assert t is None
+
+ def test_get_type_tensor(self, mocker):
+ # Arrange
+ class MockSchema:
+ tensor_schema = {}
+
+ mock_schema = MockSchema()
+ mock_schema._get_type = schema.Schema._get_type
+
+ # Act
+ t = mock_schema._get_type(mock_schema)
+
+ # Assert
+ assert t == "tensor"
+
+ def test_get_type_columnar(self, mocker):
+ # Arrange
+ class MockSchema:
+ columnar_schema = {}
+
+ mock_schema = MockSchema()
+ mock_schema._get_type = schema.Schema._get_type
+
+ # Act
+ t = mock_schema._get_type(mock_schema)
+
+ # Assert
+ assert t == "columnar"
diff --git a/hsml/python/tests/test_tag.py b/hsml/python/tests/test_tag.py
new file mode 100644
index 000000000..7a12955ac
--- /dev/null
+++ b/hsml/python/tests/test_tag.py
@@ -0,0 +1,62 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import humps
+from hsml import tag
+
+
+class TestTag:
+ # from response json
+
+ def test_from_response_json(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["tag"]["get"]["response"]
+ json_camelized = humps.camelize(json)
+
+ # Act
+ t_list = tag.Tag.from_response_json(json_camelized)
+
+ # Assert
+ assert len(t_list) == 1
+ t = t_list[0]
+ assert t.name == "test_name"
+ assert t.value == "test_value"
+
+ def test_from_response_json_empty(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["tag"]["get_empty"]["response"]
+ json_camelized = humps.camelize(json)
+
+ # Act
+ t_list = tag.Tag.from_response_json(json_camelized)
+
+ # Assert
+ assert len(t_list) == 0
+
+ # constructor
+
+ def test_constructor(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["tag"]["get"]["response"]["items"][0]
+ tag_name = json.pop("name")
+ tag_value = json.pop("value")
+
+ # Act
+ t = tag.Tag(name=tag_name, value=tag_value, **json)
+
+ # Assert
+ assert t.name == "test_name"
+ assert t.value == "test_value"
diff --git a/hsml/python/tests/test_transformer.py b/hsml/python/tests/test_transformer.py
new file mode 100644
index 000000000..7df302bd6
--- /dev/null
+++ b/hsml/python/tests/test_transformer.py
@@ -0,0 +1,309 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import copy
+
+import pytest
+from hsml import resources, transformer
+from hsml.constants import RESOURCES
+
+
+SERVING_RESOURCE_LIMITS = {"cores": 2, "memory": 1024, "gpus": 2}
+SERVING_NUM_INSTANCES_NO_LIMIT = [-1]
+SERVING_NUM_INSTANCES_SCALE_TO_ZERO = [0]
+SERVING_NUM_INSTANCES_ONE = [0]
+
+
+class TestTransformer:
+ # from response json
+
+ def test_from_response_json_with_transformer_field(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ json = backend_fixtures["transformer"]["get_deployment_with_transformer"][
+ "response"
+ ]
+
+ # Act
+ t = transformer.Transformer.from_response_json(json)
+
+ # Assert
+ assert isinstance(t, transformer.Transformer)
+ assert t.script_file == json["transformer"]
+
+ tr_resources = json["transformer_resources"]
+ assert (
+ t.resources.num_instances == tr_resources["requested_transformer_instances"]
+ )
+ assert t.resources.requests.cores == tr_resources["requests"]["cores"]
+ assert t.resources.requests.memory == tr_resources["requests"]["memory"]
+ assert t.resources.requests.gpus == tr_resources["requests"]["gpus"]
+ assert t.resources.limits.cores == tr_resources["limits"]["cores"]
+ assert t.resources.limits.memory == tr_resources["limits"]["memory"]
+ assert t.resources.limits.gpus == tr_resources["limits"]["gpus"]
+
+ def test_from_response_json_with_script_file_field(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ json = backend_fixtures["transformer"]["get_transformer_with_resources"][
+ "response"
+ ]
+
+ # Act
+ t = transformer.Transformer.from_response_json(json)
+
+ # Assert
+ assert isinstance(t, transformer.Transformer)
+ assert t.script_file == json["script_file"]
+
+ tr_resources = json["resources"]
+ assert t.resources.num_instances == tr_resources["num_instances"]
+ assert t.resources.requests.cores == tr_resources["requests"]["cores"]
+ assert t.resources.requests.memory == tr_resources["requests"]["memory"]
+ assert t.resources.requests.gpus == tr_resources["requests"]["gpus"]
+ assert t.resources.limits.cores == tr_resources["limits"]["cores"]
+ assert t.resources.limits.memory == tr_resources["limits"]["memory"]
+ assert t.resources.limits.gpus == tr_resources["limits"]["gpus"]
+
+ def test_from_response_json_empty(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ json = backend_fixtures["transformer"]["get_deployment_without_transformer"][
+ "response"
+ ]
+
+ # Act
+ t = transformer.Transformer.from_response_json(json)
+
+ # Assert
+ assert t is None
+
+ def test_from_response_json_default_resources(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+ json = backend_fixtures["transformer"]["get_transformer_without_resources"][
+ "response"
+ ]
+
+ # Act
+ t = transformer.Transformer.from_response_json(json)
+
+ # Assert
+ assert isinstance(t, transformer.Transformer)
+ assert t.script_file == json["script_file"]
+
+ assert t.resources.num_instances == RESOURCES.MIN_NUM_INSTANCES
+ assert t.resources.requests.cores == RESOURCES.MIN_CORES
+ assert t.resources.requests.memory == RESOURCES.MIN_MEMORY
+ assert t.resources.requests.gpus == RESOURCES.MIN_GPUS
+ assert t.resources.limits.cores == RESOURCES.MAX_CORES
+ assert t.resources.limits.memory == RESOURCES.MAX_MEMORY
+ assert t.resources.limits.gpus == RESOURCES.MAX_GPUS
+
+ # constructor
+
+ def test_constructor_default_resources(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+ json = backend_fixtures["transformer"]["get_transformer_without_resources"][
+ "response"
+ ]
+
+ # Act
+ t = transformer.Transformer(json["script_file"], resources=None)
+
+ # Assert
+ assert t.script_file == json["script_file"]
+
+ assert t.resources.num_instances == RESOURCES.MIN_NUM_INSTANCES
+ assert t.resources.requests.cores == RESOURCES.MIN_CORES
+ assert t.resources.requests.memory == RESOURCES.MIN_MEMORY
+ assert t.resources.requests.gpus == RESOURCES.MIN_GPUS
+ assert t.resources.limits.cores == RESOURCES.MAX_CORES
+ assert t.resources.limits.memory == RESOURCES.MAX_MEMORY
+ assert t.resources.limits.gpus == RESOURCES.MAX_GPUS
+
+ def test_constructor_default_resources_when_scale_to_zero_is_required(
+ self, mocker, backend_fixtures
+ ):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=True
+ )
+ json = backend_fixtures["transformer"]["get_transformer_without_resources"][
+ "response"
+ ]
+
+ # Act
+ t = transformer.Transformer(json["script_file"], resources=None)
+
+ # Assert
+ assert t.script_file == json["script_file"]
+
+ assert t.resources.num_instances == 0
+ assert t.resources.requests.cores == RESOURCES.MIN_CORES
+ assert t.resources.requests.memory == RESOURCES.MIN_MEMORY
+ assert t.resources.requests.gpus == RESOURCES.MIN_GPUS
+ assert t.resources.limits.cores == RESOURCES.MAX_CORES
+ assert t.resources.limits.memory == RESOURCES.MAX_MEMORY
+ assert t.resources.limits.gpus == RESOURCES.MAX_GPUS
+
+ # validate resources
+
+ def test_validate_resources_none(self):
+ # Act
+ res = transformer.Transformer._validate_resources(None)
+
+ # Assert
+ assert res is None
+
+ def test_validate_resources_num_instances_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+ tr = resources.TransformerResources(num_instances=0)
+
+ # Act
+ res = transformer.Transformer._validate_resources(tr)
+
+ # Assert
+ assert res == tr
+
+ def test_validate_resources_num_instances_one_without_scale_to_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+ tr = resources.TransformerResources(num_instances=1)
+
+ # Act
+ res = transformer.Transformer._validate_resources(tr)
+
+ # Assert
+ assert res == tr
+
+ def test_validate_resources_num_instances_one_with_scale_to_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=True
+ )
+ tr = resources.TransformerResources(num_instances=1)
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = transformer.Transformer._validate_resources(tr)
+
+ # Assert
+ assert "Scale-to-zero is required" in str(e_info.value)
+
+ # default num instances
+
+ def test_get_default_num_instances_without_scale_to_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+
+ # Act
+ num_instances = transformer.Transformer._get_default_num_instances()
+
+ # Assert
+ assert num_instances == RESOURCES.MIN_NUM_INSTANCES
+
+ def test_get_default_num_instances_with_scale_to_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=True
+ )
+
+ # Act
+ num_instances = transformer.Transformer._get_default_num_instances()
+
+ # Assert
+ assert num_instances == 0
+
+ # default resources
+
+ def test_get_default_resources_without_scale_to_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=False
+ )
+
+ # Act
+ res = transformer.Transformer._get_default_resources()
+
+ # Assert
+ assert isinstance(res, resources.TransformerResources)
+ assert res.num_instances == RESOURCES.MIN_NUM_INSTANCES
+
+ def test_get_default_resources_with_scale_to_zero(self, mocker):
+ # Arrange
+ self._mock_serving_variables(
+ mocker, SERVING_NUM_INSTANCES_NO_LIMIT, force_scale_to_zero=True
+ )
+
+ # Act
+ res = transformer.Transformer._get_default_resources()
+
+ # Assert
+ assert isinstance(res, resources.TransformerResources)
+ assert res.num_instances == 0
+
+ # extract fields from json
+
+ def test_extract_fields_from_json(self, mocker, backend_fixtures):
+ # Arrange
+ self._mock_serving_variables(mocker, SERVING_NUM_INSTANCES_NO_LIMIT)
+ json = backend_fixtures["transformer"]["get_deployment_with_transformer"][
+ "response"
+ ]
+ json_copy = copy.deepcopy(json)
+
+ # Act
+ sf, rc = transformer.Transformer.extract_fields_from_json(json_copy)
+
+ # Assert
+ assert sf == json["transformer"]
+ assert isinstance(rc, resources.TransformerResources)
+
+ tr_resources = json["transformer_resources"]
+ assert rc.num_instances == tr_resources["requested_transformer_instances"]
+ assert rc.requests.cores == tr_resources["requests"]["cores"]
+ assert rc.requests.memory == tr_resources["requests"]["memory"]
+ assert rc.requests.gpus == tr_resources["requests"]["gpus"]
+ assert rc.limits.cores == tr_resources["limits"]["cores"]
+ assert rc.limits.memory == tr_resources["limits"]["memory"]
+ assert rc.limits.gpus == tr_resources["limits"]["gpus"]
+
+ # auxiliary methods
+
+ def _mock_serving_variables(self, mocker, num_instances, force_scale_to_zero=False):
+ mocker.patch(
+ "hsml.client.get_serving_resource_limits",
+ return_value=SERVING_RESOURCE_LIMITS,
+ )
+ mocker.patch(
+ "hsml.client.get_serving_num_instances_limits", return_value=num_instances
+ )
+ mocker.patch(
+ "hsml.client.is_scale_to_zero_required", return_value=force_scale_to_zero
+ )
diff --git a/hsml/python/tests/test_util.py b/hsml/python/tests/test_util.py
new file mode 100644
index 000000000..3e7d18166
--- /dev/null
+++ b/hsml/python/tests/test_util.py
@@ -0,0 +1,645 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import os
+from urllib.parse import ParseResult
+
+import pytest
+from hsml import util
+from hsml.constants import MODEL
+from hsml.model import Model as BaseModel
+from hsml.predictor import Predictor as BasePredictor
+from hsml.python.model import Model as PythonModel
+from hsml.python.predictor import Predictor as PyPredictor
+from hsml.sklearn.model import Model as SklearnModel
+from hsml.sklearn.predictor import Predictor as SkLearnPredictor
+from hsml.tensorflow.model import Model as TensorflowModel
+from hsml.tensorflow.predictor import Predictor as TFPredictor
+from hsml.torch.model import Model as TorchModel
+from hsml.torch.predictor import Predictor as TorchPredictor
+
+
+class TestUtil:
+ # schema and types
+
+ # - set_model_class
+
+ def test_set_model_class_base(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_base"]["response"]["items"][0]
+
+ # Act
+ model = util.set_model_class(json)
+
+ # Assert
+ assert isinstance(model, BaseModel)
+ assert model.framework is None
+
+ def test_set_model_class_python(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_python"]["response"]["items"][0]
+
+ # Act
+ model = util.set_model_class(json)
+
+ # Assert
+ assert isinstance(model, PythonModel)
+ assert model.framework == MODEL.FRAMEWORK_PYTHON
+
+ def test_set_model_class_sklearn(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_sklearn"]["response"]["items"][0]
+
+ # Act
+ model = util.set_model_class(json)
+
+ # Assert
+ assert isinstance(model, SklearnModel)
+ assert model.framework == MODEL.FRAMEWORK_SKLEARN
+
+ def test_set_model_class_tensorflow(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_tensorflow"]["response"]["items"][0]
+
+ # Act
+ model = util.set_model_class(json)
+
+ # Assert
+ assert isinstance(model, TensorflowModel)
+ assert model.framework == MODEL.FRAMEWORK_TENSORFLOW
+
+ def test_set_model_class_torch(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_torch"]["response"]["items"][0]
+
+ # Act
+ model = util.set_model_class(json)
+
+ # Assert
+ assert isinstance(model, TorchModel)
+ assert model.framework == MODEL.FRAMEWORK_TORCH
+
+ def test_set_model_class_unsupported(self, backend_fixtures):
+ # Arrange
+ json = backend_fixtures["model"]["get_base"]["response"]["items"][0]
+ json["framework"] = "UNSUPPORTED"
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ util.set_model_class(json)
+
+ # Assert
+ assert "is not a supported framework" in str(e_info.value)
+
+ # - input_example_to_json
+
+ def test_input_example_to_json_from_numpy(self, mocker, input_example_numpy):
+ # Arrange
+ mock_handle_tensor_input = mocker.patch("hsml.util._handle_tensor_input")
+ mock_handle_dataframe_input = mocker.patch("hsml.util._handle_dataframe_input")
+ mock_handle_dict_input = mocker.patch("hsml.util._handle_dict_input")
+
+ # Act
+ util.input_example_to_json(input_example_numpy)
+
+ # Assert
+ mock_handle_tensor_input.assert_called_once()
+ mock_handle_dict_input.assert_not_called()
+ mock_handle_dataframe_input.assert_not_called()
+
+ def test_input_example_to_json_from_dict(self, mocker, input_example_dict):
+ # Arrange
+ mock_handle_tensor_input = mocker.patch("hsml.util._handle_tensor_input")
+ mock_handle_dataframe_input = mocker.patch("hsml.util._handle_dataframe_input")
+ mock_handle_dict_input = mocker.patch("hsml.util._handle_dict_input")
+
+ # Act
+ util.input_example_to_json(input_example_dict)
+
+ # Assert
+ mock_handle_tensor_input.assert_not_called()
+ mock_handle_dict_input.assert_called_once()
+ mock_handle_dataframe_input.assert_not_called()
+
+ def test_input_example_to_json_from_dataframe(
+ self, mocker, input_example_dataframe_pandas_dataframe
+ ):
+ # Arrange
+ mock_handle_tensor_input = mocker.patch("hsml.util._handle_tensor_input")
+ mock_handle_dataframe_input = mocker.patch("hsml.util._handle_dataframe_input")
+ mock_handle_dict_input = mocker.patch("hsml.util._handle_dict_input")
+
+ # Act
+ util.input_example_to_json(input_example_dataframe_pandas_dataframe)
+
+ # Assert
+ mock_handle_tensor_input.assert_not_called()
+ mock_handle_dict_input.assert_not_called()
+ mock_handle_dataframe_input.assert_called_once() # default
+
+ def test_input_example_to_json_unsupported(self, mocker):
+ # Arrange
+ mock_handle_tensor_input = mocker.patch("hsml.util._handle_tensor_input")
+ mock_handle_dataframe_input = mocker.patch("hsml.util._handle_dataframe_input")
+ mock_handle_dict_input = mocker.patch("hsml.util._handle_dict_input")
+
+ # Act
+ util.input_example_to_json(lambda unsupported_type: None)
+
+ # Assert
+ mock_handle_tensor_input.assert_not_called()
+ mock_handle_dict_input.assert_not_called()
+ mock_handle_dataframe_input.assert_called_once() # default
+
+ # - handle input examples
+
+ def test_handle_dataframe_input_pandas_dataframe(
+ self,
+ input_example_dataframe_pandas_dataframe,
+ input_example_dataframe_pandas_dataframe_empty,
+ input_example_dataframe_list,
+ ):
+ # Act
+ json = util._handle_dataframe_input(input_example_dataframe_pandas_dataframe)
+ with pytest.raises(ValueError) as e_info:
+ util._handle_dataframe_input(input_example_dataframe_pandas_dataframe_empty)
+
+ # Assert
+ assert isinstance(json, list)
+ assert json == input_example_dataframe_list
+ assert "can not be empty" in str(e_info.value)
+
+ def test_handle_dataframe_input_pandas_dataframe_series(
+ self,
+ input_example_dataframe_pandas_series,
+ input_example_dataframe_pandas_series_empty,
+ input_example_dataframe_list,
+ ):
+ # Act
+ json = util._handle_dataframe_input(input_example_dataframe_pandas_series)
+ with pytest.raises(ValueError) as e_info:
+ util._handle_dataframe_input(input_example_dataframe_pandas_series_empty)
+
+ # Assert
+ assert isinstance(json, list)
+ assert json == input_example_dataframe_list
+ assert "can not be empty" in str(e_info.value)
+
+ def test_handle_dataframe_input_list(self, input_example_dataframe_list):
+ # Act
+ json = util._handle_dataframe_input(input_example_dataframe_list)
+
+ # Assert
+ assert isinstance(json, list)
+ assert json == input_example_dataframe_list
+
+ def test_handle_dataframe_input_unsupported(self):
+ # Act
+ with pytest.raises(TypeError) as e_info:
+ util._handle_dataframe_input(lambda unsupported: None)
+
+ # Assert
+ assert "is not a supported input example type" in str(e_info.value)
+
+ def test_handle_tensor_input(
+ self, input_example_numpy, input_example_dataframe_list
+ ):
+ # Act
+ json = util._handle_tensor_input(input_example_numpy)
+
+ # Assert
+ assert isinstance(json, list)
+ assert json == input_example_dataframe_list
+
+ def test_handle_dict_input(self, input_example_dict):
+ # Act
+ json = util._handle_dict_input(input_example_dict)
+
+ # Assert
+ assert isinstance(json, dict)
+ assert json == input_example_dict
+
+ # artifacts
+
+ def test_compress_dir(self, mocker):
+ # Arrange
+ archive_name = "archive_name"
+ path_to_archive = os.path.join("this", "is", "the", "path", "to", "archive")
+ archive_out_path = os.path.join(
+ "this", "is", "the", "output", "path", "to", "archive"
+ )
+ full_archive_out_path = os.path.join(archive_out_path, archive_name)
+ mock_isdir = mocker.patch("os.path.isdir", return_value=True)
+ mock_shutil_make_archive = mocker.patch(
+ "shutil.make_archive", return_value="resulting_path"
+ )
+
+ # Act
+ path = util.compress(archive_out_path, archive_name, path_to_archive)
+
+ # Assert
+ assert path == "resulting_path"
+ mock_isdir.assert_called_once_with(path_to_archive)
+ mock_shutil_make_archive.assert_called_once_with(
+ full_archive_out_path, "gztar", path_to_archive
+ )
+
+ def test_compress_file(self, mocker):
+ # Arrange
+ archive_name = "archive_name"
+ path_to_archive = os.path.join("path", "to", "archive")
+ archive_out_path = os.path.join("output", "path", "to", "archive")
+ full_archive_out_path = os.path.join(archive_out_path, archive_name)
+ archive_path_dirname = os.path.join("path", "to")
+ archive_path_basename = "archive"
+ mock_isdir = mocker.patch("os.path.isdir", return_value=False)
+ mock_shutil_make_archive = mocker.patch(
+ "shutil.make_archive", return_value="resulting_path"
+ )
+
+ # Act
+ path = util.compress(archive_out_path, archive_name, path_to_archive)
+
+ # Assert
+ assert path == "resulting_path"
+ mock_isdir.assert_called_once_with(path_to_archive)
+ mock_shutil_make_archive.assert_called_once_with(
+ full_archive_out_path, "gztar", archive_path_dirname, archive_path_basename
+ )
+
+ def test_decompress(self, mocker):
+ # Arrange
+ archive_file_path = os.path.join("path", "to", "archive", "file")
+ extract_dir = False
+ mock_shutil_unpack_archive = mocker.patch(
+ "shutil.unpack_archive", return_value="resulting_path"
+ )
+
+ # Act
+ path = util.decompress(archive_file_path, extract_dir)
+
+ # Assert
+ assert path == "resulting_path"
+ mock_shutil_unpack_archive.assert_called_once_with(
+ archive_file_path, extract_dir=extract_dir
+ )
+
+ # export models
+
+ def test_validate_metrics(self, model_metrics):
+ # Act
+ util.validate_metrics(model_metrics)
+
+ # Assert
+ # noop
+
+ def test_validate_metrics_unsupported_type(self, model_metrics_wrong_type):
+ # Act
+ with pytest.raises(TypeError) as e_info:
+ util.validate_metrics(model_metrics_wrong_type)
+
+ # Assert
+ assert "expected a dict" in str(e_info.value)
+
+ def test_validate_metrics_unsupported_metric_type(
+ self, model_metrics_wrong_metric_type
+ ):
+ # Act
+ with pytest.raises(TypeError) as e_info:
+ util.validate_metrics(model_metrics_wrong_metric_type)
+
+ # Assert
+ assert "expected a string" in str(e_info.value)
+
+ def test_validate_metrics_unsupported_metric_value(
+ self, model_metrics_wrong_metric_value
+ ):
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ util.validate_metrics(model_metrics_wrong_metric_value)
+
+ # Assert
+ assert "is not a number" in str(e_info.value)
+
+ # model serving
+
+ def test_get_predictor_for_model_base(self, mocker, model_base):
+ # Arrange
+ def pred_base_spec(model_framework, model_server):
+ pass
+
+ pred_base = mocker.patch(
+ "hsml.predictor.Predictor.__init__", return_value=None, spec=pred_base_spec
+ )
+ pred_python = mocker.patch("hsml.python.predictor.Predictor.__init__")
+ pred_sklearn = mocker.patch("hsml.sklearn.predictor.Predictor.__init__")
+ pred_tensorflow = mocker.patch("hsml.tensorflow.predictor.Predictor.__init__")
+ pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__")
+
+ # Act
+ predictor = util.get_predictor_for_model(model_base)
+
+ # Assert
+ assert isinstance(predictor, BasePredictor)
+ pred_base.assert_called_once_with(
+ model_framework=MODEL.FRAMEWORK_PYTHON, model_server=MODEL.FRAMEWORK_PYTHON
+ )
+ pred_python.assert_not_called()
+ pred_sklearn.assert_not_called()
+ pred_tensorflow.assert_not_called()
+ pred_torch.assert_not_called()
+
+ def test_get_predictor_for_model_python(self, mocker, model_python):
+ # Arrange
+ pred_base = mocker.patch("hsml.predictor.Predictor.__init__")
+ pred_python = mocker.patch(
+ "hsml.python.predictor.Predictor.__init__", return_value=None
+ )
+ pred_sklearn = mocker.patch("hsml.sklearn.predictor.Predictor.__init__")
+ pred_tensorflow = mocker.patch("hsml.tensorflow.predictor.Predictor.__init__")
+ pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__")
+
+ # Act
+ predictor = util.get_predictor_for_model(model_python)
+
+ # Assert
+ assert isinstance(predictor, PyPredictor)
+ pred_base.assert_not_called()
+ pred_python.assert_called_once()
+ pred_sklearn.assert_not_called()
+ pred_tensorflow.assert_not_called()
+ pred_torch.assert_not_called()
+
+ def test_get_predictor_for_model_sklearn(self, mocker, model_sklearn):
+ # Arrange
+ pred_base = mocker.patch("hsml.predictor.Predictor.__init__")
+ pred_python = mocker.patch("hsml.python.predictor.Predictor.__init__")
+ pred_sklearn = mocker.patch(
+ "hsml.sklearn.predictor.Predictor.__init__", return_value=None
+ )
+ pred_tensorflow = mocker.patch("hsml.tensorflow.predictor.Predictor.__init__")
+ pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__")
+
+ # Act
+ predictor = util.get_predictor_for_model(model_sklearn)
+
+ # Assert
+ assert isinstance(predictor, SkLearnPredictor)
+ pred_base.assert_not_called()
+ pred_python.assert_not_called()
+ pred_sklearn.assert_called_once()
+ pred_tensorflow.assert_not_called()
+ pred_torch.assert_not_called()
+
+ def test_get_predictor_for_model_tensorflow(self, mocker, model_tensorflow):
+ # Arrange
+ pred_base = mocker.patch("hsml.predictor.Predictor.__init__")
+ pred_python = mocker.patch("hsml.python.predictor.Predictor.__init__")
+ pred_sklearn = mocker.patch("hsml.sklearn.predictor.Predictor.__init__")
+ pred_tensorflow = mocker.patch(
+ "hsml.tensorflow.predictor.Predictor.__init__", return_value=None
+ )
+ pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__")
+
+ # Act
+ predictor = util.get_predictor_for_model(model_tensorflow)
+
+ # Assert
+ assert isinstance(predictor, TFPredictor)
+ pred_base.assert_not_called()
+ pred_python.assert_not_called()
+ pred_sklearn.assert_not_called()
+ pred_tensorflow.assert_called_once()
+ pred_torch.assert_not_called()
+
+ def test_get_predictor_for_model_torch(self, mocker, model_torch):
+ # Arrange
+ pred_base = mocker.patch("hsml.predictor.Predictor.__init__")
+ pred_python = mocker.patch("hsml.python.predictor.Predictor.__init__")
+ pred_sklearn = mocker.patch("hsml.sklearn.predictor.Predictor.__init__")
+ pred_tensorflow = mocker.patch("hsml.tensorflow.predictor.Predictor.__init__")
+ pred_torch = mocker.patch(
+ "hsml.torch.predictor.Predictor.__init__", return_value=None
+ )
+
+ # Act
+ predictor = util.get_predictor_for_model(model_torch)
+
+ # Assert
+ assert isinstance(predictor, TorchPredictor)
+ pred_base.assert_not_called()
+ pred_python.assert_not_called()
+ pred_sklearn.assert_not_called()
+ pred_tensorflow.assert_not_called()
+ pred_torch.assert_called_once()
+
+ def test_get_predictor_for_model_non_base(self, mocker):
+ # Arrange
+ pred_base = mocker.patch("hsml.predictor.Predictor.__init__")
+ pred_python = mocker.patch("hsml.python.predictor.Predictor.__init__")
+ pred_sklearn = mocker.patch("hsml.sklearn.predictor.Predictor.__init__")
+ pred_tensorflow = mocker.patch("hsml.tensorflow.predictor.Predictor.__init__")
+ pred_torch = mocker.patch("hsml.torch.predictor.Predictor.__init__")
+
+ class NonBaseModel:
+ pass
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ util.get_predictor_for_model(NonBaseModel())
+
+ assert "an instance of {} class is expected".format(BaseModel) in str(
+ e_info.value
+ )
+ pred_base.assert_not_called()
+ pred_python.assert_not_called()
+ pred_sklearn.assert_not_called()
+ pred_tensorflow.assert_not_called()
+ pred_torch.assert_not_called()
+
+ def test_get_hostname_replaced_url(self, mocker):
+ # Arrange
+ sub_path = "this/is/a/sub_path"
+ base_url = "/hopsworks/api/base/"
+ urlparse_href_arg = ParseResult(
+ scheme="",
+ netloc="",
+ path=base_url + sub_path,
+ params="",
+ query="",
+ fragment="",
+ )
+ geturl_return = "final_url"
+ mock_url_parsed = mocker.MagicMock()
+ mock_url_parsed.geturl = mocker.MagicMock(return_value=geturl_return)
+ mock_client = mocker.MagicMock()
+ mock_client._base_url = base_url + "url"
+ mock_client._replace_public_host = mocker.MagicMock(
+ return_value=mock_url_parsed
+ )
+ mocker.patch("hsml.client.get_instance", return_value=mock_client)
+
+ # Act
+ url = util.get_hostname_replaced_url(sub_path)
+
+ # Assert
+ mock_client._replace_public_host.assert_called_once_with(urlparse_href_arg)
+ mock_url_parsed.geturl.assert_called_once()
+ assert url == geturl_return
+
+ # general
+
+ def test_get_members(self):
+ # Arrange
+ class TEST:
+ TEST_1 = 1
+ TEST_2 = "two"
+ TEST_3 = "3"
+
+ # Act
+ members = list(util.get_members(TEST))
+
+ # Assert
+ assert members == [1, "two", "3"]
+
+ def test_get_members_with_prefix(self):
+ # Arrange
+ class TEST:
+ TEST_1 = 1
+ TEST_2 = "two"
+ RES_3 = "3"
+ NONE = None
+
+ # Act
+ members = list(util.get_members(TEST, prefix="TEST"))
+
+ # Assert
+ assert members == [1, "two"]
+
+ # json
+
+ def test_extract_field_from_json(self, mocker):
+ # Arrange
+ json = {"a": "1", "b": "2"}
+ get_obj_from_json = mocker.patch("hsml.util.get_obj_from_json")
+
+ # Act
+ b = util.extract_field_from_json(json, "b")
+
+ # Assert
+ assert b == "2"
+ assert get_obj_from_json.call_count == 0
+
+ def test_extract_field_from_json_fields(self, mocker):
+ # Arrange
+ json = {"a": "1", "b": "2"}
+ get_obj_from_json = mocker.patch("hsml.util.get_obj_from_json")
+
+ # Act
+ b = util.extract_field_from_json(json, ["B", "b"]) # alternative fields
+
+ # Assert
+ assert b == "2"
+ assert get_obj_from_json.call_count == 0
+
+ def test_extract_field_from_json_as_instance_of_str(self, mocker):
+ # Arrange
+ json = {"a": "1", "b": "2"}
+ get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value="2"
+ )
+
+ # Act
+ b = util.extract_field_from_json(json, "b", as_instance_of=str)
+
+ # Assert
+ assert b == "2"
+ get_obj_from_json.assert_called_once_with(obj="2", cls=str)
+
+ def test_extract_field_from_json_as_instance_of_list_str(self, mocker):
+ # Arrange
+ json = {"a": "1", "b": ["2", "2", "2"]}
+ get_obj_from_json = mocker.patch(
+ "hsml.util.get_obj_from_json", return_value="2"
+ )
+
+ # Act
+ b = util.extract_field_from_json(json, "b", as_instance_of=str)
+
+ # Assert
+ assert b == ["2", "2", "2"]
+ assert get_obj_from_json.call_count == 3
+ assert get_obj_from_json.call_args[1]["obj"] == "2"
+ assert get_obj_from_json.call_args[1]["cls"] == str
+
+ def test_get_obj_from_json_cls(self, mocker):
+ # Arrange
+ class Test:
+ def __init__(self):
+ self.a = "1"
+
+ # Act
+ obj = util.get_obj_from_json(Test(), Test)
+
+ # Assert
+ assert isinstance(obj, Test)
+ assert obj.a == "1"
+
+ def test_get_obj_from_json_dict(self, mocker):
+ # Arrange
+ class Test:
+ def __init__(self, a):
+ self.a = a
+
+ @classmethod
+ def from_json(cls, json):
+ return cls(**json)
+
+ # Act
+ obj = util.get_obj_from_json({"a": "1"}, Test)
+
+ # Assert
+ assert isinstance(obj, Test)
+ assert obj.a == "1"
+
+ def test_get_obj_from_json_dict_default(self, mocker):
+ # Arrange
+ class Test:
+ def __init__(self, a="11"):
+ self.a = "11"
+
+ @classmethod
+ def from_json(cls, json):
+ return cls(**json)
+
+ # Act
+ obj = util.get_obj_from_json({}, Test)
+
+ # Assert
+ assert isinstance(obj, Test)
+ assert obj.a == "11"
+
+ def test_get_obj_from_json_unsupported(self, mocker):
+ # Arrange
+ class Test:
+ pass
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ util.get_obj_from_json("UNSUPPORTED", Test)
+
+ # Assert
+ assert "cannot be converted to class" in str(e_info.value)
diff --git a/hsml/python/tests/utils/__init__.py b/hsml/python/tests/utils/__init__.py
new file mode 100644
index 000000000..ff8055b9b
--- /dev/null
+++ b/hsml/python/tests/utils/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
diff --git a/hsml/python/tests/utils/schema/test_column.py b/hsml/python/tests/utils/schema/test_column.py
new file mode 100644
index 000000000..0a41ef205
--- /dev/null
+++ b/hsml/python/tests/utils/schema/test_column.py
@@ -0,0 +1,44 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from hsml.utils.schema import column
+
+
+class TestColumn:
+ def test_constructor_default(self):
+ # Arrange
+ _type = 1234
+
+ # Act
+ t = column.Column(_type)
+
+ # Assert
+ assert t.type == str(_type)
+ assert not hasattr(t, "name")
+ assert not hasattr(t, "description")
+
+ def test_constructor(self):
+ # Arrange
+ _type = 1234
+ name = 1111111
+ description = 2222222
+
+ # Act
+ t = column.Column(_type, name, description)
+
+ # Assert
+ assert t.type == str(_type)
+ assert t.name == str(name)
+ assert t.description == str(description)
diff --git a/hsml/python/tests/utils/schema/test_columnar_schema.py b/hsml/python/tests/utils/schema/test_columnar_schema.py
new file mode 100644
index 000000000..c01c3c33d
--- /dev/null
+++ b/hsml/python/tests/utils/schema/test_columnar_schema.py
@@ -0,0 +1,461 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pandas as pd
+import pytest
+from hsml.utils.schema import column, columnar_schema
+from mock import call
+
+
+class TestColumnarSchema:
+ # constructor
+
+ def test_constructor_default(self, mocker):
+ # Arrange
+ mock_convert_list_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_list_to_schema",
+ return_value="convert_list_to_schema",
+ )
+ mock_convert_pandas_df_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_df_to_schema",
+ return_value="convert_pandas_df_to_schema",
+ )
+ mock_convert_pandas_series_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_series_to_schema",
+ return_value="convert_pandas_series_to_schema",
+ )
+ mock_convert_spark_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_spark_to_schema",
+ return_value="convert_spark_to_schema",
+ )
+ mock_convert_td_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_td_to_schema",
+ return_value="convert_td_to_schema",
+ )
+ mock_find_spec = mocker.patch("importlib.util.find_spec", return_value=None)
+
+ # Act
+ with pytest.raises(TypeError) as e_info:
+ _ = columnar_schema.ColumnarSchema()
+
+ # Assert
+ assert "is not supported in a columnar schema" in str(e_info.value)
+ mock_convert_list_to_schema.assert_not_called()
+ mock_convert_pandas_df_to_schema.assert_not_called()
+ mock_convert_pandas_series_to_schema.assert_not_called()
+ mock_convert_spark_to_schema.assert_not_called()
+ mock_convert_td_to_schema.assert_not_called()
+ assert mock_find_spec.call_count == 2
+
+ def test_constructor_list(self, mocker):
+ # Arrange
+ columnar_obj = [1, 2, 3, 4]
+ mock_convert_list_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_list_to_schema",
+ return_value="convert_list_to_schema",
+ )
+ mock_convert_pandas_df_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_df_to_schema",
+ return_value="convert_pandas_df_to_schema",
+ )
+ mock_convert_pandas_series_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_series_to_schema",
+ return_value="convert_pandas_series_to_schema",
+ )
+ mock_convert_spark_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_spark_to_schema",
+ return_value="convert_spark_to_schema",
+ )
+ mock_convert_td_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_td_to_schema",
+ return_value="convert_td_to_schema",
+ )
+ mock_find_spec = mocker.patch("importlib.util.find_spec", return_value=None)
+
+ # Act
+ cs = columnar_schema.ColumnarSchema(columnar_obj)
+
+ # Assert
+ assert cs.columns == "convert_list_to_schema"
+ mock_convert_list_to_schema.assert_called_once_with(columnar_obj)
+ mock_convert_pandas_df_to_schema.assert_not_called()
+ mock_convert_pandas_series_to_schema.assert_not_called()
+ mock_convert_spark_to_schema.assert_not_called()
+ mock_convert_td_to_schema.assert_not_called()
+ mock_find_spec.assert_not_called()
+
+ def test_constructor_pd_dataframe(self, mocker):
+ # Arrange
+ columnar_obj = pd.DataFrame([1, 2, 3, 4])
+ mock_convert_list_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_list_to_schema",
+ return_value="convert_list_to_schema",
+ )
+ mock_convert_pandas_df_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_df_to_schema",
+ return_value="convert_pandas_df_to_schema",
+ )
+ mock_convert_pandas_series_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_series_to_schema",
+ return_value="convert_pandas_series_to_schema",
+ )
+ mock_convert_spark_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_spark_to_schema",
+ return_value="convert_spark_to_schema",
+ )
+ mock_convert_td_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_td_to_schema",
+ return_value="convert_td_to_schema",
+ )
+ mock_find_spec = mocker.patch("importlib.util.find_spec", return_value=None)
+
+ # Act
+ cs = columnar_schema.ColumnarSchema(columnar_obj)
+
+ # Assert
+ assert cs.columns == "convert_pandas_df_to_schema"
+ mock_convert_list_to_schema.assert_not_called()
+ mock_convert_pandas_df_to_schema.assert_called_once_with(columnar_obj)
+ mock_convert_pandas_series_to_schema.assert_not_called()
+ mock_convert_spark_to_schema.assert_not_called()
+ mock_convert_td_to_schema.assert_not_called()
+ mock_find_spec.assert_not_called()
+
+ def test_constructor_pd_series(self, mocker):
+ # Arrange
+ columnar_obj = pd.Series([1, 2, 3, 4])
+ mock_convert_list_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_list_to_schema",
+ return_value="convert_list_to_schema",
+ )
+ mock_convert_pandas_df_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_df_to_schema",
+ return_value="convert_pandas_df_to_schema",
+ )
+ mock_convert_pandas_series_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_series_to_schema",
+ return_value="convert_pandas_series_to_schema",
+ )
+ mock_convert_spark_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_spark_to_schema",
+ return_value="convert_spark_to_schema",
+ )
+ mock_convert_td_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_td_to_schema",
+ return_value="convert_td_to_schema",
+ )
+ mock_find_spec = mocker.patch("importlib.util.find_spec", return_value=None)
+
+ # Act
+ cs = columnar_schema.ColumnarSchema(columnar_obj)
+
+ # Assert
+ assert cs.columns == "convert_pandas_series_to_schema"
+ mock_convert_list_to_schema.assert_not_called()
+ mock_convert_pandas_df_to_schema.assert_not_called()
+ mock_convert_pandas_series_to_schema.assert_called_once_with(columnar_obj)
+ mock_convert_spark_to_schema.assert_not_called()
+ mock_convert_td_to_schema.assert_not_called()
+ mock_find_spec.assert_not_called()
+
+ def test_constructor_pyspark_dataframe(self, mocker):
+ try:
+ import pyspark
+ except ImportError:
+ pytest.skip("pyspark not available")
+
+ # Arrange
+ columnar_obj = mocker.MagicMock(spec=pyspark.sql.dataframe.DataFrame)
+ mock_convert_list_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_list_to_schema",
+ return_value="convert_list_to_schema",
+ )
+ mock_convert_pandas_df_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_df_to_schema",
+ return_value="convert_pandas_df_to_schema",
+ )
+ mock_convert_pandas_series_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_series_to_schema",
+ return_value="convert_pandas_series_to_schema",
+ )
+ mock_convert_spark_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_spark_to_schema",
+ return_value="convert_spark_to_schema",
+ )
+ mock_convert_td_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_td_to_schema",
+ return_value="convert_td_to_schema",
+ )
+ mock_find_spec = mocker.patch(
+ "importlib.util.find_spec", return_value="Not None"
+ )
+
+ # Act
+ cs = columnar_schema.ColumnarSchema(columnar_obj)
+
+ # Assert
+ assert cs.columns == "convert_spark_to_schema"
+ mock_convert_list_to_schema.assert_not_called()
+ mock_convert_pandas_df_to_schema.assert_not_called()
+ mock_convert_pandas_series_to_schema.assert_not_called()
+ mock_convert_spark_to_schema.assert_called_once_with(columnar_obj)
+ mock_convert_td_to_schema.assert_not_called()
+ mock_find_spec.assert_called_once_with("pyspark")
+
+ def test_constructor_hsfs_td(self, mocker):
+ # Arrange
+ try:
+ import hsfs
+ except ImportError:
+ pytest.skip("hsfs not available")
+
+ # Arrange
+ columnar_obj = mocker.MagicMock(spec=hsfs.training_dataset.TrainingDataset)
+ mock_convert_list_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_list_to_schema",
+ return_value="convert_list_to_schema",
+ )
+ mock_convert_pandas_df_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_df_to_schema",
+ return_value="convert_pandas_df_to_schema",
+ )
+ mock_convert_pandas_series_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_pandas_series_to_schema",
+ return_value="convert_pandas_series_to_schema",
+ )
+ mock_convert_spark_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_spark_to_schema",
+ return_value="convert_spark_to_schema",
+ )
+ mock_convert_td_to_schema = mocker.patch(
+ "hsml.utils.schema.columnar_schema.ColumnarSchema._convert_td_to_schema",
+ return_value="convert_td_to_schema",
+ )
+ mock_find_spec = mocker.patch(
+ "importlib.util.find_spec", return_value="Not None"
+ )
+
+ # Act
+ cs = columnar_schema.ColumnarSchema(columnar_obj)
+
+ # Assert
+ assert cs.columns == "convert_td_to_schema"
+ mock_convert_list_to_schema.assert_not_called()
+ mock_convert_pandas_df_to_schema.assert_not_called()
+ mock_convert_pandas_series_to_schema.assert_not_called()
+ mock_convert_spark_to_schema.assert_not_called()
+ mock_convert_td_to_schema.assert_called_once_with(columnar_obj)
+ assert mock_find_spec.call_count == 2
+
+ # convert list to schema
+
+ def test_convert_list_to_schema(self, mocker):
+ # Arrange
+ columnar_obj = [1, 2, 3, 4]
+ mock_columnar_schema = mocker.MagicMock()
+ mock_columnar_schema._convert_list_to_schema = (
+ columnar_schema.ColumnarSchema._convert_list_to_schema
+ )
+ mock_columnar_schema._build_column.side_effect = columnar_obj
+
+ # Act
+ c = mock_columnar_schema._convert_list_to_schema(
+ mock_columnar_schema, columnar_obj
+ )
+
+ # Assert
+ expected_calls = [call(cv) for cv in columnar_obj]
+ mock_columnar_schema._build_column.assert_has_calls(expected_calls)
+ assert mock_columnar_schema._build_column.call_count == len(columnar_obj)
+ assert c == columnar_obj
+
+ # convert pandas df to schema
+
+ def test_convert_pd_dataframe_to_schema(self, mocker):
+ # Arrange
+ columnar_obj = pd.DataFrame([[1, 2], [3, 4], [1, 2], [3, 4]])
+ mock_column_init = mocker.patch(
+ "hsml.utils.schema.column.Column.__init__", return_value=None
+ )
+ mock_columnar_schema = mocker.MagicMock()
+ mock_columnar_schema._convert_pandas_df_to_schema = (
+ columnar_schema.ColumnarSchema._convert_pandas_df_to_schema
+ )
+
+ # Act
+ c = mock_columnar_schema._convert_pandas_df_to_schema(
+ mock_columnar_schema, columnar_obj
+ )
+
+ # Assert
+ cols = columnar_obj.columns
+ dtypes = columnar_obj.dtypes
+ expected_calls = [call(dtypes[col], name=col) for col in cols]
+ mock_column_init.assert_has_calls(expected_calls)
+ assert mock_column_init.call_count == 2
+ assert len(c) == 2
+
+ # convert pandas series to schema
+
+ def test_convert_pd_series_to_schema(self, mocker):
+ # Arrange
+ columnar_obj = pd.Series([1, 2, 3, 4])
+ mock_column_init = mocker.patch(
+ "hsml.utils.schema.column.Column.__init__", return_value=None
+ )
+ mock_columnar_schema = mocker.MagicMock()
+ mock_columnar_schema._convert_pandas_series_to_schema = (
+ columnar_schema.ColumnarSchema._convert_pandas_series_to_schema
+ )
+
+ # Act
+ c = mock_columnar_schema._convert_pandas_series_to_schema(
+ mock_columnar_schema, columnar_obj
+ )
+
+ # Assert
+ expected_call = call(columnar_obj.dtype, name=columnar_obj.name)
+ mock_column_init.assert_has_calls([expected_call])
+ assert mock_column_init.call_count == 1
+ assert len(c) == 1
+
+ # convert spark to schema
+
+ def test_convert_spark_to_schema(self, mocker):
+ # Arrange
+ try:
+ import pyspark
+ except ImportError:
+ pytest.skip("pyspark not available")
+
+ # Arrange
+ columnar_obj = mocker.MagicMock(spec=pyspark.sql.dataframe.DataFrame)
+ columnar_obj.dtypes = [("name_1", "type_1"), ("name_2", "type_2")]
+ mock_column_init = mocker.patch(
+ "hsml.utils.schema.column.Column.__init__", return_value=None
+ )
+ mock_columnar_schema = mocker.MagicMock()
+ mock_columnar_schema._convert_spark_to_schema = (
+ columnar_schema.ColumnarSchema._convert_spark_to_schema
+ )
+
+ # Act
+ c = mock_columnar_schema._convert_spark_to_schema(
+ mock_columnar_schema, columnar_obj
+ )
+
+ # Assert
+ expected_calls = [call(dtype, name=name) for name, dtype in columnar_obj.dtypes]
+ mock_column_init.assert_has_calls(expected_calls)
+ assert mock_column_init.call_count == len(columnar_obj.dtypes)
+ assert len(c) == len(columnar_obj.dtypes)
+
+ # convert td to schema
+
+ def test_convert_td_to_schema(self, mocker):
+ # Arrange
+ class MockFeature:
+ def __init__(self, fname, ftype):
+ self.name = fname
+ self.type = ftype
+
+ columnar_obj = mocker.MagicMock()
+ columnar_obj.schema = [
+ MockFeature("name_1", "type_1"),
+ MockFeature("name_2", "type_2"),
+ ]
+ mock_column_init = mocker.patch(
+ "hsml.utils.schema.column.Column.__init__", return_value=None
+ )
+ mock_columnar_schema = mocker.MagicMock()
+ mock_columnar_schema._convert_td_to_schema = (
+ columnar_schema.ColumnarSchema._convert_td_to_schema
+ )
+
+ # Act
+ c = mock_columnar_schema._convert_td_to_schema(
+ mock_columnar_schema, columnar_obj
+ )
+
+ # Assert
+ expected_calls = [
+ call(feat.type, name=feat.name) for feat in columnar_obj.schema
+ ]
+ mock_column_init.assert_has_calls(expected_calls)
+ assert mock_column_init.call_count == len(columnar_obj.schema)
+ assert len(c) == len(columnar_obj.schema)
+
+ # build column
+
+ def test_build_column_type_only(self, mocker):
+ # Arrange
+ columnar_obj = {"type": "tensor_type"}
+ mock_column_init = mocker.patch(
+ "hsml.utils.schema.column.Column.__init__", return_value=None
+ )
+ mock_columnar_schema = mocker.MagicMock()
+ mock_columnar_schema._build_column = (
+ columnar_schema.ColumnarSchema._build_column
+ )
+
+ # Act
+ c = mock_columnar_schema._build_column(mock_columnar_schema, columnar_obj)
+
+ # Assert
+ assert isinstance(c, column.Column)
+ mock_column_init.assert_called_once_with(
+ columnar_obj["type"], name=None, description=None
+ )
+
+ def test_build_tensor_invalid_missing_type(self, mocker):
+ # Arrange
+ columnar_obj = {}
+ mock_columnar_schema = mocker.MagicMock()
+ mock_columnar_schema._build_column = (
+ columnar_schema.ColumnarSchema._build_column
+ )
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = mock_columnar_schema._build_column(mock_columnar_schema, columnar_obj)
+
+ # Assert
+ assert "Mandatory 'type' key missing from entry" in str(e_info.value)
+
+ def test_build_tensor_type_name_and_description(self, mocker):
+ # Arrange
+ columnar_obj = {
+ "type": "tensor_type",
+ "name": "tensor_name",
+ "description": "tensor_description",
+ }
+ mock_column_init = mocker.patch(
+ "hsml.utils.schema.column.Column.__init__", return_value=None
+ )
+ mock_columnar_schema = mocker.MagicMock()
+ mock_columnar_schema._build_column = (
+ columnar_schema.ColumnarSchema._build_column
+ )
+
+ # Act
+ c = mock_columnar_schema._build_column(mock_columnar_schema, columnar_obj)
+
+ # Assert
+ assert isinstance(c, column.Column)
+ mock_column_init.assert_called_once_with(
+ columnar_obj["type"],
+ name=columnar_obj["name"],
+ description=columnar_obj["description"],
+ )
diff --git a/hsml/python/tests/utils/schema/test_tensor.py b/hsml/python/tests/utils/schema/test_tensor.py
new file mode 100644
index 000000000..22c2ab360
--- /dev/null
+++ b/hsml/python/tests/utils/schema/test_tensor.py
@@ -0,0 +1,48 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from hsml.utils.schema import tensor
+
+
+class TestTensor:
+ def test_constructor_default(self):
+ # Arrange
+ _type = 1234
+ shape = 4321
+
+ # Act
+ t = tensor.Tensor(_type, shape)
+
+ # Assert
+ assert t.type == str(_type)
+ assert t.shape == str(shape)
+ assert not hasattr(t, "name")
+ assert not hasattr(t, "description")
+
+ def test_constructor(self):
+ # Arrange
+ _type = 1234
+ shape = 4321
+ name = 1111111
+ description = 2222222
+
+ # Act
+ t = tensor.Tensor(_type, shape, name, description)
+
+ # Assert
+ assert t.type == str(_type)
+ assert t.shape == str(shape)
+ assert t.name == str(name)
+ assert t.description == str(description)
diff --git a/hsml/python/tests/utils/schema/test_tensor_schema.py b/hsml/python/tests/utils/schema/test_tensor_schema.py
new file mode 100644
index 000000000..18afb3fdc
--- /dev/null
+++ b/hsml/python/tests/utils/schema/test_tensor_schema.py
@@ -0,0 +1,204 @@
+#
+# Copyright 2024 Hopsworks AB
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import pytest
+from hsml.utils.schema import tensor, tensor_schema
+
+
+class TestTensorSchema:
+ # constructor
+
+ def test_constructor_default(self):
+ # Act
+ with pytest.raises(TypeError) as e_info:
+ _ = tensor_schema.TensorSchema()
+
+ # Assert
+ assert "is not supported in a tensor schema" in str(e_info.value)
+
+ def test_constructor_invalid(self):
+ # Act
+ with pytest.raises(TypeError) as e_info:
+ _ = tensor_schema.TensorSchema("invalid")
+
+ # Assert
+ assert "is not supported in a tensor schema" in str(e_info.value)
+
+ def test_constructor_list(self, mocker):
+ # Arrange
+ tensor_obj = [1234, 4321, 1111111, 2222222]
+ mock_convert_list_to_schema = mocker.patch(
+ "hsml.utils.schema.tensor_schema.TensorSchema._convert_list_to_schema",
+ return_value="list_to_schema",
+ )
+ mock_convert_tensor_to_schema = mocker.patch(
+ "hsml.utils.schema.tensor_schema.TensorSchema._convert_tensor_to_schema",
+ return_value="tensor_to_schema",
+ )
+
+ # Act
+ ts = tensor_schema.TensorSchema(tensor_obj)
+
+ # Assert
+ assert ts.tensors == "list_to_schema"
+ mock_convert_list_to_schema.assert_called_once_with(tensor_obj)
+ mock_convert_tensor_to_schema.assert_not_called()
+
+ def test_constructor_ndarray(self, mocker):
+ # Arrange
+ tensor_obj = np.array([1234, 4321, 1111111, 2222222])
+ mock_convert_list_to_schema = mocker.patch(
+ "hsml.utils.schema.tensor_schema.TensorSchema._convert_list_to_schema",
+ return_value="list_to_schema",
+ )
+ mock_convert_tensor_to_schema = mocker.patch(
+ "hsml.utils.schema.tensor_schema.TensorSchema._convert_tensor_to_schema",
+ return_value="tensor_to_schema",
+ )
+
+ # Act
+ ts = tensor_schema.TensorSchema(tensor_obj)
+
+ # Assert
+ assert ts.tensors == "tensor_to_schema"
+ mock_convert_tensor_to_schema.assert_called_once_with(tensor_obj)
+ mock_convert_list_to_schema.assert_not_called()
+
+ # convert tensor to schema
+
+ def test_convert_tensor_to_schema(self, mocker):
+ # Arrange
+ tensor_obj = mocker.MagicMock()
+ mock_tensor_schema = mocker.MagicMock()
+ mock_tensor_schema._convert_tensor_to_schema = (
+ tensor_schema.TensorSchema._convert_tensor_to_schema
+ )
+ mock_tensor_init = mocker.patch(
+ "hsml.utils.schema.tensor.Tensor.__init__", return_value=None
+ )
+
+ # Act
+ t = mock_tensor_schema._convert_tensor_to_schema(mock_tensor_schema, tensor_obj)
+
+ # Assert
+ assert isinstance(t, tensor.Tensor)
+ mock_tensor_init.assert_called_once_with(tensor_obj.dtype, tensor_obj.shape)
+
+ # convert list to schema
+
+ def test_convert_list_to_schema_singleton(self, mocker):
+ # Arrange
+ tensor_obj = [1234]
+ mock_tensor_schema = mocker.MagicMock()
+ mock_tensor_schema._convert_list_to_schema = (
+ tensor_schema.TensorSchema._convert_list_to_schema
+ )
+
+ # Act
+ t = mock_tensor_schema._convert_list_to_schema(mock_tensor_schema, tensor_obj)
+
+ # Assert
+ assert isinstance(t, list)
+ assert len(t) == len(tensor_obj)
+ mock_tensor_schema._build_tensor.assert_called_once_with(1234)
+
+ def test_convert_list_to_schema_list(self, mocker):
+ # Arrange
+ tensor_obj = np.array([1234, 4321, 1111111, 2222222])
+ mock_tensor_schema = mocker.MagicMock()
+ mock_tensor_schema._convert_list_to_schema = (
+ tensor_schema.TensorSchema._convert_list_to_schema
+ )
+
+ # Act
+ t = mock_tensor_schema._convert_list_to_schema(mock_tensor_schema, tensor_obj)
+
+ # Assert
+ assert isinstance(t, list)
+ assert len(t) == len(tensor_obj)
+ assert mock_tensor_schema._build_tensor.call_count == len(tensor_obj)
+
+ # build tensor
+
+ def test_build_tensor_type_and_shape_only(self, mocker):
+ # Arrange
+ tensor_obj = {"type": "tensor_type", "shape": "tensor_shape"}
+ mock_tensor_init = mocker.patch(
+ "hsml.utils.schema.tensor.Tensor.__init__", return_value=None
+ )
+ mock_tensor_schema = mocker.MagicMock()
+ mock_tensor_schema._build_tensor = tensor_schema.TensorSchema._build_tensor
+
+ # Act
+ t = mock_tensor_schema._build_tensor(mock_tensor_schema, tensor_obj)
+
+ # Assert
+ assert isinstance(t, tensor.Tensor)
+ mock_tensor_init.assert_called_once_with(
+ tensor_obj["type"], tensor_obj["shape"], name=None, description=None
+ )
+
+ def test_build_tensor_invalid_missing_type(self, mocker):
+ # Arrange
+ tensor_obj = {"shape": "tensor_shape"}
+ mock_tensor_schema = mocker.MagicMock()
+ mock_tensor_schema._build_tensor = tensor_schema.TensorSchema._build_tensor
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = mock_tensor_schema._build_tensor(mock_tensor_schema, tensor_obj)
+
+ # Assert
+ assert "Mandatory 'type' key missing from entry" in str(e_info.value)
+
+ def test_build_tensor_invalid_missing_shape(self, mocker):
+ # Arrange
+ tensor_obj = {"type": "tensor_type"}
+ mock_tensor_schema = mocker.MagicMock()
+ mock_tensor_schema._build_tensor = tensor_schema.TensorSchema._build_tensor
+
+ # Act
+ with pytest.raises(ValueError) as e_info:
+ _ = mock_tensor_schema._build_tensor(mock_tensor_schema, tensor_obj)
+
+ # Assert
+ assert "Mandatory 'shape' key missing from entry" in str(e_info.value)
+
+ def test_build_tensor_type_shape_name_and_description(self, mocker):
+ # Arrange
+ tensor_obj = {
+ "type": "tensor_type",
+ "shape": "tensor_shape",
+ "name": "tensor_name",
+ "description": "tensor_description",
+ }
+ mock_tensor_init = mocker.patch(
+ "hsml.utils.schema.tensor.Tensor.__init__", return_value=None
+ )
+ mock_tensor_schema = mocker.MagicMock()
+ mock_tensor_schema._build_tensor = tensor_schema.TensorSchema._build_tensor
+
+ # Act
+ t = mock_tensor_schema._build_tensor(mock_tensor_schema, tensor_obj)
+
+ # Assert
+ assert isinstance(t, tensor.Tensor)
+ mock_tensor_init.assert_called_once_with(
+ tensor_obj["type"],
+ tensor_obj["shape"],
+ name=tensor_obj["name"],
+ description=tensor_obj["description"],
+ )
diff --git a/hsml/requirements-docs.txt b/hsml/requirements-docs.txt
new file mode 100644
index 000000000..d1499a262
--- /dev/null
+++ b/hsml/requirements-docs.txt
@@ -0,0 +1,11 @@
+mkdocs==1.5.3
+mkdocs-material==9.5.17
+mike==2.0.0
+sphinx==7.2.6
+keras_autodoc @ git+https://git@github.com/logicalclocks/keras-autodoc
+markdown-include==0.8.1
+mkdocs-jupyter==0.24.3
+markdown==3.6
+pymdown-extensions==10.7.1
+mkdocs-macros-plugin==1.0.4
+mkdocs-minify-plugin>=0.2.0