diff --git a/.github/actions/setup_environment/action.yml b/.github/actions/setup_environment/action.yml index 5ea5d222928..05c3b518d2f 100644 --- a/.github/actions/setup_environment/action.yml +++ b/.github/actions/setup_environment/action.yml @@ -123,4 +123,4 @@ runs: run: |- zenml integration list uv pip list - pip check || true + uv pip check || true diff --git a/.github/workflows/ci-slow.yml b/.github/workflows/ci-slow.yml index 470a4f21046..28f8683681d 100644 --- a/.github/workflows/ci-slow.yml +++ b/.github/workflows/ci-slow.yml @@ -83,15 +83,22 @@ jobs: uses: actions/setup-python@v5.0.0 with: python-version: '3.8' - - name: Install current package as editable + - name: Install uv run: | curl -LsSf https://astral.sh/uv/install.sh | sh source $HOME/.cargo/env - uv pip install --system -e . - - name: Install mlstacks package - run: uv pip install --system mlstacks + - name: Create virtual environment + run: | + uv venv + - name: Check mlstacks compatibility + run: | + source .venv/bin/activate + uv pip install -e . + uv pip install mlstacks - name: Check for broken dependencies - run: pip check + run: | + source .venv/bin/activate + uv pip check - name: Markdown link check uses: gaurav-nelson/github-action-markdown-link-check@1.0.15 with: @@ -104,14 +111,18 @@ jobs: continue-on-error: true - name: Security check run: | - uv pip install --system bandit + source .venv/bin/activate + uv pip install bandit bash scripts/check-security.sh - name: Check for alembic branch divergence env: ZENML_DEBUG: 0 run: | - uv pip install --system alembic + source .venv/bin/activate + uv pip install alembic bash scripts/check-alembic-branches.sh + - name: Install latest dashboard (test gitignore) + run: bash scripts/install-dashboard.sh custom-ubuntu-unit-test: if: github.event.pull_request.draft == false needs: run-slow-ci-label-is-set diff --git a/.github/workflows/generate-test-duration.yml b/.github/workflows/generate-test-duration.yml index 2710de2ffb9..a4c441c1bf9 100644 --- a/.github/workflows/generate-test-duration.yml +++ b/.github/workflows/generate-test-duration.yml @@ -55,4 +55,4 @@ jobs: run: |- zenml integration list uv pip list - pip check || true + uv pip check || true diff --git a/.github/workflows/integration-test-fast.yml b/.github/workflows/integration-test-fast.yml index 0db71af4d83..d3e8e239902 100644 --- a/.github/workflows/integration-test-fast.yml +++ b/.github/workflows/integration-test-fast.yml @@ -239,4 +239,4 @@ jobs: run: |- zenml integration list uv pip list - pip check || true + uv pip check || true diff --git a/.github/workflows/integration-test-slow.yml b/.github/workflows/integration-test-slow.yml index d742a50b18f..494233aa43f 100644 --- a/.github/workflows/integration-test-slow.yml +++ b/.github/workflows/integration-test-slow.yml @@ -247,4 +247,4 @@ jobs: run: |- zenml integration list uv pip list - pip check || true + uv pip check || true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9c7a9b78809..53e717efa8b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,7 +1,5 @@ --- -# This is a basic workflow to help you get started with Actions name: Release Package & Docker Image -# Controls when the action will run. Triggers the workflow on push of a tag on: push: tags: ['*'] @@ -22,14 +20,25 @@ jobs: uses: actions/setup-python@v5.0.0 with: python-version: '3.8' - - name: Install current package as editable + - name: Install uv run: | - pip install -U uv - uv pip install --system -e . - - name: Install mlstacks package - run: uv pip install --system mlstacks + curl -LsSf https://astral.sh/uv/install.sh | sh + source $HOME/.cargo/env + - name: Create virtual environment + run: | + source $HOME/.cargo/env + uv venv + - name: Check mlstacks compatibility + run: | + source .venv/bin/activate + source $HOME/.cargo/env + uv pip install -e . + uv pip install mlstacks - name: Check for broken dependencies - run: pip check + run: | + source .venv/bin/activate + source $HOME/.cargo/env + uv pip check mysql-db-migration-testing: runs-on: arc-runner-set env: diff --git a/.github/workflows/unit-test.yml b/.github/workflows/unit-test.yml index 24c85f2a369..4e4560e40a4 100644 --- a/.github/workflows/unit-test.yml +++ b/.github/workflows/unit-test.yml @@ -116,4 +116,4 @@ jobs: run: |- zenml integration list uv pip list - pip check || true + uv pip check || true diff --git a/.github/workflows/update-templates-to-examples.yml b/.github/workflows/update-templates-to-examples.yml index ee4683c29dd..0d02152757a 100644 --- a/.github/workflows/update-templates-to-examples.yml +++ b/.github/workflows/update-templates-to-examples.yml @@ -246,3 +246,75 @@ jobs: repo: context.repo.repo, body: 'Quickstart template updates in `examples/quickstart` have been pushed.' }) + update-llm-finetuning-template-to-examples: + name: update-llm-finetuning-template-to-examples + runs-on: ${{ inputs.os }} + env: + ZENML_DEBUG: 1 + ZENML_ANALYTICS_OPT_IN: false + PYTHONIOENCODING: utf-8 + OBJC_DISABLE_INITIALIZE_FORK_SAFETY: 'YES' + if: github.event_name == 'pull_request' && ! startsWith(github.event.head_commit.message, + 'GitBook:') + defaults: + run: + shell: bash + steps: + - name: Run template tests for zenml-io/template-llm-finetuning + uses: zenml-io/template-llm-finetuning/.github/actions/llm_finetuning_template_test@main + with: + python-version: ${{ inputs.python-version }} + stack-name: local + ref-zenml: ${{ github.ref }} + ref-template: 2024.03.18 # Make sure it is aligned with ZENML_PROJECT_TEMPLATES from src/zenml/cli/base.py + - name: Clean-up + run: | + rm -rf ./local_checkout + - name: message-on-error + if: failure() + run: | + echo "::error title=zenml-io/template-llm-finetuning project template testing failed with new version of ZenML core!::\ + Breaking changes affecting templates have been introduced. To mitigate this issue,\ + please make the code in zenml-io/template-llm-finetuning compatible with new version of\ + ZenML core, release it and update release tag in zenml.cli.base.ZENML_PROJECT_TEMPLATES" + - uses: actions/checkout@v4.1.1 + with: + ref: ${{ github.event.pull_request.head.ref }} + - name: Check-out fresh LLM Finetuning template + run: | + rm -rf examples/llm_finetuning + mkdir -p examples/llm_finetuning + printf 'info@zenml.io' | zenml init --path examples/llm_finetuning --template llm_finetuning --template-with-defaults + pip install yamlfix + bash scripts/format.sh examples/llm_finetuning + - name: Check for changes + id: check_changes + run: | + if git diff --quiet "origin/${{ github.event.pull_request.head.ref }}"; then + echo "No active Git changes found." + echo "changes=false" >> $GITHUB_OUTPUT + else + echo "vvv Active Git changes found vvv" + echo "changes=true" >> $GITHUB_OUTPUT + git diff "origin/${{ github.event.pull_request.head.ref }}" + fi + - name: Commit and push template + if: steps.check_changes.outputs.changes == 'true' + run: | + git config --global user.name "GitHub Actions" + git config --global user.email "actions@github.com" + git add . + git commit -am "Auto-update of LLM Finetuning template" + git push origin HEAD:${{ github.event.pull_request.head.ref }} + - name: Create PR comment + if: steps.check_changes.outputs.changes == 'true' + uses: actions/github-script@v7.0.1 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: |- + github.rest.issues.createComment({ + issue_number: ${{ github.event.pull_request.number }}, + owner: context.repo.owner, + repo: context.repo.repo, + body: 'LLM Finetuning template updates in `examples/llm_finetuning` have been pushed.' + }) diff --git a/.typos.toml b/.typos.toml index 39266b19df2..fba9255c023 100644 --- a/.typos.toml +++ b/.typos.toml @@ -7,6 +7,7 @@ extend-exclude = [ "tests/unit/materializers/test_built_in_materializer.py", "tests/integration/functional/cli/test_pipeline.py", "src/zenml/zen_server/dashboard/", + "examples/llm_finetuning/lit_gpt/" ] [default.extend-identifiers] diff --git a/README.md b/README.md index d699f50859e..d511f6e24eb 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ Projects Showcase

- 🎉 Version 0.55.5 is out. Check out the release notes + 🎉 Version 0.56.2 is out. Check out the release notes here.

diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index 50c69d796ca..673a1a6ca7f 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,4 +1,195 @@ +# 0.56.2 + +This release replaces 0.56.0 and 0.56.1, and fixes the major migration bugs that were in +that yanked release. Please upgrade directly to 0.56.2 and avoid upgrading to +0.56.0 to avoid unexpected migration issues. + +Note that 0.56.0 and 0.56.1 were removed from PyPI due to an issue with the +alembic versions + migration which could affect the database state. This release +fixes that issue. +This release introduces introduces a wide array of new features, enhancements, and bug fixes, with a strong emphasis on elevating the user experience and streamlining machine +learning workflows. Most notably, you can now deploy models using Hugging Face inference endpoints thanks for an open-source community contribution of this model deployer stack component! + +This release also comes with a breaking change to the services +architecture. + +## Breaking Change + +A significant change in this release is the migration of the `Service` (ZenML's technical term for deployment) +registration and deployment from local or remote environments to the ZenML server. +This change will be reflected in an upcoming tab in the dashboard which will +allow users to explore and see the deployed models in the dashboard with their live +status and metadata. This architectural shift also simplifies the model deployer +abstraction and streamlines the model deployment process for users by moving from +limited built-in steps to a more documented and flexible approach. + +Important note: If you have models that you previously deployed with ZenML, you might +want to redeploy them to have them stored in the ZenML server and tracked by ZenML, +ensuring they appear in the dashboard. + +Additionally, the `find_model_server` method now retrieves models (services) from the +ZenML server instead of local or remote deployment environments. As a result, any +usage of `find_model_server` will only return newly deployed models stored in the server. + +It is also no longer recommended to call service functions like `service.start()`. +Instead, use `model_deployer.start_model_server(service_id)`, which will allow ZenML +to update the changed status of the service in the server. + +### Starting a service +**Old syntax:** +```python +from zenml import pipeline, +from zenml.integrations.bentoml.services.bentoml_deployment import BentoMLDeploymentService + +@step +def predictor( + service: BentoMLDeploymentService, +) -> None: + # starting the service + service.start(timeout=10) +``` + +**New syntax:** +```python +from zenml import pipeline +from zenml.integrations.bentoml.model_deployers import BentoMLModelDeployer +from zenml.integrations.bentoml.services.bentoml_deployment import BentoMLDeploymentService + +@step +def predictor( + service: BentoMLDeploymentService, +) -> None: + # starting the service + model_deployer = BentoMLModelDeployer.get_active_model_deployer() + model_deployer.start_model_server(service_id=service.service_id, timeout=10) +``` + +### Enabling continuous deployment + +Instead of replacing the parameter that was used in the `deploy_model` method to replace the +existing service (if it matches the exact same pipeline name and step name without +taking into accounts other parameters or configurations), we now have a new parameter, +`continuous_deployment_mode`, that allows you to enable continuous deployment for +the service. This will ensure that the service is updated with the latest version +if it's on the same pipeline and step and the service is not already running. Otherwise, +any new deployment with different configurations will create a new service. + +```python +from zenml import pipeline, step, get_step_context +from zenml.client import Client + +@step +def deploy_model() -> Optional[MLFlowDeploymentService]: + # Deploy a model using the MLflow Model Deployer + zenml_client = Client() + model_deployer = zenml_client.active_stack.model_deployer + mlflow_deployment_config = MLFlowDeploymentConfig( + name: str = "mlflow-model-deployment-example", + description: str = "An example of deploying a model using the MLflow Model Deployer", + pipeline_name: str = get_step_context().pipeline_name, + pipeline_step_name: str = get_step_context().step_name, + model_uri: str = "runs://model" or "models://", + model_name: str = "model", + workers: int = 1 + mlserver: bool = False + timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT + ) + service = model_deployer.deploy_model(mlflow_deployment_config, continuous_deployment_mode=True) + logger.info(f"The deployed service info: {model_deployer.get_model_server_info(service)}") + return service +``` + + +## Major Features and Enhancements: + +* A new `Huggingface Model Deployer` has been introduced, allowing you to seamlessly +deploy your Huggingface models using ZenML. (Thank you so much @dudeperf3ct for the contribution!) +* Faster Integration and Dependency Management ZenML now leverages the `uv` library, +significantly improving the speed of integration installations and dependency management, +resulting in a more streamlined and efficient workflow. +* Enhanced Logging and Status Tracking Logging have been improved, providing better +visibility into the state of your ZenML services. +* Improved Artifact Store Isolation: ZenML now prevents unsafe operations that access +data outside the scope of the artifact store, ensuring better isolation and security. +* Adding admin user notion for the user accounts and added protection to certain operations +performed via the REST interface to ADMIN-allowed only. +* Rate limiting for login API to prevent abuse and protect the server from potential +security threats. +* The LLM template is now supported in ZenML, allowing you to use the LLM template +for your pipelines. + + +## 🥳 Community Contributions 🥳 + +We'd like to give a special thanks to @dudeperf3ct he contributed to this release +by introducing the Huggingface Model Deployer. We'd also like to thank @moesio-f +for their contribution to this release by adding a new attribute to the `Kaniko` image builder. +Additionally, we'd like to thank @christianversloot for his contributions to this release. + + +## What's Changed +* Upgrading SQLModel to the latest version by @bcdurak in https://github.com/zenml-io/zenml/pull/2452 +* Remove KServe integration by @safoinme in https://github.com/zenml-io/zenml/pull/2495 +* Upgrade migration testing with 0.55.5 by @avishniakov in https://github.com/zenml-io/zenml/pull/2501 +* Relax azure, gcfs and s3 dependencies by @strickvl in https://github.com/zenml-io/zenml/pull/2498 +* Use HTTP forwarded headers to detect the real origin of client devices by @stefannica in https://github.com/zenml-io/zenml/pull/2499 +* Update README.md for quickstart colab link by @strickvl in https://github.com/zenml-io/zenml/pull/2505 +* Add sequential migration tests for MariaDB and MySQL by @strickvl in https://github.com/zenml-io/zenml/pull/2502 +* Huggingface Model Deployer by @dudeperf3ct in https://github.com/zenml-io/zenml/pull/2376 +* Use `uv` to speed up pip installs & the CI in general by @strickvl in https://github.com/zenml-io/zenml/pull/2442 +* Handle corrupted or empty global configuration file by @stefannica in https://github.com/zenml-io/zenml/pull/2508 +* Add admin users notion by @avishniakov in https://github.com/zenml-io/zenml/pull/2494 +* Remove dashboard from gitignore by @safoinme in https://github.com/zenml-io/zenml/pull/2517 +* Colima / Homebrew fix by @strickvl in https://github.com/zenml-io/zenml/pull/2512 +* [HELM] Remove extra environment variable assignment by @wjayesh in https://github.com/zenml-io/zenml/pull/2518 +* Allow installing packages using UV by @schustmi in https://github.com/zenml-io/zenml/pull/2510 +* Additional fields for track events by @bcdurak in https://github.com/zenml-io/zenml/pull/2507 +* Check if environment key is set before deleting in HyperAI orchestrator by @christianversloot in https://github.com/zenml-io/zenml/pull/2511 +* Fix the pagination in the database backup by @stefannica in https://github.com/zenml-io/zenml/pull/2522 +* Bump mlflow to version 2.11.1 by @christianversloot in https://github.com/zenml-io/zenml/pull/2524 +* Add docs for uv installation by @schustmi in https://github.com/zenml-io/zenml/pull/2527 +* Fix bug in HyperAI orchestrator depends_on parallelism by @christianversloot in https://github.com/zenml-io/zenml/pull/2523 +* Upgrade pip in docker images by @schustmi in https://github.com/zenml-io/zenml/pull/2528 +* Fix node selector and other fields for DB job in helm chart by @stefannica in https://github.com/zenml-io/zenml/pull/2531 +* Revert "Upgrading SQLModel to the latest version" by @bcdurak in https://github.com/zenml-io/zenml/pull/2515 +* Add `pod_running_timeout` attribute to `Kaniko` image builder by @moesio-f in https://github.com/zenml-io/zenml/pull/2509 +* Add test to install dashboard script by @strickvl in https://github.com/zenml-io/zenml/pull/2521 +* Sort pipeline namespaces by last run by @schustmi in https://github.com/zenml-io/zenml/pull/2514 +* Add support for LLM template by @schustmi in https://github.com/zenml-io/zenml/pull/2519 +* Rate limiting for login API by @avishniakov in https://github.com/zenml-io/zenml/pull/2484 +* Try/catch for Docker client by @christianversloot in https://github.com/zenml-io/zenml/pull/2513 +* Fix config file in starter guide by @schustmi in https://github.com/zenml-io/zenml/pull/2534 +* Log URL for pipelines and model versions when running a pipeline by @wjayesh in https://github.com/zenml-io/zenml/pull/2506 +* Add security exclude by @schustmi in https://github.com/zenml-io/zenml/pull/2541 +* Update error message around notebook use by @strickvl in https://github.com/zenml-io/zenml/pull/2536 +* Cap `fsspec` for Huggingface integration by @avishniakov in https://github.com/zenml-io/zenml/pull/2542 +* Fix integration materializers' URLs in docs by @strickvl in https://github.com/zenml-io/zenml/pull/2538 +* Bug fix HyperAI orchestrator: Offload scheduled pipeline execution to bash script by @christianversloot in https://github.com/zenml-io/zenml/pull/2535 +* Update `pip check` command to use `uv` by @strickvl in https://github.com/zenml-io/zenml/pull/2520 +* Implemented bitbucket webhook event source by @AlexejPenner in https://github.com/zenml-io/zenml/pull/2481 +* Add ZenMLServiceType and update service registration by @safoinme in https://github.com/zenml-io/zenml/pull/2471 +* Prepare release 0.56.0 by @safoinme in https://github.com/zenml-io/zenml/pull/2546 +* Fix formatting and release workflow by @strickvl in https://github.com/zenml-io/zenml/pull/2549 +* Fix release workflow by @strickvl in https://github.com/zenml-io/zenml/pull/2550 +* Fix pipelines and model links for the cloud dashboard by @wjayesh in https://github.com/zenml-io/zenml/pull/2554 +* Make starlette non-must for client by @avishniakov in https://github.com/zenml-io/zenml/pull/2553 +* Bump MLFlow to version 2.11.2 by @christianversloot in https://github.com/zenml-io/zenml/pull/2552 +* Prepare release 0.56.1 by @avishniakov in https://github.com/zenml-io/zenml/pull/2555 +* Updated neptune documentation by @SiddhantSadangi in https://github.com/zenml-io/zenml/pull/2548 +* 0.56.0 and 0.56.1 in testing by @avishniakov in https://github.com/zenml-io/zenml/pull/2557 +* Only install uv once by @schustmi in https://github.com/zenml-io/zenml/pull/2558 +* Bump MLFlow to version 2.11.3 by @christianversloot in https://github.com/zenml-io/zenml/pull/2559 +* Update docs with warning about pickle materializer insecurity by @avishniakov in https://github.com/zenml-io/zenml/pull/2561 +* Add service table migration by @safoinme in https://github.com/zenml-io/zenml/pull/2563 + +## New Contributors +* @dudeperf3ct made their first contribution in https://github.com/zenml-io/zenml/pull/2376 +* @moesio-f made their first contribution in https://github.com/zenml-io/zenml/pull/2509 +* @SiddhantSadangi made their first contribution in https://github.com/zenml-io/zenml/pull/2548 + +**Full Changelog**: https://github.com/zenml-io/zenml/compare/0.55.5...0.56.2 + # 0.55.5 This patch contains a number of bug fixes and security improvements. diff --git a/docker/base.Dockerfile b/docker/base.Dockerfile index 5da6ab98df4..da53fbe8c69 100644 --- a/docker/base.Dockerfile +++ b/docker/base.Dockerfile @@ -10,6 +10,9 @@ ENV PYTHONFAULTHANDLER=1 \ ARG ZENML_VERSION +# Upgrade pip to the latest version +RUN pip install --upgrade pip + # install the given zenml version (default to latest) RUN pip install zenml${ZENML_VERSION:+==$ZENML_VERSION} diff --git a/docker/zenml-dev.Dockerfile b/docker/zenml-dev.Dockerfile index 6931d4db792..91d8cd43e5c 100644 --- a/docker/zenml-dev.Dockerfile +++ b/docker/zenml-dev.Dockerfile @@ -8,8 +8,7 @@ ENV PYTHONFAULTHANDLER=1 \ PIP_DISABLE_PIP_VERSION_CHECK=1 \ ZENML_DEBUG=1 \ ZENML_LOGGING_VERBOSITY=INFO \ - ZENML_CONTAINER=1 - + ZENML_CONTAINER=1 WORKDIR /zenml @@ -21,5 +20,8 @@ COPY src/zenml/__init__.py ./src/zenml/ ENV ZENML_DEBUG=true \ ZENML_ANALYTICS_OPT_IN=false +# Upgrade pip to the latest version +RUN pip install --upgrade pip + RUN pip install -e . COPY src src \ No newline at end of file diff --git a/docker/zenml-server-dev.Dockerfile b/docker/zenml-server-dev.Dockerfile index d775817edb3..26272853857 100644 --- a/docker/zenml-server-dev.Dockerfile +++ b/docker/zenml-server-dev.Dockerfile @@ -8,7 +8,9 @@ ENV PYTHONFAULTHANDLER=1 \ PIP_DISABLE_PIP_VERSION_CHECK=1 \ ZENML_DEBUG=1 \ ZENML_LOGGING_VERBOSITY=INFO \ - ZENML_CONTAINER=1 + ZENML_CONTAINER=1 \ + ZENML_SERVER_RATE_LIMIT_ENABLED=1 \ + ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE=100 ARG USERNAME=zenml ARG USER_UID=1000 @@ -31,6 +33,8 @@ COPY README.md pyproject.toml ./ # copying our source files which would invalidate caching COPY src/zenml/__init__.py ./src/zenml/ +# Upgrade pip to the latest version +RUN pip install --upgrade pip RUN pip install -e .[server,secrets-aws,secrets-gcp,secrets-azure,secrets-hashicorp,s3fs,gcsfs,adlfs,connectors-aws,connectors-gcp,connectors-azure] COPY src src diff --git a/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-docker.md b/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-docker.md index 07148bc2eb7..d9103dfecf2 100644 --- a/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-docker.md +++ b/docs/book/deploying-zenml/zenml-self-hosted/deploy-with-docker.md @@ -44,6 +44,9 @@ The following environment variables can be passed to the container: * **ZENML\_STORE\_SSL\_VERIFY\_SERVER\_CERT**: This boolean variable controls whether the SSL certificate in use by the MySQL server is verified. Only valid when `ZENML_STORE_URL` points to a MySQL database that uses SSL-secured connections. Defaults to `False`. * **ZENML\_LOGGING\_VERBOSITY**: Use this variable to control the verbosity of logs inside the container. It can be set to one of the following values: `NOTSET`, `ERROR`, `WARN`, `INFO` (default), `DEBUG` or `CRITICAL`. * **ZENML\_STORE\_BACKUP\_STRATEGY**: This variable controls the database backup strategy used by the ZenML server. See the [Database backup and recovery](#database-backup-and-recovery) section for more details about this feature and other related environment variables. Defaults to `in-memory`. +* **ZENML\_SERVER\_RATE\_LIMIT\_ENABLED**: This variable controls the rate limiting for ZenML API (currently only for the `LOGIN` endpoint). It is disabled by default, so set it to `1` only if you need to enable rate limiting. To determine unique users a `X_FORWARDED_FOR` header or `request.client.host` is used, so before enabling this make sure that your network configuration is associating proper information with your clients in order to avoid disruptions for legitimate requests. +* **ZENML\_SERVER\_LOGIN\_RATE\_LIMIT\_MINUTE**: If rate limiting is enabled, this variable controls how many requests will be allowed to query the login endpoint in a one minute interval. Set it to a desired integer value; defaults to `5`. +* **ZENML\_SERVER\_LOGIN\_RATE\_LIMIT\_DAY**: If rate limiting is enabled, this variable controls how many requests will be allowed to query the login endpoint in an interval of day interval. Set it to a desired integer value; defaults to `1000`. If none of the `ZENML_STORE_*` variables are set, the container will default to creating and using an SQLite database file stored at `/zenml/.zenconfig/local_stores/default_zen_store/zenml.db` inside the container. The `/zenml/.zenconfig/local_stores` base path where the default SQLite database is located can optionally be overridden by setting the `ZENML_LOCAL_STORES_PATH` environment variable to point to a different path (e.g. a persistent volume or directory that is mounted from the host). diff --git a/docs/book/stacks-and-components/component-guide/experiment-trackers/neptune.md b/docs/book/stacks-and-components/component-guide/experiment-trackers/neptune.md index 0b5f8d5bc61..7625ec0187e 100644 --- a/docs/book/stacks-and-components/component-guide/experiment-trackers/neptune.md +++ b/docs/book/stacks-and-components/component-guide/experiment-trackers/neptune.md @@ -168,29 +168,6 @@ def tf_trainer(...): ``` {% endhint %} -### Neptune UI - -Neptune comes with a web-based UI that you can use to find further details about -your tracked experiments. Each pipeline run will be logged as a separate -experiment run in Neptune, which you can inspect in the Neptune UI: - -![Neptune UI](../../../.gitbook/assets/NeptuneUI.png) - -You can find the URL of the Neptune experiment linked to a specific ZenML run -via the metadata of the step in which the experiment tracker was used: - -```python -from zenml.client import Client - -last_run = client.get_pipeline("").last_run -trainer_step = last_run.get_step("") -tracking_url = trainer_step.run_metadata["experiment_tracker_url"].value -print(tracking_url) -``` - -Alternatively, you can see an overview of all experiment runs at -https://app.neptune.ai/{ACCOUNT_USERNAME}/{PROJECT_NAME}. - #### Additional configuration You can pass a set of tags to the Neptune run by using the `NeptuneExperimentTrackerSettings` class, like in the example @@ -226,5 +203,17 @@ def my_step( ... ``` +### Neptune UI + +Neptune comes with a web-based UI that you can use to find further details about +your tracked experiments. Each pipeline run will be logged as a separate +experiment run in Neptune, which you can inspect in the Neptune UI. + +You can find the URL of the Neptune run linked to a specific ZenML run printed on the console whenever a Neptune run is initialized. + +### Further reading + +Check [Neptune's docs](https://docs.neptune.ai/integrations/zenml/) for further information on how to use this integration and Neptune in general. +
ZenML Scarf
diff --git a/docs/book/stacks-and-components/component-guide/image-builders/kaniko.md b/docs/book/stacks-and-components/component-guide/image-builders/kaniko.md index dcb26b9960a..6aaad396c2a 100644 --- a/docs/book/stacks-and-components/component-guide/image-builders/kaniko.md +++ b/docs/book/stacks-and-components/component-guide/image-builders/kaniko.md @@ -34,6 +34,7 @@ To use the Kaniko image builder, we need: transfer the build context by storing it in the artifact store, you need to register it with the `store_context_in_artifact_store` attribute set to `True`. In this case, you also need a [remote artifact store](../artifact-stores/artifact-stores.md) as part of your stack. +* Optionally, you can change the timeout (in seconds) until the Kaniko pod is running in the orchestrator using the `pod_running_timeout` attribute. We can then register the image builder and use it in our active stack: @@ -41,6 +42,7 @@ We can then register the image builder and use it in our active stack: zenml image-builder register \ --flavor=kaniko \ --kubernetes_context= + [ --pod_running_timeout= ] # Register and activate a stack with the new image builder zenml stack register -i ... --set diff --git a/docs/book/stacks-and-components/component-guide/model-deployers/custom.md b/docs/book/stacks-and-components/component-guide/model-deployers/custom.md index e33646e7eb5..df7a09efd94 100644 --- a/docs/book/stacks-and-components/component-guide/model-deployers/custom.md +++ b/docs/book/stacks-and-components/component-guide/model-deployers/custom.md @@ -16,7 +16,7 @@ When present in a stack, the model deployer can also act as a registry for model In ZenML, the base abstraction of the model deployer is built on top of three major criteria: -1. It needs to contain all the stack-related configuration attributes required to interact with the remote model serving tool, service, or platform (e.g. hostnames, URLs, references to credentials, and other client-related configuration parameters). +1. It needs to ensure efficient deployment and management of models in accordance with the specific requirements of the serving infrastructure, by holding all the stack-related configuration attributes required to interact with the remote model serving tool, service, or platform. 2. It needs to implement the continuous deployment logic necessary to deploy models in a way that updates an existing model server that is already serving a previous version of the same model instead of creating a new model server for every new model version (see the `deploy_model` abstract method). This functionality can be consumed directly from ZenML pipeline steps, but it can also be used outside the pipeline to deploy ad-hoc models. It is also usually coupled with a standard model deployer step, implemented by each integration, that hides the details of the deployment process from the user. 3. It needs to act as a ZenML BaseService registry, where every BaseService instance is used as an internal representation of a remote model server (see the `find_model_server` abstract method). To achieve this, it must be able to re-create the configuration of a BaseService from information that is persisted externally, alongside, or even as part of the remote model server configuration itself. For example, for model servers that are implemented as Kubernetes resources, the BaseService instances can be serialized and saved as Kubernetes resource annotations. This allows the model deployer to keep track of all externally running model servers and to re-create their corresponding BaseService instance representations at any given time. The model deployer also defines methods that implement basic life-cycle management on remote model servers outside the coverage of a pipeline (see `stop_model_server` , `start_model_server` and `delete_model_server`). @@ -42,11 +42,11 @@ class BaseModelDeployer(StackComponent, ABC): """Base class for all ZenML model deployers.""" @abstractmethod - def deploy_model( - self, - config: ServiceConfig, - replace: bool = False, - timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + def perform_deploy_model( + self, + id: UUID, + config: ServiceConfig, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, ) -> BaseService: """Abstract method to deploy a model.""" @@ -59,43 +59,28 @@ class BaseModelDeployer(StackComponent, ABC): properties for the user.""" @abstractmethod - def find_model_server( - self, - running: bool = False, - service_uuid: Optional[UUID] = None, - pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, - ) -> List[BaseService]: - """Abstract method to find one or more model servers that match the - given criteria.""" - - @abstractmethod - def stop_model_server( - self, - uuid: UUID, - timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, - force: bool = False, - ) -> None: + def perform_stop_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> BaseService: """Abstract method to stop a model server.""" @abstractmethod - def start_model_server( - self, - uuid: UUID, - timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, - ) -> None: + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: """Abstract method to start a model server.""" @abstractmethod - def delete_model_server( - self, - uuid: UUID, - timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, - force: bool = False, + def perform_delete_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, ) -> None: """Abstract method to delete a model server.""" @@ -143,6 +128,7 @@ If you want to create your own custom flavor for a model deployer, you can follo 1. Create a class that inherits from the `BaseModelDeployer` class and implements the abstract methods. 2. If you need to provide any configuration, create a class that inherits from the `BaseModelDeployerConfig` class and add your configuration parameters. 3. Bring both the implementation and the configuration together by inheriting from the `BaseModelDeployerFlavor` class. Make sure that you give a `name` to the flavor through its abstract property. +4. Create a service class that inherits from the `BaseService` class and implements the abstract methods. This class will be used to represent the deployed model server in ZenML. Once you are done with the implementation, you can register it through the CLI. Please ensure you **point to the flavor class via dot notation**: diff --git a/docs/book/stacks-and-components/component-guide/model-deployers/mlflow.md b/docs/book/stacks-and-components/component-guide/model-deployers/mlflow.md index 026effbdee1..214fdd693cb 100644 --- a/docs/book/stacks-and-components/component-guide/model-deployers/mlflow.md +++ b/docs/book/stacks-and-components/component-guide/model-deployers/mlflow.md @@ -52,53 +52,98 @@ the background to serve the latest MLflow model. ### Deploy a logged model -ZenML provides a predefined `mlflow_model_deployer_step` that you can use to -deploy an MLflfow prediction service based on a model that you have -previously logged in your -[MLflow experiment tracker](../experiment-trackers/mlflow.md): +Following [MLflow's documentation](https://mlflow.org/docs/latest/deployment/deploy-model-locally.html#deploy-mlflow-model-as-a-local-inference-server), if we want to deploy a model as a local inference server, we need the model to be +logged in the MLflow experiment tracker first. Once the model is logged, we can use the model URI either from the +artifact path saved with the MLflow run or using model name and version if a model is registered in the MLflow model +registry. + +In the following examples, we will show how to deploy a model using the MLflow Model Deployer, in two different scenarios: + +1. We already know the logged model URI and we want to deploy it as a local inference server. ```python -from zenml import pipeline -from zenml.integrations.mlflow.steps import mlflow_model_deployer_step +from zenml import pipeline, step, get_step_context +from zenml.client import Client -@pipeline -def mlflow_train_deploy_pipeline(): - model = ... - deployed_model = mlflow_model_deployer_step(model=model) +@step +def deploy_model() -> Optional[MLFlowDeploymentService]: + # Deploy a model using the MLflow Model Deployer + zenml_client = Client() + model_deployer = zenml_client.active_stack.model_deployer + mlflow_deployment_config = MLFlowDeploymentConfig( + name: str = "mlflow-model-deployment-example", + description: str = "An example of deploying a model using the MLflow Model Deployer", + pipeline_name: str = get_step_context().pipeline_name, + pipeline_step_name: str = get_step_context().step_name, + model_uri: str = "runs://model" or "models://", + model_name: str = "model", + workers: int = 1 + mlserver: bool = False + timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT + ) + service = model_deployer.deploy_model(mlflow_deployment_config) + logger.info(f"The deployed service info: {model_deployer.get_model_server_info(service)}") + return service ``` -{% hint style="warning" %} -The `mlflow_model_deployer_step` expects that the `model` it receives has -already been logged to MLflow in a previous step. E.g., for a scikit-learn -model, you would need to have used `mlflow.sklearn.autolog()` or -`mlflow.sklearn.log_model(model)` in a previous step. See the -[MLflow experiment tracker documentation](../experiment-trackers/mlflow.md) for -more information on how to log models to MLflow from your ZenML steps. -{% endhint %} +2. We don't know the logged model URI, since the model was logged in a previous step. We want to deploy the model as a local inference server. ZenML provides set of functionalities that would make it easier to get the model URI from the current run and deploy it. -### Deploy from model registry +```python +from zenml import pipeline, step, get_step_context +from zenml.client import Client +from mlflow.tracking import MlflowClient, artifact_utils -Alternatively, if you are already using the -[MLflow model registry](../model-registries/mlflow.md), you can use the -`mlflow_model_registry_deployer_step` to directly deploy an MLflow prediction -service based on a model in your model registry: -```python -from zenml import pipeline -from zenml.integrations.mlflow.steps import mlflow_model_registry_deployer_step - -@pipeline -def mlflow_registry_deploy_pipeline(): - deployed_model = mlflow_model_registry_deployer_step( - registry_model_name="tensorflow-mnist-model", - registry_model_version="1", # Either specify a model version - # or use the model stage if you have set it in the MLflow registry: - # registered_model_stage="Staging" +@step +def deploy_model() -> Optional[MLFlowDeploymentService]: + # Deploy a model using the MLflow Model Deployer + zenml_client = Client() + model_deployer = zenml_client.active_stack.model_deployer + experiment_tracker = zenml_client.active_stack.experiment_tracker + # Let's get the run id of the current pipeline + mlflow_run_id = experiment_tracker.get_run_id( + experiment_name=get_step_context().pipeline_name, + run_name=get_step_context().run_name, + ) + # Once we have the run id, we can get the model URI using mlflow client + experiment_tracker.configure_mlflow() + client = MlflowClient() + model_name = "model" # set the model name that was logged + model_uri = artifact_utils.get_artifact_uri( + run_id=mlflow_run_id, artifact_path=model_name + ) + mlflow_deployment_config = MLFlowDeploymentConfig( + name: str = "mlflow-model-deployment-example", + description: str = "An example of deploying a model using the MLflow Model Deployer", + pipeline_name: str = get_step_context().pipeline_name, + pipeline_step_name: str = get_step_context().step_name, + model_uri: str = model_uri, + model_name: str = model_name, + workers: int = 1, + mlserver: bool = False, + timeout: int = 300, ) + service = model_deployer.deploy_model(mlflow_deployment_config) + return service ``` -See the [MLflow model registry documentation](../model-registries/mlflow.md) -for more information on how to register models in the MLflow registry. +#### Configuration + +Within the `MLFlowDeploymentService` you can configure: + +* `name`: The name of the deployment. +* `description`: The description of the deployment. +* `pipeline_name`: The name of the pipeline that deployed the MLflow prediction server. +* `pipeline_step_name`: The name of the step that deployed the MLflow prediction server. +* `model_name`: The name of the model that is deployed in case of model registry the name must be a valid registered model name. +* `model_version`: The version of the model that is deployed in case of model registry the version must be a valid registered model version. +* `silent_daemon`: set to True to suppress the output of the daemon +(i.e., redirect stdout and stderr to /dev/null). If False, the daemon output will be redirected to a log file. +* `blocking`: set to True to run the service in the context of the current process and block until the service is stopped instead of running the service as a daemon process. Useful for operating systems that do not support daemon processes. +* `model_uri`: The URI of the model to be deployed. This can be a local file path, a run ID, or a model name and version. +* `workers`: The number of workers to be used by the MLflow prediction server. +* `mlserver`: If True, the MLflow prediction server will be started as a MLServer instance. +* `timeout`: The timeout in seconds to wait for the MLflow prediction server to start or stop. ### Run inference on a deployed model @@ -106,7 +151,11 @@ The following code example shows how you can load a deployed model in Python and run inference against it: +1. Load a prediction service deployed in another pipeline + ```python +import json +import requests from zenml import step from zenml.integrations.mlflow.model_deployers.mlflow_model_deployer import ( MLFlowModelDeployer, @@ -119,9 +168,8 @@ from zenml.integrations.mlflow.services import MLFlowDeploymentService def prediction_service_loader( pipeline_name: str, pipeline_step_name: str, - running: bool = True, model_name: str = "model", -) -> MLFlowDeploymentService: +) -> None: """Get the prediction service started by the deployment pipeline. Args: @@ -140,7 +188,6 @@ def prediction_service_loader( pipeline_name=pipeline_name, pipeline_step_name=pipeline_step_name, model_name=model_name, - running=running, ) if not existing_services: @@ -150,8 +197,35 @@ def prediction_service_loader( f"'{model_name}' is currently running." ) - return existing_services[0] + service = existing_services[0] + # Let's try run a inference request against the prediction service + + payload = json.dumps( + { + "inputs": {"messages": [{"role": "user", "content": "Tell a joke!"}]}, + "params": { + "temperature": 0.5, + "max_tokens": 20, + }, + } + ) + response = requests.post( + url=service.get_prediction_url(), + data=payload, + headers={"Content-Type": "application/json"}, + ) + + response.json() +``` + +2. Within the same pipeline, use the service from previous step to run inference this time using pre-built predict method + +```python +from typing_extensions import Annotated +import numpy as np +from zenml import step +from zenml.integrations.mlflow.services import MLFlowDeploymentService # Use the service for inference @step @@ -161,23 +235,10 @@ def predictor( ) -> Annotated[np.ndarray, "predictions"]: """Run a inference request against a prediction service""" - service.start(timeout=10) # should be a NOP if already started prediction = service.predict(data) prediction = prediction.argmax(axis=-1) return prediction - - -@pipeline -def mlflow_deployment_inference_pipeline( - pipeline_name: str, pipeline_step_name: str = "mlflow_model_deployer_step", -): - inference_data = ... - model_deployment_service = prediction_service_loader( - pipeline_name=pipeline_name, - pipeline_step_name=pipeline_step_name, - ) - predictions = predictor(model_deployment_service, inference_data) ``` For more information and a full list of configurable attributes of the MLflow Model Deployer, check out diff --git a/docs/book/stacks-and-components/component-guide/model-deployers/model-deployers.md b/docs/book/stacks-and-components/component-guide/model-deployers/model-deployers.md index 428cfb6a13a..5d73ce44e1a 100644 --- a/docs/book/stacks-and-components/component-guide/model-deployers/model-deployers.md +++ b/docs/book/stacks-and-components/component-guide/model-deployers/model-deployers.md @@ -23,11 +23,9 @@ stored as files or in a database for end users or business applications. ### When to use it? The model deployers are optional components in the ZenML stack. They are used to deploy machine learning models to a -target environment either a development (local) or a production (Kubernetes), the model deployers are mainly used to -deploy models for real-time inference use cases. With the model deployers and other stack components, you can build -pipelines that are continuously trained and deployed to production. +target environment, either a development (local) or a production (Kubernetes or cloud) environment. The model deployers are mainly used to deploy models for real-time inference use cases. With the model deployers and other stack components, you can build pipelines that are continuously trained and deployed to production. -### How they experiment trackers slot into the stack +### How model deployers slot into the stack Here is an architecture diagram that shows how model deployers fit into the overall story of a remote stack. @@ -67,10 +65,7 @@ zenml model-deployer register seldon --flavor=seldon \ #### The role that a model deployer plays in a ZenML Stack -1. Holds all the stack-related configuration attributes required to interact with the remote model serving tool, - service, or platform (e.g. hostnames, URLs, references to credentials, and other client-related configuration - parameters). The following are examples of configuring the MLflow and Seldon Core Model Deployers and registering - them as a Stack component: +* Seamless Model Deployment: Facilitates the deployment of machine learning models to various serving environments, such as local servers, Kubernetes clusters, or cloud platforms, ensuring that models can be deployed and managed efficiently in accordance with the specific requirements of the serving infrastructure by holds all the stack-related configuration attributes required to interact with the remote model serving tool, service, or platform (e.g. hostnames, URLs, references to credentials, and other client-related configuration parameters). The following are examples of configuring the MLflow and Seldon Core Model Deployers and registering them as a Stack component: ```bash zenml integration install mlflow @@ -87,46 +82,52 @@ zenml model-deployer register seldon --flavor=seldon \ zenml stack register seldon_stack -m default -a aws -o default -d seldon ``` -2. Implements the continuous deployment logic necessary to deploy models in a way that updates an existing model server - that is already serving a previous version of the same model instead of creating a new model server for every new - model version. Every model server that the Model Deployer provisions externally to deploy a model is represented - internally as a `Service` object that may be accessed for visibility and control over a single model deployment. This - functionality can be consumed directly from ZenML pipeline steps, but it can also be used outside the pipeline to - deploy ad-hoc models. See the [seldon_model_deployer_step](https://sdkdocs.zenml.io/latest/integration_code_docs/integrations-seldon/#zenml.integrations.seldon.steps.seldon_deployer.seldon_model_deployer_step) for an example of using the Seldon Core Model Deployer to deploy a model inside a ZenML pipeline step. -3. Acts as a registry for all Services that represent remote model servers. External model deployment servers can be - listed and filtered using a variety of criteria, such as the name of the model or the names of the pipeline and step - that was used to deploy the model. The Service objects returned by the Model Deployer can be used to interact with - the remote model server, e.g. to get the operational status of a model server, the prediction URI that it exposes, or - to stop or delete a model server: +* Lifecycle Management: Provides mechanisms for comprehensive lifecycle management of model servers, including the ability to start, stop, and delete model servers, as well as to update existing servers with new model versions, thereby optimizing resource utilization and facilitating continuous delivery of model updates. Some core methods that can be used to interact with the remote model server include: + +`deploy_model` - Deploys a model to the serving environment and returns a Service object that represents the deployed model server. +`find_model_server` - Finds and returns a list of Service objects that represent model servers that have been deployed to the serving environment, the +services are stored in the DB and can be used as a reference to know what and where the model is deployed. +`stop_model_server` - Stops a model server that is currently running in the serving environment. +`start_model_server` - Starts a model server that has been stopped in the serving environment. +`delete_model_server` - Deletes a model server from the serving environment and from the DB. + +{% hint style="info" %} +ZenML uses the Service object to represent a model server that has been deployed to a serving environment. The Service object is saved in the DB and can be used as a reference to know what and where the model is deployed. The Service object consists of 2 main attributes, the `config` and the `status`. The `config` attribute holds all the deployment configuration attributes required to create a new deployment, while the `status` attribute holds the operational status of the deployment, such as the last error message, the prediction URL, and the deployment status. +{% endhint %} ```python - from zenml.integrations.seldon.model_deployers import SeldonModelDeployer + from zenml.integrations.huggingface.model_deployers import HuggingFaceModelDeployer - model_deployer = SeldonModelDeployer.get_active_model_deployer() + model_deployer = HuggingFaceModelDeployer.get_active_model_deployer() services = model_deployer.find_model_server( - pipeline_name="continuous-deployment-pipeline", - pipeline_step_name="seldon_model_deployer_step", - model_name="my-model", + pipeline_name="LLM_pipeline", + pipeline_step_name="huggingface_model_deployer_step", + model_name="LLAMA-7B", ) if services: if services[0].is_running: print( - f"Seldon deployment service started and reachable at:\n" - f" {services[0].prediction_url}\n" - ) - elif services[0].is_failed: - print( - f"Seldon deployment service is in a failure state. " - f"The last error message was: {services[0].status.last_error}" - ) - else: - print(f"Seldon deployment service is not running") - - # start the service - services[0].start(timeout=100) - - # delete the service - model_deployer.delete_service(services[0].uuid, timeout=100, force=False) + f"Model server {services[0].config['model_name']} is running at {services[0].status['prediction_url']}" + ) + else: + print(f"Model server {services[0].config['model_name']} is not running") + model_deployer.start_model_server(services[0]) + else: + print("No model server found") + service = model_deployer.deploy_model( + pipeline_name="LLM_pipeline", + pipeline_step_name="huggingface_model_deployer_step", + model_name="LLAMA-7B", + model_uri="s3://zenprojects/huggingface_model_deployer_step/output/884/huggingface", + revision="main", + task="text-classification", + region="us-east-1", + vendor="aws", + token="huggingface_token", + namespace="zenml-workloads", + endpoint_type="public", + ) + print(f"Model server {service.config['model_name']} is deployed at {service.status['prediction_url']}") ``` #### How to Interact with a model deployer after deployment? @@ -187,20 +188,6 @@ deployer_step = pipeline_run.steps[""] deployed_model_url = deployer_step.run_metadata["deployed_model_url"].value ``` -Services can be passed through steps like any other object, and used to interact with the external systems that they -represent: - -```python -from zenml import step - - -@step -def my_step(my_service: MyService) -> ...: - if not my_service.is_running: - my_service.start() # starts service - my_service.stop() # stops service -``` - The ZenML integrations that provide Model Deployer stack components also include standard pipeline steps that can directly be inserted into any pipeline to achieve a continuous model deployment workflow. These steps take care of all the aspects of continuously deploying models to an external server and saving the Service configuration into the diff --git a/docs/book/toc.md b/docs/book/toc.md index 136e440e1b6..f1b6dfbc2e9 100644 --- a/docs/book/toc.md +++ b/docs/book/toc.md @@ -171,6 +171,7 @@ * [🚀 Quickstart](https://github.com/zenml-io/zenml/blob/main/examples/quickstart) * [🔏 End-to-End Batch Inference](https://github.com/zenml-io/zenml/tree/main/examples/e2e) * [📚 Basic NLP with BERT](https://github.com/zenml-io/zenml/tree/main/examples/e2e_nlp) +* [📖 LLM Finetuning](https://github.com/zenml-io/zenml/tree/main/examples/llm_finetuning) * [🧩 More Projects...](https://github.com/zenml-io/zenml-projects) ## Reference diff --git a/docs/book/user-guide/advanced-guide/data-management/artifact-versioning.md b/docs/book/user-guide/advanced-guide/data-management/artifact-versioning.md index 199a48ea356..3675c34fc11 100644 --- a/docs/book/user-guide/advanced-guide/data-management/artifact-versioning.md +++ b/docs/book/user-guide/advanced-guide/data-management/artifact-versioning.md @@ -32,6 +32,12 @@ By tracking the lineage of artifacts across environments and stacks, ZenML enabl Materializers are designed to be extensible and customizable, allowing you to define your own serialization and deserialization logic for specific data types or storage systems. By default, ZenML provides built-in materializers for common data types and uses `cloudpickle` to pickle objects where there is no default materializer. If you want direct control over how objects are serialized, you can easily create custom materializers by extending the `BaseMaterializer` class and implementing the required methods for your specific use case. Read more about materializers [here](handle-custom-data-types.md). +{% hint style="warning" %} +ZenML provides a built-in [CloudpickleMaterializer](https://sdkdocs.zenml.io/latest/core\_code\_docs/core-materializers/#zenml.materializers.cloudpickle\_materializer.CloudpickleMaterializer) that can handle any object by saving it with [cloudpickle](https://github.com/cloudpipe/cloudpickle). However, this is not production-ready because the resulting artifacts cannot be loaded when running with a different Python version. In such cases, you should consider building a [custom Materializer](handle-custom-data-types.md#custom-materializers) to save your objects in a more robust and efficient format. + +Moreover, using the `CloudpickleMaterializer` could allow users to upload of any kind of object. This could be exploited to upload a malicious file, which could execute arbitrary code on the vulnerable system. +{% endhint %} + When a pipeline runs, ZenML uses the appropriate materializers to save and load artifacts using the ZenML `fileio` system (built to work across multiple artifact stores). This not only simplifies the process of working with different data formats and storage systems but also enables artifact caching and lineage tracking. You can see an example of a default materializer (the `numpy` materializer) in action [here](https://github.com/zenml-io/zenml/blob/main/src/zenml/materializers/numpy\_materializer.py). diff --git a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md index 360fb2349f9..c925359c55a 100644 --- a/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md +++ b/docs/book/user-guide/advanced-guide/data-management/handle-custom-data-types.md @@ -14,15 +14,17 @@ ZenML already includes built-in materializers for many common data types. These
MaterializerHandled Data TypesStorage Format
BuiltInMaterializerbool, float, int, str, None.json
BytesInMaterializerbytes.txt
BuiltInContainerMaterializerdict, list, set, tupleDirectory
NumpyMaterializernp.ndarray.npy
PandasMaterializerpd.DataFrame, pd.Series.csv (or .gzip if parquet is installed)
PydanticMaterializerpydantic.BaseModel.json
ServiceMaterializerzenml.services.service.BaseService.json
StructuredStringMaterializerzenml.types.CSVString, zenml.types.HTMLString, zenml.types.MarkdownString.csv / .html / .md (depending on type)
-{% hint style="info" %} -ZenML also provides a built-in [CloudpickleMaterializer](https://sdkdocs.zenml.io/latest/core\_code\_docs/core-materializers/#zenml.materializers.cloudpickle\_materializer.CloudpickleMaterializer) that can handle any object by saving it with [cloudpickle](https://github.com/cloudpipe/cloudpickle). However, this is not production-ready because the resulting artifacts cannot be loaded when running with a different Python version. In such cases, you should consider building a [custom Materializer](handle-custom-data-types.md#custom-materializers) to save your objects in a more robust and efficient format. +{% hint style="warning" %} +ZenML provides a built-in [CloudpickleMaterializer](https://sdkdocs.zenml.io/latest/core\_code\_docs/core-materializers/#zenml.materializers.cloudpickle\_materializer.CloudpickleMaterializer) that can handle any object by saving it with [cloudpickle](https://github.com/cloudpipe/cloudpickle). However, this is not production-ready because the resulting artifacts cannot be loaded when running with a different Python version. In such cases, you should consider building a [custom Materializer](handle-custom-data-types.md#custom-materializers) to save your objects in a more robust and efficient format. + +Moreover, using the `CloudpickleMaterializer` could allow users to upload of any kind of object. This could be exploited to upload a malicious file, which could execute arbitrary code on the vulnerable system. {% endhint %} ## Integration Materializers -In addition to the built-in materializers, ZenML also provides several integration-specific materializers that can be activated by installing the respective [integration](../../component-guide/integration-overview.md): +In addition to the built-in materializers, ZenML also provides several integration-specific materializers that can be activated by installing the respective [integration](../../../stacks-and-components/component-guide/integration-overview.md): -
IntegrationMaterializerHandled Data TypesStorage Format
bentomlBentoMaterializerbentoml.Bento.bento
deepchecksDeepchecksResultMateriailzerdeepchecks.CheckResult, deepchecks.SuiteResult.json
evidentlyEvidentlyProfileMaterializerevidently.Profile.json
great_expectationsGreatExpectationsMaterializergreat_expectations.ExpectationSuite, great_expectations.CheckpointResult.json
huggingfaceHFDatasetMaterializerdatasets.Dataset, datasets.DatasetDictDirectory
huggingfaceHFPTModelMaterializertransformers.PreTrainedModelDirectory
huggingfaceHFTFModelMaterializertransformers.TFPreTrainedModelDirectory
huggingfaceHFTokenizerMaterializertransformers.PreTrainedTokenizerBaseDirectory
lightgbmLightGBMBoosterMaterializerlgbm.Booster.txt
lightgbmLightGBMDatasetMaterializerlgbm.Dataset.binary
neural_prophetNeuralProphetMaterializerNeuralProphet.pt
pillowPillowImageMaterializerPillow.Image.PNG
polarsPolarsMaterializerpl.DataFrame, pl.Series.parquet
pycaretPyCaretMaterializerAny sklearn, xgboost, lightgbm or catboost model.pkl
pytorchPyTorchDataLoaderMaterializertorch.Dataset, torch.DataLoader.pt
pytorchPyTorchModuleMaterializertorch.Module.pt
scipySparseMaterializerscipy.spmatrix.npz
sparkSparkDataFrameMaterializerpyspark.DataFrame.parquet
sparkSparkModelMaterializerpyspark.Transformerpyspark.Estimator
tensorflowKerasMaterializertf.keras.ModelDirectory
tensorflowTensorflowDatasetMaterializertf.DatasetDirectory
whylogsWhylogsMaterializerwhylogs.DatasetProfileView.pb
xgboostXgboostBoosterMaterializerxgb.Booster.json
xgboostXgboostDMatrixMaterializerxgb.DMatrix.binary
+
IntegrationMaterializerHandled Data TypesStorage Format
bentomlBentoMaterializerbentoml.Bento.bento
deepchecksDeepchecksResultMateriailzerdeepchecks.CheckResult, deepchecks.SuiteResult.json
evidentlyEvidentlyProfileMaterializerevidently.Profile.json
great_expectationsGreatExpectationsMaterializergreat_expectations.ExpectationSuite, great_expectations.CheckpointResult.json
huggingfaceHFDatasetMaterializerdatasets.Dataset, datasets.DatasetDictDirectory
huggingfaceHFPTModelMaterializertransformers.PreTrainedModelDirectory
huggingfaceHFTFModelMaterializertransformers.TFPreTrainedModelDirectory
huggingfaceHFTokenizerMaterializertransformers.PreTrainedTokenizerBaseDirectory
lightgbmLightGBMBoosterMaterializerlgbm.Booster.txt
lightgbmLightGBMDatasetMaterializerlgbm.Dataset.binary
neural_prophetNeuralProphetMaterializerNeuralProphet.pt
pillowPillowImageMaterializerPillow.Image.PNG
polarsPolarsMaterializerpl.DataFrame, pl.Series.parquet
pycaretPyCaretMaterializerAny sklearn, xgboost, lightgbm or catboost model.pkl
pytorchPyTorchDataLoaderMaterializertorch.Dataset, torch.DataLoader.pt
pytorchPyTorchModuleMaterializertorch.Module.pt
scipySparseMaterializerscipy.spmatrix.npz
sparkSparkDataFrameMaterializerpyspark.DataFrame.parquet
sparkSparkModelMaterializerpyspark.Transformerpyspark.Estimator
tensorflowKerasMaterializertf.keras.ModelDirectory
tensorflowTensorflowDatasetMaterializertf.DatasetDirectory
whylogsWhylogsMaterializerwhylogs.DatasetProfileView.pb
xgboostXgboostBoosterMaterializerxgb.Booster.json
xgboostXgboostDMatrixMaterializerxgb.DMatrix.binary
{% hint style="warning" %} If you are running pipelines with a Docker-based [orchestrator](../../component-guide/orchestrators/orchestrators.md), you need to specify the corresponding integration as `required_integrations` in the `DockerSettings` of your pipeline in order to have the integration materializer available inside your Docker container. See the [pipeline configuration documentation](../pipelining-features/pipeline-settings.md) for more information. @@ -683,4 +685,4 @@ if __name__ == "__main__": -
ZenML Scarf
\ No newline at end of file +
ZenML Scarf
diff --git a/docs/book/user-guide/advanced-guide/infrastructure-management/containerize-your-pipeline.md b/docs/book/user-guide/advanced-guide/infrastructure-management/containerize-your-pipeline.md index 388b89c54be..09859f812ff 100644 --- a/docs/book/user-guide/advanced-guide/infrastructure-management/containerize-your-pipeline.md +++ b/docs/book/user-guide/advanced-guide/infrastructure-management/containerize-your-pipeline.md @@ -192,7 +192,7 @@ def my_pipeline(...): ... ``` -* Specify a list of pip requirements in code: +* Specify a list of requirements in code: ```python docker_settings = DockerSettings(requirements=["torch==1.12.0", "torchvision"]) @@ -201,7 +201,7 @@ def my_pipeline(...): def my_pipeline(...): ... ``` -* Specify a pip requirements file: +* Specify a requirements file: ```python docker_settings = DockerSettings(requirements="/path/to/requirements.txt") @@ -253,7 +253,7 @@ def my_training_step(...): ``` {% hint style="info" %} -You can combine these methods but do make sure that your list of pip requirements does not overlap with the ones specified explicitly in the docker settings. +You can combine these methods but do make sure that your list of requirements does not overlap with the ones specified explicitly in the Docker settings. {% endhint %} Depending on the options specified in your Docker settings, ZenML installs the requirements in the following order (each step optional): @@ -262,6 +262,20 @@ Depending on the options specified in your Docker settings, ZenML installs the r * The packages specified via the `requirements` attribute (step level overwrites pipeline level) * The packages specified via the `required_integrations` and potentially stack requirements +* **Experimental**: If you want to use [`uv`](https://github.com/astral-sh/uv) for faster resolving and installation of your Python packages, you can use by it as follows: + +```python +docker_settings = DockerSettings(python_package_installer="uv") + +@pipeline(settings={"docker": docker_settings}) +def my_pipeline(...): + ... +``` +{% hint style="info" %} +`uv` is a relatively new project and not as stable as `pip` yet, which might lead to errors during package installation. +If this happens, try switching the installer back to `pip` and see if that solves the issue. +{% endhint %} + ### Using a custom parent image By default, ZenML performs all the steps described above on top of the [official ZenML image](https://hub.docker.com/r/zenmldocker/zenml/) for the Python and ZenML version in the active Python environment. To have more control over the entire environment used to execute your pipelines, you can either specify a custom pre-built parent image or a Dockerfile that ZenML uses to build a parent image for you. diff --git a/docs/book/user-guide/advanced-guide/pipelining-features/managing-steps.md b/docs/book/user-guide/advanced-guide/pipelining-features/managing-steps.md index 6d305f04a98..f29fd83feaf 100644 --- a/docs/book/user-guide/advanced-guide/pipelining-features/managing-steps.md +++ b/docs/book/user-guide/advanced-guide/pipelining-features/managing-steps.md @@ -13,6 +13,12 @@ Your functions will work as ZenML steps even if you don't provide any type annot * **Type validation of your step inputs**: ZenML makes sure that your step functions receive an object of the correct type from the upstream steps in your pipeline. * **Better serialization**: Without type annotations, ZenML uses [Cloudpickle](https://github.com/cloudpipe/cloudpickle) to serialize your step outputs. When provided with type annotations, ZenML can choose a [materializer](../../../getting-started/core-concepts.md#materializers) that is best suited for the output. In case none of the builtin materializers work, you can even [write a custom materializer](../data-management/handle-custom-data-types.md). +{% hint style="warning" %} +ZenML provides a built-in [CloudpickleMaterializer](https://sdkdocs.zenml.io/latest/core\_code\_docs/core-materializers/#zenml.materializers.cloudpickle\_materializer.CloudpickleMaterializer) that can handle any object by saving it with [cloudpickle](https://github.com/cloudpipe/cloudpickle). However, this is not production-ready because the resulting artifacts cannot be loaded when running with a different Python version. In such cases, you should consider building a [custom Materializer](handle-custom-data-types.md#custom-materializers) to save your objects in a more robust and efficient format. + +Moreover, using the `CloudpickleMaterializer` could allow users to upload of any kind of object. This could be exploited to upload a malicious file, which could execute arbitrary code on the vulnerable system. +{% endhint %} + ```python from typing import Tuple from zenml import step diff --git a/docs/book/user-guide/starter-guide/create-an-ml-pipeline.md b/docs/book/user-guide/starter-guide/create-an-ml-pipeline.md index eaa21d2599c..2e5910b0257 100644 --- a/docs/book/user-guide/starter-guide/create-an-ml-pipeline.md +++ b/docs/book/user-guide/starter-guide/create-an-ml-pipeline.md @@ -260,10 +260,8 @@ check them into git history. A simple version of such a YAML file could be: ```yaml -steps: - svc_trainer: - parameters: - gamma: 0.01 +parameters: + gamma: 0.01 ``` Please note that this would take precedence over any parameters passed in the code. diff --git a/examples/llm_finetuning/.assets/model.png b/examples/llm_finetuning/.assets/model.png new file mode 100644 index 00000000000..12b95c69d01 Binary files /dev/null and b/examples/llm_finetuning/.assets/model.png differ diff --git a/examples/llm_finetuning/.copier-answers.yml b/examples/llm_finetuning/.copier-answers.yml new file mode 100644 index 00000000000..d87c3c5df5b --- /dev/null +++ b/examples/llm_finetuning/.copier-answers.yml @@ -0,0 +1,15 @@ +# Changes here will be overwritten by Copier +_commit: 2024.03.18 +_src_path: gh:zenml-io/template-llm-finetuning +cuda_version: cuda11.8 +email: '' +from_safetensors: false +full_name: ZenML GmbH +huggingface_adapter_model_repository: '' +huggingface_merged_model_repository: '' +model_repository: mistralai/Mistral-7B-Instruct-v0.1 +open_source_license: apache +product_name: llm_lora +project_name: ZenML LLM Finetuning project +version: 0.1.0 +zenml_server_url: '' diff --git a/examples/llm_finetuning/.dockerignore b/examples/llm_finetuning/.dockerignore new file mode 100644 index 00000000000..496552c8c5f --- /dev/null +++ b/examples/llm_finetuning/.dockerignore @@ -0,0 +1,9 @@ +* +!/pipelines/** +!/steps/** +!/materializers/** +!/evaluate/** +!/finetune/** +!/generate/** +!/lit_gpt/** +!/scripts/** diff --git a/examples/llm_finetuning/LICENSE b/examples/llm_finetuning/LICENSE new file mode 100644 index 00000000000..75d01fb4544 --- /dev/null +++ b/examples/llm_finetuning/LICENSE @@ -0,0 +1,15 @@ +Apache Software License 2.0 + +Copyright (c) ZenML GmbH 2024. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/examples/llm_finetuning/README.md b/examples/llm_finetuning/README.md new file mode 100644 index 00000000000..9426738597f --- /dev/null +++ b/examples/llm_finetuning/README.md @@ -0,0 +1,128 @@ +# ☮️ Fine-tuning open source LLMs using MLOps pipelines + +Welcome to your newly generated "ZenML LLM Finetuning project" project! This is +a great way to get hands-on with ZenML using production-like template. +The project contains a collection of ZenML steps, pipelines and other artifacts +and useful resources that can serve as a solid starting point for finetuning open-source LLMs using ZenML. + +Using these pipelines, we can run the data-preparation and model finetuning with a single command while using YAML files for [configuration](https://docs.zenml.io/user-guide/production-guide/configure-pipeline) and letting ZenML take care of tracking our metadata and [containerizing our pipelines](https://docs.zenml.io/user-guide/advanced-guide/infrastructure-management/containerize-your-pipeline). + +
+
+ + Model version metadata + +
+
+ +## :earth_americas: Inspiration and Credit + +This project heavily relies on the [Lit-GPT project](https://github.com/Lightning-AI/litgpt) of the amazing people at Lightning AI. We used [this blogpost](https://lightning.ai/pages/community/lora-insights/#toc14) to get started with LoRA and QLoRA and modified the commands they recommend to make them work using ZenML. + +## 🏃 How to run + +In this project we provide a few predefined configuration files for finetuning models on the [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) dataset. Before we're able to run any pipeline, we need to set up our environment as follows: + +```bash +# Set up a Python virtual environment, if you haven't already +python3 -m venv .venv +source .venv/bin/activate + +# Install requirements +pip install -r requirements.txt +``` + +### Combined feature engineering and finetuning pipeline + +The easiest way to get started with just a single command is to run the finetuning pipeline with the `finetune-alpaca.yaml` configuration file, which will do both feature engineering and finetuning: + +```shell +python run.py --finetuning-pipeline --config finetune-alpaca.yaml +``` + +When running the pipeline like this, the trained adapter will be stored in the ZenML artifact store. You can optionally upload the adapter, the merged model or both by specifying the `adapter_output_repo` and `merged_output_repo` parameters in the configuration file. + + +### Evaluation pipeline + +Before running this pipeline, you will need to fill in the `adapter_repo` in the `eval.yaml` configuration file. This should point to a huggingface repository that contains the finetuned adapter you got by running the finetuning pipeline. + +```shell +python run.py --eval-pipeline --config eval.yaml +``` + +### Merging pipeline + +In case you have trained an adapter using the finetuning pipeline, you can merge it with the base model by filling in the `adapter_repo` and `output_repo` parameters in the `merge.yaml` file, and then running: + +```shell +python run.py --merge-pipeline --config merge.yaml +``` + +### Feature Engineering followed by Finetuning + +If you want to finetune your model on a different dataset, you can do so by running the feature engineering pipeline followed by the finetuning pipeline. To define your dataset, take a look at the `scripts/prepare_*` scripts and set the dataset name in the `feature-alpaca.yaml` config file. + +```shell +python run.py --feature-pipeline --config feature-alpaca.yaml +python run.py --finetuning-pipeline --config finetune-from-dataset.yaml +``` + +## ☁️ Running with a remote stack + +To finetune an LLM on remote infrastructure, you can either use a remote orchestrator or a remote step operator. Follow these steps to set up a complete remote stack: +- Register the [orchestrator](https://docs.zenml.io/stacks-and-components/component-guide/orchestrators) (or [step operator](https://docs.zenml.io/stacks-and-components/component-guide/step-operators)) and make sure to configure it in a way so that the finetuning step has access to a GPU with at least 24GB of VRAM. Check out our docs for more [details](https://docs.zenml.io/stacks-and-components/component-guide). + - To access GPUs with this amount of VRAM, you might need to increase your GPU quota ([AWS](https://docs.aws.amazon.com/servicequotas/latest/userguide/request-quota-increase.html), [GCP](https://console.cloud.google.com/iam-admin/quotas), [Azure](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-manage-quotas?view=azureml-api-2#request-quota-and-limit-increases)). + - The GPU instance that your finetuning will be running on will have CUDA drivers of a specific version installed. If that CUDA version is not compatible with the one provided by the default Docker image of the finetuning pipeline, you will need to modify it in the configuration file. See [here](https://hub.docker.com/r/pytorch/pytorch/tags) for a list of available PyTorch images. + - If you're running out of memory, you can experiment with quantized LoRA (QLoRA) by setting a different value for the `quantize` parameter in the configuration, or reduce the `global_batch_size`/`micro_batch_size`. +- Register a remote [artifact store](https://docs.zenml.io/stacks-and-components/component-guide/artifact-stores) and [container registry](https://docs.zenml.io/stacks-and-components/component-guide/container-registries). +- Register a stack with all these components + ```shell + zenml stack register llm-finetuning-stack -o \ + -a \ + -c \ + [-s ] + ``` + +## 💾 Running with custom data + +To finetune a model with your custom data, you will need to convert it to a CSV file with the columns described +[here](https://github.com/Lightning-AI/litgpt/blob/main/tutorials/prepare_dataset.md#preparing-custom-datasets-from-a-csv-file). + +Next, update the `configs/feature-custom.yaml` file and set the value of the `csv_path` parameter to that CSV file. +With all that in place, you can now run the feature engineering pipeline to convert your CSV into the correct format for training and then run the finetuning pipeline as follows: +```shell +python run.py --feature-pipeline --config feature-custom.yaml +python run.py --finetuning-pipeline --config finetune-from-dataset.yaml +``` + +## 📜 Project Structure + +The project loosely follows [the recommended ZenML project structure](https://docs.zenml.io/user-guide/starter-guide/follow-best-practices): + +``` +. +├── configs # pipeline configuration files +│ ├── eval.yaml # configuration for the evaluation pipeline +│ ├── feature-alpaca.yaml # configuration for the feature engineering pipeline +│ ├── feature-custom.yaml # configuration for the feature engineering pipeline +│ ├── finetune-alpaca.yaml # configuration for the finetuning pipeline +│ ├── finetune-from-dataset.yaml # configuration for the finetuning pipeline +│ └── merge.yaml # configuration for the merging pipeline +├── pipelines # `zenml.pipeline` implementations +│ ├── evaluate.py # Evaluation pipeline +│ ├── feature_engineering.py # Feature engineering pipeline +│ ├── finetuning.py # Finetuning pipeline +│ └── merge.py # Merging pipeline +├── steps # logically grouped `zenml.steps` implementations +│ ├── evaluate.py # evaluate model performance +│ ├── feature_engineering.py # preprocess data +│ ├── finetune.py # finetune a model +│ ├── merge.py # merge model and adapter +│ ├── params.py # shared parameters for steps +│ └── utils.py # utility functions +├── .dockerignore +├── README.md # this file +├── requirements.txt # extra Python dependencies +└── run.py # CLI tool to run pipelines on ZenML Stack +``` diff --git a/examples/llm_finetuning/configs/eval.yaml b/examples/llm_finetuning/configs/eval.yaml new file mode 100644 index 00000000000..03a73adbd7f --- /dev/null +++ b/examples/llm_finetuning/configs/eval.yaml @@ -0,0 +1,21 @@ +model: + name: llm_lora-Mistral-7B-Instruct-v0.1 + description: "Fine-tune `mistralai/Mistral-7B-Instruct-v0.1`." + tags: + - llm + - lora + - mistralai/Mistral-7B-Instruct-v0.1 + +settings: + docker: + parent_image: pytorch/pytorch:2.2.0-cuda11.8-cudnn8-runtime + +steps: + evaluate: + enable_step_logs: False + parameters: + config: + model_repo: mistralai/Mistral-7B-Instruct-v0.1 + from_safetensors: False + adapter_repo: + precision: bf16-true \ No newline at end of file diff --git a/examples/llm_finetuning/configs/feature-alpaca.yaml b/examples/llm_finetuning/configs/feature-alpaca.yaml new file mode 100644 index 00000000000..2a0ca7f85c4 --- /dev/null +++ b/examples/llm_finetuning/configs/feature-alpaca.yaml @@ -0,0 +1,16 @@ +model: + name: llm_lora-Mistral-7B-Instruct-v0.1 + description: "Fine-tune `mistralai/Mistral-7B-Instruct-v0.1`." + tags: + - llm + - lora + - mistralai/Mistral-7B-Instruct-v0.1 + - alpaca + +steps: + feature_engineering: + enable_step_logs: False + parameters: + config: + model_repo: mistralai/Mistral-7B-Instruct-v0.1 + dataset_name: alpaca diff --git a/examples/llm_finetuning/configs/feature-custom.yaml b/examples/llm_finetuning/configs/feature-custom.yaml new file mode 100644 index 00000000000..6611ede26be --- /dev/null +++ b/examples/llm_finetuning/configs/feature-custom.yaml @@ -0,0 +1,19 @@ +model: + name: llm_lora-Mistral-7B-Instruct-v0.1 + description: "Fine-tune `mistralai/Mistral-7B-Instruct-v0.1`." + tags: + - llm + - lora + - mistralai/Mistral-7B-Instruct-v0.1 + +steps: + feature_engineering: + enable_step_logs: False + parameters: + config: + model_repo: mistralai/Mistral-7B-Instruct-v0.1 + dataset_name: csv + prepare_kwargs: + # REQUIRED: Path the the .csv file containing the data. Format must be as described here + # https://github.com/Lightning-AI/litgpt/blob/main/tutorials/prepare_dataset.md#preparing-custom-datasets-from-a-csv-file + csv_path: null diff --git a/examples/llm_finetuning/configs/finetune-alpaca.yaml b/examples/llm_finetuning/configs/finetune-alpaca.yaml new file mode 100644 index 00000000000..6e9a4502c69 --- /dev/null +++ b/examples/llm_finetuning/configs/finetune-alpaca.yaml @@ -0,0 +1,35 @@ +model: + name: llm_lora-Mistral-7B-Instruct-v0.1 + description: "Fine-tune `mistralai/Mistral-7B-Instruct-v0.1`." + tags: + - llm + - lora + - mistralai/Mistral-7B-Instruct-v0.1 + - alpaca + +settings: + docker: + parent_image: pytorch/pytorch:2.2.0-cuda11.8-cudnn8-runtime + +steps: + finetune: + # Uncomment and set value to use a step operator for this step + # step_operator: + enable_step_logs: False + parameters: + config: + base_model_repo: mistralai/Mistral-7B-Instruct-v0.1 + from_safetensors: False + precision: bf16-true + quantize: bnb.nf4 # Enable quantization with 4-bit normal float + # OPTIONAL: Configure Huggingface repository to which the merged model should be pushed + # merged_output_repo: + # OPTIONAL: Configure Huggingface repository to which the adapter should be pushed + # adapter_output_repo: + training: + save_interval: 1 + epochs: 5 + epoch_size: 50000 + global_batch_size: 128 + micro_batch_size: 4 + learning_rate: 3e-4 diff --git a/examples/llm_finetuning/configs/finetune-from-dataset.yaml b/examples/llm_finetuning/configs/finetune-from-dataset.yaml new file mode 100644 index 00000000000..653bcbecaa2 --- /dev/null +++ b/examples/llm_finetuning/configs/finetune-from-dataset.yaml @@ -0,0 +1,33 @@ +parameters: + dataset_artifact_name: dataset + +model: + name: llm_lora-Mistral-7B-Instruct-v0.1 + version: latest + +settings: + docker: + parent_image: pytorch/pytorch:2.2.0-cuda11.8-cudnn8-runtime + +steps: + finetune: + # Uncomment and set value to use a step operator for this step + # step_operator: + enable_step_logs: False + parameters: + config: + base_model_repo: mistralai/Mistral-7B-Instruct-v0.1 + from_safetensors: False + precision: bf16-true + quantize: bnb.nf4 # Enable quantization with 4-bit normal float + # OPTIONAL: Configure Huggingface repository to which the merged model should be pushed + # merged_output_repo: + # OPTIONAL: Configure Huggingface repository to which the adapter should be pushed + # adapter_output_repo: + training: + save_interval: 1 + epochs: 5 + epoch_size: 50000 + global_batch_size: 128 + micro_batch_size: 4 + learning_rate: 3e-4 diff --git a/examples/llm_finetuning/configs/merge.yaml b/examples/llm_finetuning/configs/merge.yaml new file mode 100644 index 00000000000..2f7326ebe7d --- /dev/null +++ b/examples/llm_finetuning/configs/merge.yaml @@ -0,0 +1,19 @@ +model: + name: llm_lora-Mistral-7B-Instruct-v0.1 + description: "Fine-tune `mistralai/Mistral-7B-Instruct-v0.1`." + tags: + - llm + - lora + - mistralai/Mistral-7B-Instruct-v0.1 + +steps: + merge: + parameters: + config: + base_model_repo: mistralai/Mistral-7B-Instruct-v0.1 + from_safetensors: False + # REQUIRED: Huggingface repository in which to adapter is stored + adapter_repo: null + # REQUIRED: Huggingface repository to which the merged model should be pushed + output_repo: null + precision: bf16-true \ No newline at end of file diff --git a/examples/llm_finetuning/evaluate/lm_eval_harness.py b/examples/llm_finetuning/evaluate/lm_eval_harness.py new file mode 100644 index 00000000000..6f90c19f14e --- /dev/null +++ b/examples/llm_finetuning/evaluate/lm_eval_harness.py @@ -0,0 +1,231 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +import sys +from pathlib import Path +from typing import Dict, List, Literal, Optional + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision +from lm_eval import base, evaluator, tasks +from lm_eval.base import BaseLM + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.base import generate +from lit_gpt import GPT, Config, Tokenizer +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + get_default_supported_precision, + load_checkpoint, +) + + +class EvalHarnessBase(BaseLM): + # Credits: + # https://github.com/EleutherAI/gpt-neox/blob/main/eval_tasks/eval_adapter.py + def __init__( + self, + fabric: L.Fabric, + model: GPT, + tokenizer: Tokenizer, + batch_size: int, + ): + super().__init__() + self.fabric = fabric + self.model = model + self.tokenizer = tokenizer + self.batch_size_per_gpu = batch_size + with fabric.init_tensor(): + model.set_kv_cache(batch_size=batch_size) + + @classmethod + def create_from_arg_string(cls, arg_string, additional_config=None): + kwargs = { + el.split("=")[0]: el.split("=")[1] for el in arg_string.split(",") + } + return cls(**kwargs, **additional_config) + + @property + def eot_token_id(self): + # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* + return self.tokenizer.eos_id + + @property + def max_length(self): + return self.model.max_seq_length + + @property + def vocab_size(self): + return self.tokenizer.vocab_size + + @property + def max_gen_toks(self): + return 256 + + @property + def batch_size(self): + return self.batch_size_per_gpu * self.fabric.world_size + + @property + def device(self): + return self.fabric.device + + def tok_encode(self, string: str) -> List[int]: + return self.tokenizer.encode(string, bos=False, eos=False).tolist() + + def tok_decode(self, tokens: List[int]) -> str: + t = torch.tensor(tokens) + return self.tokenizer.decode(t) + + @torch.inference_mode() + def _model_call(self, inps): + return self.model(inps) + + @torch.inference_mode() + def _model_generate( + self, context, max_length, eos_token_id + ) -> torch.Tensor: + # this only supports batch size 1 + assert context.shape[0] == 1 + out = generate(self.model, context[0], max_length, eos_id=eos_token_id) + for block in self.model.transformer.h: + block.attn.kv_cache.reset_parameters() + return out.unsqueeze(0) + + @torch.inference_mode() + def run_eval( + self, + eval_tasks: List[str], + num_fewshot: int, + limit: Optional[int], + bootstrap_iters: int, + no_cache: bool, + ) -> Dict: + # Returns a list containing all values of the task registry that + # match at least one of the patterns + import fnmatch + + def pattern_match(patterns, source_list): + task_names = set() + for pattern in patterns: + for matching in fnmatch.filter(source_list, pattern): + task_names.add(matching) + return list(task_names) + + eval_tasks = pattern_match(eval_tasks, tasks.ALL_TASKS) + print(f"Found tasks: {eval_tasks}") + + # **HACK INCOMING**: + # first get task dict on local main rank + # the tasks are downloaded *as they are initialized*, and the downloads don't like multithreading. + # so we download them once on the local main rank, wait, and then initialize them on all other ranks, which *should* load from the cache. + if self.fabric.local_rank == 0: + tasks.get_task_dict(eval_tasks) + # torch barrier + self.fabric.barrier() + tasks.get_task_dict(eval_tasks) + + lm = self + if not no_cache: + lm = base.CachingLM(lm, "lm_cache/lit-gpt.db") + + results = evaluator.evaluate( + lm=lm, + task_dict=tasks.get_task_dict(eval_tasks), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + ) + results["config"] = dict( + model=self.model.config.name, + batch_size=self.batch_size, + device=str(self.device), + num_fewshot=num_fewshot, + limit=limit, + bootstrap_iters=bootstrap_iters, + no_cache=no_cache, + ) + return results + + +@torch.inference_mode() +def run_eval_harness( + checkpoint_dir: Path, + precision: Optional[str] = None, + quantize: Optional[ + Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"] + ] = None, + eval_tasks: List[str] = [ + "arc_challenge", + "piqa", + "hellaswag", + "hendrycksTest-*", + ], + save_filepath: Optional[Path] = None, + num_fewshot: int = 0, + limit: Optional[int] = None, + bootstrap_iters: int = 100000, + no_cache: bool = True, +): + if precision is None: + precision = get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + + check_valid_checkpoint_dir(checkpoint_dir) + tokenizer = Tokenizer(checkpoint_dir) + + config = Config.from_json(checkpoint_dir / "lit_config.json") + + checkpoint_path = checkpoint_dir / "lit_model.pth" + + print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", + file=sys.stderr, + ) + with fabric.init_module(empty_init=True): + model = GPT(config) + + model.eval() + model = fabric.setup_module(model) + + load_checkpoint(fabric, model, checkpoint_path) + + eval_harness = EvalHarnessBase(fabric, model, tokenizer, 1) + + results = eval_harness.run_eval( + eval_tasks, num_fewshot, limit, bootstrap_iters, no_cache + ) + if save_filepath is None: + print(results) + else: + print(f"Saving results to {str(save_filepath)!r}") + save_filepath.parent.mkdir(parents=True, exist_ok=True) + data = json.dumps(results) + with open(save_filepath, "w") as fw: + fw.write(data) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(run_eval_harness) diff --git a/examples/llm_finetuning/finetune/adapter.py b/examples/llm_finetuning/finetune/adapter.py new file mode 100644 index 00000000000..acf8f6d414d --- /dev/null +++ b/examples/llm_finetuning/finetune/adapter.py @@ -0,0 +1,451 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import dataclasses +import os +import sys +import time +from pathlib import Path +from typing import Dict, List, Literal, Optional, Tuple + +import lightning as L +import torch +from lightning.fabric.loggers import CSVLogger +from lightning.fabric.plugins import BitsandbytesPrecision +from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.utilities import ThroughputMonitor + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.base import generate +from lit_gpt.adapter import ( + GPT, + Block, + Config, + adapter_filter, + mark_only_adapter_as_trainable, +) +from lit_gpt.args import EvalArgs, IOArgs, TrainArgs +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + chunked_cross_entropy, + get_default_supported_precision, + load_checkpoint, + num_parameters, +) + +from scripts.prepare_alpaca import generate_prompt + + +def setup( + precision: Optional[str] = None, + quantize: Optional[ + Literal[ + "bnb.nf4", + "bnb.nf4-dq", + "bnb.fp4", + "bnb.fp4-dq", + "bnb.int8-training", + ] + ] = None, + devices: int = 1, + io: IOArgs = IOArgs( + train_data_dir=Path("data/alpaca"), + val_data_dir=Path("data/alpaca"), + checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + out_dir=Path("out/adapter/alpaca"), + ), + train: TrainArgs = TrainArgs( + save_interval=1000, + log_interval=1, + global_batch_size=64, + micro_batch_size=4, + lr_warmup_steps=100, + epochs=5, + epoch_size=50000, + learning_rate=1e-3, + max_seq_length=None, + ), + eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), +) -> None: + print(locals()) + precision = precision or get_default_supported_precision(training=True) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + if devices > 1: + if quantize: + raise NotImplementedError( + "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the" + " --quantize flag." + ) + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + logger = CSVLogger( + io.out_dir.parent, + io.out_dir.name, + flush_logs_every_n_steps=train.log_interval, + ) + fabric = L.Fabric( + devices=devices, + strategy=strategy, + precision=precision, + loggers=logger, + plugins=plugins, + ) + fabric.launch( + main, + devices, + Config.from_name(name=io.checkpoint_dir.name), + io, + train, + eval, + ) + + +def main( + fabric: L.Fabric, + devices: int, + config: Config, + io: IOArgs, + train: TrainArgs, + eval: EvalArgs, +) -> None: + validate_args(io, train, eval) + + steps_per_epoch = train.epoch_size // devices // train.batch_size(devices) + lr_max_steps = train.epochs * steps_per_epoch + + check_valid_checkpoint_dir(io.checkpoint_dir) + + fabric.seed_everything( + 1337 + ) # same seed for every process to init model (FSDP) + + if fabric.global_rank == 0: + os.makedirs(io.out_dir, exist_ok=True) + + train_data = torch.load(io.train_data_dir / "train.pt") + val_data = torch.load(io.val_data_dir / "test.pt") + + checkpoint_path = io.checkpoint_dir / "lit_model.pth" + fabric.print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}" + ) + with fabric.init_module(empty_init=(devices > 1)): + model = GPT(config) + mark_only_adapter_as_trainable(model) + + fabric.print( + f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}" + ) + fabric.print( + f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}" + ) + + model = fabric.setup_module(model) + + trainable_params = [p for p in model.parameters() if p.requires_grad] + if isinstance(fabric.strategy.precision, BitsandbytesPrecision): + import bitsandbytes as bnb + + optimizer_cls = bnb.optim.PagedAdamW + else: + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + trainable_params, + lr=train.learning_rate, + weight_decay=train.weight_decay, + betas=(train.beta1, train.beta2), + ) + optimizer = fabric.setup_optimizers(optimizer) + scheduler = get_lr_scheduler( + optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps + ) + + # strict=False because missing keys due to Adapter weights not contained in state dict + load_checkpoint(fabric, model, checkpoint_path, strict=False) + + fabric.seed_everything(1337 + fabric.global_rank) + + train_time = time.perf_counter() + fit( + fabric, + model, + optimizer, + scheduler, + train_data, + val_data, + devices, + io, + train, + eval, + ) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" + ) + + # Save the final checkpoint at the end of training + save_path = io.out_dir / "lit_model_adapter_finetuned.pth" + save_adapter_checkpoint(fabric, model, save_path) + + +def fit( + fabric: L.Fabric, + model: GPT, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + train_data: List[Dict], + val_data: List[Dict], + devices: int, + io: IOArgs, + train: TrainArgs, + eval: EvalArgs, +) -> None: + tokenizer = Tokenizer(io.checkpoint_dir) + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data) + model.max_seq_length = min( + longest_seq_length, train.max_seq_length or float("inf") + ) + fabric.print( + f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" + f" {model.max_seq_length} and context length is {model.config.block_size}" + ) + + validate( + fabric, + model, + val_data, + tokenizer, + dataclasses.replace(eval, max_iters=2), + train, + ) # sanity check + + throughput = ThroughputMonitor(fabric, window_size=50) + step_count = 0 + total_lengths = 0 + total_t0 = time.perf_counter() + + for iter_num in range(1, train.max_iters(devices) + 1): + iter_t0 = time.perf_counter() + + input_ids, targets = get_batch( + fabric, + train_data, + train.micro_batch_size, + train.max_seq_length, + longest_seq_ix if iter_num == 1 else None, + ) + + is_accumulating = ( + iter_num % train.gradient_accumulation_iters(devices) != 0 + ) + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids, lm_head_chunk_size=128) + # shift the targets such that output n predicts token n+1 + logits[-1] = logits[-1][..., :-1, :] + loss = chunked_cross_entropy(logits, targets[..., 1:]) + fabric.backward(loss / train.gradient_accumulation_iters(devices)) + + if not is_accumulating: + optimizer.step() + optimizer.zero_grad() + scheduler.step() + step_count += 1 + + total_lengths += input_ids.numel() + if iter_num % train.log_interval == 0: + loss_item = loss.item() # expensive device-to-host synchronization + t1 = time.perf_counter() + throughput.update( + time=t1 - total_t0, + batches=iter_num, + samples=iter_num * train.micro_batch_size, + lengths=total_lengths, + ) + throughput.compute_and_log(step=iter_num) + fabric.print( + f"iter {iter_num} | step {step_count}: loss {loss_item:.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f} ms{' (optimizer.step)' if not is_accumulating else ''}" + ) + + if not is_accumulating and step_count % eval.interval == 0: + t0 = time.perf_counter() + val_loss = validate( + fabric, model, val_data, tokenizer, eval, train + ) + t1 = time.perf_counter() - t0 + fabric.print( + f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms" + ) + fabric.barrier() + if not is_accumulating and step_count % train.save_interval == 0: + checkpoint_path = io.out_dir / f"iter-{iter_num:06d}-ckpt.pth" + save_adapter_checkpoint(fabric, model, checkpoint_path) + + +# the adapter "kv cache" cannot be initialized under `inference_mode` +@torch.no_grad() +def validate( + fabric: L.Fabric, + model: GPT, + val_data: List[Dict], + tokenizer: Tokenizer, + eval: EvalArgs, + train: TrainArgs, +) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + losses = torch.zeros(eval.max_iters) + for k in range(eval.max_iters): + input_ids, targets = get_batch( + fabric, val_data, train.micro_batch_size, train.max_seq_length + ) + logits = model(input_ids) + losses[k] = chunked_cross_entropy( + logits[..., :-1, :], targets[..., 1:], chunk_size=0 + ) + val_loss = losses.mean() + + # produce an example: + instruction = "Recommend a movie for me to watch during the weekend and explain the reason." + fabric.print(instruction) + sample = {"instruction": instruction, "input": ""} + prompt = generate_prompt(sample) + encoded = tokenizer.encode(prompt, device=fabric.device) + with fabric.init_tensor(): + # do not set `max_seq_length=max_returned_token` because memory is not a concern here + model.set_kv_cache(batch_size=1) + output = generate( + model, + encoded, + max_returned_tokens=len(encoded) + eval.max_new_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, + ) + model.clear_kv_cache() + output = tokenizer.decode(output) + fabric.print(output) + + model.train() + return val_loss + + +def get_batch( + fabric: L.Fabric, + data: List[Dict], + micro_batch_size: int, + max_seq_length: Optional[int], + longest_seq_ix: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + ix = torch.randint(len(data), (micro_batch_size,)) + if longest_seq_ix is not None: + # force the longest sample at the beginning so potential OOMs happen right away + ix[0] = longest_seq_ix + + input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] + labels = [data[i]["labels"].type(torch.int64) for i in ix] + + # this could be `longest_seq_length` to have a fixed size for all batches + max_len = max(len(s) for s in input_ids) + + def pad_right(x, pad_id): + # pad right based on the longest sequence + n = max_len - len(x) + return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) + + x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) + y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) + + # Truncate if needed + if max_seq_length: + x = x[:, :max_seq_length] + y = y[:, :max_seq_length] + + if fabric.device.type == "cuda" and x.device.type == "cpu": + x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) + else: + x, y = fabric.to_device((x, y)) + return x, y + + +def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): + # linear warmup followed by cosine annealing + scheduler1 = torch.optim.lr_scheduler.LambdaLR( + optimizer, lambda step: step / warmup_steps + ) + scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=(max_steps - warmup_steps) + ) + return torch.optim.lr_scheduler.SequentialLR( + optimizer, [scheduler1, scheduler2], milestones=[warmup_steps] + ) + + +def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: + # find out the minimum max_seq_length required during fine-tuning (saves memory!) + lengths = [len(d["input_ids"]) for d in data] + longest_seq_length = max(lengths) + longest_seq_ix = lengths.index(longest_seq_length) + return longest_seq_length, longest_seq_ix + + +def save_adapter_checkpoint( + fabric: L.Fabric, model: torch.nn.Module, file_path: Path +) -> None: + fabric.print(f"Saving adapter weights to {str(file_path)!r}") + fabric.save(file_path, {"model": model}, filter={"model": adapter_filter}) + + +def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: + issues = [] + unsupported = [(train, ["max_tokens", "max_norm"])] + for args, names in unsupported: + for name in names: + if getattr(args, name) is not None: + issues.append( + f"{__file__} doesn't support the {name!r} argument. This is set in {args}" + ) + required = [ + (io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]), + (train, ["epoch_size", "epochs"]), + (eval, ["max_new_tokens"]), + ] + for args, names in required: + for name in names: + if getattr(args, name) is None: + issues.append( + f"{__file__} requires the {name!r} argument. This is set in {args}" + ) + if issues: + raise ValueError("\n".join(issues)) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(setup) diff --git a/examples/llm_finetuning/finetune/adapter_v2.py b/examples/llm_finetuning/finetune/adapter_v2.py new file mode 100644 index 00000000000..ac7de327a49 --- /dev/null +++ b/examples/llm_finetuning/finetune/adapter_v2.py @@ -0,0 +1,451 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import dataclasses +import os +import sys +import time +from pathlib import Path +from typing import Dict, List, Literal, Optional, Tuple + +import lightning as L +import torch +from lightning.fabric.loggers import CSVLogger +from lightning.fabric.plugins import BitsandbytesPrecision +from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.utilities import ThroughputMonitor + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.base import generate +from lit_gpt.adapter_v2 import ( + GPT, + Block, + Config, + adapter_filter, + mark_only_adapter_v2_as_trainable, +) +from lit_gpt.args import EvalArgs, IOArgs, TrainArgs +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + chunked_cross_entropy, + get_default_supported_precision, + load_checkpoint, + num_parameters, +) + +from scripts.prepare_alpaca import generate_prompt + + +def setup( + precision: Optional[str] = None, + quantize: Optional[ + Literal[ + "bnb.nf4", + "bnb.nf4-dq", + "bnb.fp4", + "bnb.fp4-dq", + "bnb.int8-training", + ] + ] = None, + devices: int = 1, + io: IOArgs = IOArgs( + train_data_dir=Path("data/alpaca"), + val_data_dir=Path("data/alpaca"), + checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + out_dir=Path("out/adapter_v2/alpaca"), + ), + train: TrainArgs = TrainArgs( + save_interval=1000, + log_interval=1, + global_batch_size=128, + micro_batch_size=2, + lr_warmup_steps=100, + epochs=5, + epoch_size=50000, + learning_rate=1e-3, + max_seq_length=None, + ), + eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), +) -> None: + print(locals()) + precision = precision or get_default_supported_precision(training=True) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + if devices > 1: + if quantize: + raise NotImplementedError( + "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the" + " --quantize flag." + ) + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + logger = CSVLogger( + io.out_dir.parent, + io.out_dir.name, + flush_logs_every_n_steps=train.log_interval, + ) + fabric = L.Fabric( + devices=devices, + strategy=strategy, + precision=precision, + loggers=logger, + plugins=plugins, + ) + fabric.launch( + main, + devices, + Config.from_name(name=io.checkpoint_dir.name), + io, + train, + eval, + ) + + +def main( + fabric: L.Fabric, + devices: int, + config: Config, + io: IOArgs, + train: TrainArgs, + eval: EvalArgs, +) -> None: + validate_args(io, train, eval) + + steps_per_epoch = train.epoch_size // devices // train.batch_size(devices) + lr_max_steps = train.epochs * steps_per_epoch + + check_valid_checkpoint_dir(io.checkpoint_dir) + + fabric.seed_everything( + 1337 + ) # same seed for every process to init model (FSDP) + + if fabric.global_rank == 0: + os.makedirs(io.out_dir, exist_ok=True) + + train_data = torch.load(io.train_data_dir / "train.pt") + val_data = torch.load(io.val_data_dir / "test.pt") + + checkpoint_path = io.checkpoint_dir / "lit_model.pth" + fabric.print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}" + ) + with fabric.init_module(empty_init=(devices > 1)): + model = GPT(config) + mark_only_adapter_v2_as_trainable(model) + + fabric.print( + f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}" + ) + fabric.print( + f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}" + ) + + model = fabric.setup_module(model) + + trainable_params = [p for p in model.parameters() if p.requires_grad] + if isinstance(fabric.strategy.precision, BitsandbytesPrecision): + import bitsandbytes as bnb + + optimizer_cls = bnb.optim.PagedAdamW + else: + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + trainable_params, + lr=train.learning_rate, + weight_decay=train.weight_decay, + betas=(train.beta1, train.beta2), + ) + optimizer = fabric.setup_optimizers(optimizer) + scheduler = get_lr_scheduler( + optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps + ) + + # strict=False because missing keys due to Adapter weights not contained in state dict + load_checkpoint(fabric, model, checkpoint_path, strict=False) + + fabric.seed_everything(1337 + fabric.global_rank) + + train_time = time.perf_counter() + fit( + fabric, + model, + optimizer, + scheduler, + train_data, + val_data, + devices, + io, + train, + eval, + ) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" + ) + + # Save the final checkpoint at the end of training + save_path = io.out_dir / "lit_model_adapter_finetuned.pth" + save_adapter_v2_checkpoint(fabric, model, save_path) + + +def fit( + fabric: L.Fabric, + model: GPT, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + train_data: List[Dict], + val_data: List[Dict], + devices: int, + io: IOArgs, + train: TrainArgs, + eval: EvalArgs, +) -> None: + tokenizer = Tokenizer(io.checkpoint_dir) + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data) + model.max_seq_length = min( + longest_seq_length, train.max_seq_length or float("inf") + ) + fabric.print( + f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" + f" {model.max_seq_length} and context length is {model.config.block_size}" + ) + + validate( + fabric, + model, + val_data, + tokenizer, + dataclasses.replace(eval, max_iters=2), + train, + ) # sanity check + + throughput = ThroughputMonitor(fabric, window_size=50) + step_count = 0 + total_lengths = 0 + total_t0 = time.perf_counter() + + for iter_num in range(1, train.max_iters(devices) + 1): + iter_t0 = time.perf_counter() + + input_ids, targets = get_batch( + fabric, + train_data, + train.micro_batch_size, + train.max_seq_length, + longest_seq_ix if iter_num == 1 else None, + ) + + is_accumulating = ( + iter_num % train.gradient_accumulation_iters(devices) != 0 + ) + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids, lm_head_chunk_size=128) + # shift the targets such that output n predicts token n+1 + logits[-1] = logits[-1][..., :-1, :] + loss = chunked_cross_entropy(logits, targets[..., 1:]) + fabric.backward(loss / train.gradient_accumulation_iters(devices)) + + if not is_accumulating: + optimizer.step() + optimizer.zero_grad() + scheduler.step() + step_count += 1 + + total_lengths += input_ids.numel() + if iter_num % train.log_interval == 0: + loss_item = loss.item() # expensive device-to-host synchronization + t1 = time.perf_counter() + throughput.update( + time=t1 - total_t0, + batches=iter_num, + samples=iter_num * train.micro_batch_size, + lengths=total_lengths, + ) + throughput.compute_and_log(step=iter_num) + fabric.print( + f"iter {iter_num} | step {step_count}: loss {loss_item:.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f} ms{' (optimizer.step)' if not is_accumulating else ''}" + ) + + if not is_accumulating and step_count % eval.interval == 0: + t0 = time.perf_counter() + val_loss = validate( + fabric, model, val_data, tokenizer, eval, train + ) + t1 = time.perf_counter() - t0 + fabric.print( + f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms" + ) + fabric.barrier() + if not is_accumulating and step_count % train.save_interval == 0: + checkpoint_path = io.out_dir / f"iter-{iter_num:06d}-ckpt.pth" + save_adapter_v2_checkpoint(fabric, model, checkpoint_path) + + +# the adapter "kv cache" cannot be initialized under `inference_mode` +@torch.no_grad() +def validate( + fabric: L.Fabric, + model: GPT, + val_data: List[Dict], + tokenizer: Tokenizer, + eval: EvalArgs, + train: TrainArgs, +) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + losses = torch.zeros(eval.max_iters) + for k in range(eval.max_iters): + input_ids, targets = get_batch( + fabric, val_data, train.micro_batch_size, train.max_seq_length + ) + logits = model(input_ids) + losses[k] = chunked_cross_entropy( + logits[..., :-1, :], targets[..., 1:], chunk_size=0 + ) + val_loss = losses.mean() + + # produce an example: + instruction = "Recommend a movie for me to watch during the weekend and explain the reason." + fabric.print(instruction) + sample = {"instruction": instruction, "input": ""} + prompt = generate_prompt(sample) + encoded = tokenizer.encode(prompt, device=fabric.device) + with fabric.init_tensor(): + # do not set `max_seq_length=max_returned_token` because memory is not a concern here + model.set_kv_cache(batch_size=1) + output = generate( + model, + encoded, + max_returned_tokens=len(encoded) + eval.max_new_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, + ) + model.clear_kv_cache() + output = tokenizer.decode(output) + fabric.print(output) + + model.train() + return val_loss + + +def get_batch( + fabric: L.Fabric, + data: List[Dict], + micro_batch_size: int, + max_seq_length: Optional[int], + longest_seq_ix: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + ix = torch.randint(len(data), (micro_batch_size,)) + if longest_seq_ix is not None: + # force the longest sample at the beginning so potential OOMs happen right away + ix[0] = longest_seq_ix + + input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] + labels = [data[i]["labels"].type(torch.int64) for i in ix] + + # this could be `longest_seq_length` to have a fixed size for all batches + max_len = max(len(s) for s in input_ids) + + def pad_right(x, pad_id): + # pad right based on the longest sequence + n = max_len - len(x) + return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) + + x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) + y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) + + # Truncate if needed + if max_seq_length: + x = x[:, :max_seq_length] + y = y[:, :max_seq_length] + + if fabric.device.type == "cuda" and x.device.type == "cpu": + x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) + else: + x, y = fabric.to_device((x, y)) + return x, y + + +def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): + # linear warmup followed by cosine annealing + scheduler1 = torch.optim.lr_scheduler.LambdaLR( + optimizer, lambda step: step / warmup_steps + ) + scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=(max_steps - warmup_steps) + ) + return torch.optim.lr_scheduler.SequentialLR( + optimizer, [scheduler1, scheduler2], milestones=[warmup_steps] + ) + + +def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: + # find out the minimum max_seq_length required during fine-tuning (saves memory!) + lengths = [len(d["input_ids"]) for d in data] + longest_seq_length = max(lengths) + longest_seq_ix = lengths.index(longest_seq_length) + return longest_seq_length, longest_seq_ix + + +def save_adapter_v2_checkpoint( + fabric: L.Fabric, model: torch.nn.Module, file_path: Path +) -> None: + fabric.print(f"Saving adapter v2 weights to {str(file_path)!r}") + fabric.save(file_path, {"model": model}, filter={"model": adapter_filter}) + + +def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: + issues = [] + unsupported = [(train, ["max_tokens", "max_norm"])] + for args, names in unsupported: + for name in names: + if getattr(args, name) is not None: + issues.append( + f"{__file__} doesn't support the {name!r} argument. This is set in {args}" + ) + required = [ + (io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]), + (train, ["epoch_size", "epochs"]), + (eval, ["max_new_tokens"]), + ] + for args, names in required: + for name in names: + if getattr(args, name) is None: + issues.append( + f"{__file__} requires the {name!r} argument. This is set in {args}" + ) + if issues: + raise ValueError("\n".join(issues)) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(setup) diff --git a/examples/llm_finetuning/finetune/full.py b/examples/llm_finetuning/finetune/full.py new file mode 100644 index 00000000000..02e28a72af3 --- /dev/null +++ b/examples/llm_finetuning/finetune/full.py @@ -0,0 +1,442 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import dataclasses +import math +import os +import sys +import time +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import lightning as L +import torch +from lightning.fabric.loggers import CSVLogger +from lightning.fabric.strategies import FSDPStrategy +from torchmetrics import RunningMean + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.base import generate +from lit_gpt.args import EvalArgs, IOArgs, TrainArgs +from lit_gpt.model import GPT, Block, Config +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + chunked_cross_entropy, + get_default_supported_precision, + load_checkpoint, + num_parameters, +) + +from scripts.prepare_alpaca import generate_prompt + + +def setup( + precision: Optional[str] = None, + devices: int = 1, + resume: Union[bool, Path] = False, + io: IOArgs = IOArgs( + train_data_dir=Path("data/alpaca"), + val_data_dir=Path("data/alpaca"), + checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + out_dir=Path("out/full/alpaca"), + ), + train: TrainArgs = TrainArgs( + save_interval=1000, + log_interval=1, + global_batch_size=64, + micro_batch_size=1, + lr_warmup_steps=100, + epochs=5, + epoch_size=50000, + learning_rate=3e-3, + max_seq_length=None, + ), + eval: EvalArgs = EvalArgs(interval=600, max_new_tokens=100, max_iters=100), +) -> None: + print(locals()) + precision = precision or get_default_supported_precision(training=True) + + if devices > 1: + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + logger = CSVLogger( + io.out_dir.parent, + io.out_dir.name, + flush_logs_every_n_steps=train.log_interval, + ) + fabric = L.Fabric( + devices=devices, strategy=strategy, precision=precision, loggers=logger + ) + fabric.launch( + main, + devices, + resume, + Config.from_name(name=io.checkpoint_dir.name), + io, + train, + eval, + ) + + +def main( + fabric: L.Fabric, + devices: int, + resume: Union[bool, Path], + config: Config, + io: IOArgs, + train: TrainArgs, + eval: EvalArgs, +) -> None: + validate_args(io, train, eval) + + steps_per_epoch = train.epoch_size // devices // train.batch_size(devices) + lr_max_steps = train.epochs * steps_per_epoch + + check_valid_checkpoint_dir(io.checkpoint_dir) + + fabric.seed_everything( + 1337 + ) # same seed for every process to init model (FSDP) + + if fabric.global_rank == 0: + os.makedirs(io.out_dir, exist_ok=True) + + train_data = torch.load(io.train_data_dir / "train.pt") + val_data = torch.load(io.val_data_dir / "test.pt") + + checkpoint_path = io.checkpoint_dir / "lit_model.pth" + fabric.print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}" + ) + with fabric.init_module(empty_init=(devices > 1)): + model = GPT(config) + + fabric.print( + f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}" + ) + + model = fabric.setup(model) + optimizer = torch.optim.AdamW( + model.parameters(), + lr=train.learning_rate, + weight_decay=train.weight_decay, + betas=(train.beta1, train.beta2), + ) + optimizer = fabric.setup_optimizers(optimizer) + scheduler = get_lr_scheduler( + optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps + ) + state = { + "model": model, + "optimizer": optimizer, + "scheduler": scheduler, + "iter_num": 0, + "step_count": 0, + } + + if resume is True: + resume = max( + io.out_dir.glob("*.pth"), key=(lambda p: int(p.name.split("-")[1])) + ) + if resume: + fabric.print(f"Resuming training from {resume}") + fabric.load(resume, state) + else: + load_checkpoint(fabric, state["model"], checkpoint_path) + + fabric.seed_everything(1337 + fabric.global_rank) + + train_time = time.perf_counter() + fit(fabric, state, train_data, val_data, devices, resume, io, train, eval) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" + ) + + # Save the final checkpoint at the end of training + fabric.save( + io.out_dir / "lit_model_finetuned.pth", {"model": state["model"]} + ) + + +def fit( + fabric: L.Fabric, + state: Dict, + train_data: List[Dict], + val_data: List[Dict], + devices: int, + resume: Union[bool, Path], + io: IOArgs, + train: TrainArgs, + eval: EvalArgs, +) -> None: + model = state["model"] + optimizer = state["optimizer"] + scheduler = state["scheduler"] + tokenizer = Tokenizer(io.checkpoint_dir) + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data) + model.max_seq_length = min( + longest_seq_length, train.max_seq_length or float("inf") + ) + fabric.print( + f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" + f" {model.max_seq_length} and context length is {model.config.block_size}" + ) + + validate( + fabric, + model, + val_data, + tokenizer, + dataclasses.replace(eval, max_iters=2), + train, + ) # sanity check + initial_iter = state["iter_num"] + + # resume data loader state by fast-forwarding through all seen batches + if resume: + resume_t0 = time.perf_counter() + for resume_iter in range(initial_iter): + get_batch(fabric, train_data, None) + if resume_iter % 1000 == 0: + fabric.print( + f"Resuming dataset: {resume_iter} / {initial_iter}" + ) + fabric.barrier() + fabric.print( + f"Resuming data loader finished. Took {time.perf_counter() - resume_t0:.1f} seconds to reach iteration" + f" {initial_iter}." + ) + + running_loss = RunningMean( + window=train.gradient_accumulation_iters(devices), + sync_on_compute=False, + ).to(fabric.device) + fabric.barrier() + + for state["iter_num"] in range( + state["iter_num"] + 1, train.max_iters(devices) + 1 + ): + iter_t0 = time.perf_counter() + + input_ids, targets = get_batch( + fabric, + train_data, + train.micro_batch_size, + train.max_seq_length, + longest_seq_ix if state["iter_num"] == 1 else None, + ) + + is_accumulating = ( + state["iter_num"] % train.gradient_accumulation_iters(devices) != 0 + ) + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids) + # shift the targets such that output n predicts token n+1 + loss = chunked_cross_entropy(logits[..., :-1, :], targets[..., 1:]) + fabric.backward(loss / train.gradient_accumulation_iters(devices)) + + running_loss.update(loss.detach()) + + if not is_accumulating: + optimizer.step() + optimizer.zero_grad() + scheduler.step() + state["step_count"] += 1 + + if state["iter_num"] % train.log_interval == 0: + loss = ( + running_loss.compute().item() + ) # expensive device-to-host synchronization + t1 = time.perf_counter() + metrics = { + "loss": loss, + "iter": state["iter_num"], + "step": state["step_count"], + "iter_time": t1 - iter_t0, + "tokens": state["iter_num"] + * train.micro_batch_size + * model.config.block_size, + "total_tokens": ( + state["iter_num"] + * train.micro_batch_size + * model.config.block_size + * fabric.world_size + ), + # TODO: log learning rate + } + fabric.print( + f"iter {metrics['iter']} | step {metrics['step']}: loss {metrics['loss']:.4f}, iter time:" + f" {metrics['iter_time'] * 1000:.2f} ms{' (optimizer.step)' if not is_accumulating else ''}" + ) + fabric.log_dict(metrics, step=state["iter_num"]) + + if not is_accumulating and state["step_count"] % eval.interval == 0: + t0 = time.perf_counter() + val_loss = validate( + fabric, model, val_data, tokenizer, eval, train + ) + t1 = time.perf_counter() - t0 + fabric.print( + f"iter {state['iter_num']}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms" + ) + metrics = {"val_loss": val_loss, "val_ppl": math.exp(val_loss)} + fabric.log_dict(metrics, step=state["iter_num"]) + fabric.barrier() + if ( + not is_accumulating + and state["step_count"] % train.save_interval == 0 + ): + checkpoint_path = ( + io.out_dir / f"step-{state['step_count']:06d}.pth" + ) + fabric.print(f"Saving checkpoint to {str(checkpoint_path)!r}") + fabric.save(checkpoint_path, state) + + +# FSDP has issues with `inference_mode` +@torch.no_grad() +def validate( + fabric: L.Fabric, + model: GPT, + val_data: List[Dict], + tokenizer: Tokenizer, + eval: EvalArgs, + train: TrainArgs, +) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + losses = torch.zeros(eval.max_iters) + for k in range(eval.max_iters): + input_ids, targets = get_batch( + fabric, val_data, train.micro_batch_size, train.max_seq_length + ) + logits = model(input_ids) + losses[k] = chunked_cross_entropy( + logits[..., :-1, :], targets[..., 1:], chunk_size=0 + ) + val_loss = losses.mean() + + # produce an example: + instruction = "Recommend a movie for me to watch during the weekend and explain the reason." + fabric.print(instruction) + sample = {"instruction": instruction, "input": ""} + prompt = generate_prompt(sample) + encoded = tokenizer.encode(prompt, device=fabric.device) + with fabric.init_tensor(): + # do not set `max_seq_length=max_returned_token` because memory is not a concern here + model.set_kv_cache(batch_size=1) + output = generate( + model, + encoded, + max_returned_tokens=len(encoded) + eval.max_new_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, + ) + model.clear_kv_cache() + output = tokenizer.decode(output) + fabric.print(output) + + model.train() + return val_loss + + +def get_batch( + fabric: L.Fabric, + data: List[Dict], + micro_batch_size: int, + max_seq_length: Optional[int], + longest_seq_ix: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + ix = torch.randint(len(data), (micro_batch_size,)) + if longest_seq_ix is not None: + # force the longest sample at the beginning so potential OOMs happen right away + ix[0] = longest_seq_ix + + input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] + labels = [data[i]["labels"].type(torch.int64) for i in ix] + + # this could be `longest_seq_length` to have a fixed size for all batches + max_len = max(len(s) for s in input_ids) + + def pad_right(x, pad_id): + # pad right based on the longest sequence + n = max_len - len(x) + return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) + + x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) + y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) + + # Truncate if needed + if max_seq_length: + x = x[:, :max_seq_length] + y = y[:, :max_seq_length] + + if fabric.device.type == "cuda" and x.device.type == "cpu": + x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) + else: + x, y = fabric.to_device((x, y)) + return x, y + + +def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): + # linear warmup followed by cosine annealing + scheduler1 = torch.optim.lr_scheduler.LambdaLR( + optimizer, lambda step: step / warmup_steps + ) + scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=(max_steps - warmup_steps) + ) + return torch.optim.lr_scheduler.SequentialLR( + optimizer, [scheduler1, scheduler2], milestones=[warmup_steps] + ) + + +def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: + # find out the minimum max_seq_length required during fine-tuning (saves memory!) + lengths = [len(d["input_ids"]) for d in data] + longest_seq_length = max(lengths) + longest_seq_ix = lengths.index(longest_seq_length) + return longest_seq_length, longest_seq_ix + + +def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: + issues = [] + unsupported = [(train, ["max_tokens", "max_norm"])] + for args, names in unsupported: + for name in names: + if getattr(args, name) is not None: + issues.append( + f"{__file__} doesn't support the {name!r} argument. This is set in {args}" + ) + required = [ + (io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]), + (train, ["epoch_size", "epochs"]), + (eval, ["max_new_tokens"]), + ] + for args, names in required: + for name in names: + if getattr(args, name) is None: + issues.append( + f"{__file__} requires the {name!r} argument. This is set in {args}" + ) + if issues: + raise ValueError("\n".join(issues)) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(setup) diff --git a/examples/llm_finetuning/finetune/lora.py b/examples/llm_finetuning/finetune/lora.py new file mode 100644 index 00000000000..39caa06eeb7 --- /dev/null +++ b/examples/llm_finetuning/finetune/lora.py @@ -0,0 +1,483 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +import dataclasses +import os +import sys +import time +from pathlib import Path +from typing import Dict, List, Literal, Optional, Tuple + +import lightning as L +import torch +from lightning.fabric.loggers import CSVLogger +from lightning.fabric.plugins import BitsandbytesPrecision +from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.utilities import ThroughputMonitor + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.base import generate +from lit_gpt.args import EvalArgs, IOArgs, TrainArgs +from lit_gpt.lora import ( + GPT, + Block, + Config, + lora_filter, + mark_only_lora_as_trainable, +) +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + chunked_cross_entropy, + get_default_supported_precision, + load_checkpoint, + num_parameters, +) + +from scripts.prepare_alpaca import generate_prompt + + +def setup( + precision: Optional[str] = None, + quantize: Optional[ + Literal[ + "bnb.nf4", + "bnb.nf4-dq", + "bnb.fp4", + "bnb.fp4-dq", + "bnb.int8-training", + ] + ] = None, + devices: int = 1, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.05, + lora_query: bool = True, + lora_key: bool = False, + lora_value: bool = True, + lora_projection: bool = False, + lora_mlp: bool = False, + lora_head: bool = False, + io: IOArgs = IOArgs( + train_data_dir=Path("data/alpaca"), + val_data_dir=Path("data/alpaca"), + checkpoint_dir=Path("checkpoints/stabilityai/stablelm-base-alpha-3b"), + out_dir=Path("out/lora/alpaca"), + ), + train: TrainArgs = TrainArgs( + save_interval=1000, + log_interval=1, + global_batch_size=128, + micro_batch_size=4, + lr_warmup_steps=100, + epochs=5, + epoch_size=50000, + learning_rate=3e-4, + max_seq_length=None, + ), + eval: EvalArgs = EvalArgs(interval=100, max_new_tokens=100, max_iters=100), +) -> None: + print(locals()) + precision = precision or get_default_supported_precision(training=True) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + if devices > 1: + if quantize: + raise NotImplementedError( + "Quantization is currently not supported for multi-GPU training. Please set devices=1 when using the" + " --quantize flag." + ) + strategy = FSDPStrategy( + auto_wrap_policy={Block}, + activation_checkpointing_policy={Block}, + state_dict_type="full", + limit_all_gathers=True, + cpu_offload=False, + ) + else: + strategy = "auto" + + logger = CSVLogger( + io.out_dir.parent, + io.out_dir.name, + flush_logs_every_n_steps=train.log_interval, + ) + fabric = L.Fabric( + devices=devices, + strategy=strategy, + precision=precision, + loggers=logger, + plugins=plugins, + ) + + if not any( + ( + lora_query, + lora_key, + lora_value, + lora_projection, + lora_mlp, + lora_head, + ) + ): + fabric.print("Warning: all LoRA layers are disabled!") + fabric.launch( + main, + devices, + Config.from_name( + name=io.checkpoint_dir.name, + r=lora_r, + alpha=lora_alpha, + dropout=lora_dropout, + to_query=lora_query, + to_key=lora_key, + to_value=lora_value, + to_projection=lora_projection, + to_mlp=lora_mlp, + to_head=lora_head, + ), + io, + train, + eval, + ) + + +def main( + fabric: L.Fabric, + devices: int, + config: Config, + io: IOArgs, + train: TrainArgs, + eval: EvalArgs, +) -> None: + validate_args(io, train, eval) + + steps_per_epoch = train.epoch_size // devices // train.batch_size(devices) + lr_max_steps = train.epochs * steps_per_epoch + + check_valid_checkpoint_dir(io.checkpoint_dir) + + fabric.seed_everything( + 1337 + ) # same seed for every process to init model (FSDP) + + if fabric.global_rank == 0: + os.makedirs(io.out_dir, exist_ok=True) + + train_data = torch.load(io.train_data_dir / "train.pt") + val_data = torch.load(io.val_data_dir / "test.pt") + + checkpoint_path = io.checkpoint_dir / "lit_model.pth" + fabric.print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}" + ) + with fabric.init_module(empty_init=(devices > 1)): + model = GPT(config) + mark_only_lora_as_trainable(model) + + fabric.print( + f"Number of trainable parameters: {num_parameters(model, requires_grad=True):,}" + ) + fabric.print( + f"Number of non trainable parameters: {num_parameters(model, requires_grad=False):,}" + ) + + model = fabric.setup_module(model) + + trainable_params = [p for p in model.parameters() if p.requires_grad] + if isinstance(fabric.strategy.precision, BitsandbytesPrecision): + import bitsandbytes as bnb + + optimizer_cls = bnb.optim.PagedAdamW + else: + optimizer_cls = torch.optim.AdamW + optimizer = optimizer_cls( + trainable_params, + lr=train.learning_rate, + weight_decay=train.weight_decay, + betas=(train.beta1, train.beta2), + ) + optimizer = fabric.setup_optimizers(optimizer) + scheduler = get_lr_scheduler( + optimizer, warmup_steps=train.lr_warmup_steps, max_steps=lr_max_steps + ) + + # strict=False because missing keys due to LoRA weights not contained in state dict + load_checkpoint(fabric, model, checkpoint_path, strict=False) + + fabric.seed_everything(1337 + fabric.global_rank) + + train_time = time.perf_counter() + fit( + fabric, + model, + optimizer, + scheduler, + train_data, + val_data, + devices, + io, + train, + eval, + ) + fabric.print(f"Training time: {(time.perf_counter()-train_time):.2f}s") + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB" + ) + + # Save the final LoRA checkpoint at the end of training + save_path = io.out_dir / "lit_model_lora_finetuned.pth" + save_lora_checkpoint(fabric, model, save_path) + + +def fit( + fabric: L.Fabric, + model: GPT, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler, + train_data: List[Dict], + val_data: List[Dict], + devices: int, + io: IOArgs, + train: TrainArgs, + eval: EvalArgs, +) -> None: + tokenizer = Tokenizer(io.checkpoint_dir) + longest_seq_length, longest_seq_ix = get_longest_seq_length(train_data) + model.max_seq_length = min( + longest_seq_length, train.max_seq_length or float("inf") + ) + fabric.print( + f"The longest sequence length in the train data is {longest_seq_length}, the model's maximum sequence length is" + f" {model.max_seq_length} and context length is {model.config.block_size}" + ) + + validate( + fabric, + model, + val_data, + tokenizer, + dataclasses.replace(eval, max_iters=2), + train, + ) # sanity check + + throughput = ThroughputMonitor(fabric, window_size=50) + step_count = 0 + total_lengths = 0 + total_t0 = time.perf_counter() + + for iter_num in range(1, train.max_iters(devices) + 1): + iter_t0 = time.perf_counter() + + input_ids, targets = get_batch( + fabric, + train_data, + train.micro_batch_size, + train.max_seq_length, + longest_seq_ix if iter_num == 1 else None, + ) + + is_accumulating = ( + iter_num % train.gradient_accumulation_iters(devices) != 0 + ) + with fabric.no_backward_sync(model, enabled=is_accumulating): + logits = model(input_ids, lm_head_chunk_size=128) + # shift the targets such that output n predicts token n+1 + logits[-1] = logits[-1][..., :-1, :] + loss = chunked_cross_entropy(logits, targets[..., 1:]) + fabric.backward(loss / train.gradient_accumulation_iters(devices)) + + if not is_accumulating: + optimizer.step() + optimizer.zero_grad() + scheduler.step() + step_count += 1 + + total_lengths += input_ids.numel() + if iter_num % train.log_interval == 0: + loss_item = loss.item() # expensive device-to-host synchronization + t1 = time.perf_counter() + throughput.update( + time=t1 - total_t0, + batches=iter_num, + samples=iter_num * train.micro_batch_size, + lengths=total_lengths, + ) + throughput.compute_and_log(step=iter_num) + fabric.print( + f"iter {iter_num} | step {step_count}: loss {loss_item:.4f}, iter time:" + f" {(t1 - iter_t0) * 1000:.2f} ms{' (optimizer.step)' if not is_accumulating else ''}" + ) + + if not is_accumulating and step_count % eval.interval == 0: + t0 = time.perf_counter() + val_loss = validate( + fabric, model, val_data, tokenizer, eval, train + ) + t1 = time.perf_counter() - t0 + fabric.print( + f"iter {iter_num}: val loss {val_loss.item():.4f}, val time: {t1 * 1000:.2f} ms" + ) + fabric.barrier() + if not is_accumulating and step_count % train.save_interval == 0: + checkpoint_path = io.out_dir / f"iter-{iter_num:06d}-ckpt.pth" + save_lora_checkpoint(fabric, model, checkpoint_path) + + +# FSDP has issues with `inference_mode` +@torch.no_grad() +def validate( + fabric: L.Fabric, + model: GPT, + val_data: List[Dict], + tokenizer: Tokenizer, + eval: EvalArgs, + train: TrainArgs, +) -> torch.Tensor: + fabric.print("Validating ...") + model.eval() + losses = torch.zeros(eval.max_iters) + for k in range(eval.max_iters): + input_ids, targets = get_batch( + fabric, val_data, train.micro_batch_size, train.max_seq_length + ) + logits = model(input_ids) + losses[k] = chunked_cross_entropy( + logits[..., :-1, :], targets[..., 1:], chunk_size=0 + ) + val_loss = losses.mean() + + # produce an example: + instruction = "Recommend a movie for me to watch during the weekend and explain the reason." + fabric.print(instruction) + sample = {"instruction": instruction, "input": ""} + prompt = generate_prompt(sample) + encoded = tokenizer.encode(prompt, device=fabric.device) + with fabric.init_tensor(): + # do not set `max_seq_length=max_returned_token` because memory is not a concern here + model.set_kv_cache(batch_size=1) + output = generate( + model, + encoded, + max_returned_tokens=len(encoded) + eval.max_new_tokens, + temperature=0.8, + eos_id=tokenizer.eos_id, + ) + model.clear_kv_cache() + output = tokenizer.decode(output) + fabric.print(output) + + model.train() + return val_loss + + +def get_batch( + fabric: L.Fabric, + data: List[Dict], + micro_batch_size: int, + max_seq_length: Optional[int], + longest_seq_ix: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + ix = torch.randint(len(data), (micro_batch_size,)) + if longest_seq_ix is not None: + # force the longest sample at the beginning so potential OOMs happen right away + ix[0] = longest_seq_ix + + input_ids = [data[i]["input_ids"].type(torch.int64) for i in ix] + labels = [data[i]["labels"].type(torch.int64) for i in ix] + + # this could be `longest_seq_length` to have a fixed size for all batches + max_len = max(len(s) for s in input_ids) + + def pad_right(x, pad_id): + # pad right based on the longest sequence + n = max_len - len(x) + return torch.cat((x, torch.full((n,), pad_id, dtype=x.dtype))) + + x = torch.stack([pad_right(x, pad_id=0) for x in input_ids]) + y = torch.stack([pad_right(x, pad_id=-1) for x in labels]) + + # Truncate if needed + if max_seq_length: + x = x[:, :max_seq_length] + y = y[:, :max_seq_length] + + if fabric.device.type == "cuda" and x.device.type == "cpu": + x, y = fabric.to_device((x.pin_memory(), y.pin_memory())) + else: + x, y = fabric.to_device((x, y)) + return x, y + + +def get_lr_scheduler(optimizer, warmup_steps: int, max_steps: int): + # linear warmup followed by cosine annealing + scheduler1 = torch.optim.lr_scheduler.LambdaLR( + optimizer, lambda step: step / warmup_steps + ) + scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, T_max=(max_steps - warmup_steps) + ) + return torch.optim.lr_scheduler.SequentialLR( + optimizer, [scheduler1, scheduler2], milestones=[warmup_steps] + ) + + +def get_longest_seq_length(data: List[Dict]) -> Tuple[int, int]: + # find out the minimum max_seq_length required during fine-tuning (saves memory!) + lengths = [len(d["input_ids"]) for d in data] + longest_seq_length = max(lengths) + longest_seq_ix = lengths.index(longest_seq_length) + return longest_seq_length, longest_seq_ix + + +def save_lora_checkpoint( + fabric: L.Fabric, model: torch.nn.Module, file_path: Path +) -> None: + fabric.print(f"Saving LoRA weights to {str(file_path)!r}") + fabric.save(file_path, {"model": model}, filter={"model": lora_filter}) + + +def validate_args(io: IOArgs, train: TrainArgs, eval: EvalArgs) -> None: + issues = [] + unsupported = [(train, ["max_tokens", "max_norm"])] + for args, names in unsupported: + for name in names: + if getattr(args, name) is not None: + issues.append( + f"{__file__} doesn't support the {name!r} argument. This is set in {args}" + ) + required = [ + (io, ["checkpoint_dir", "train_data_dir", "val_data_dir"]), + (train, ["epoch_size", "epochs"]), + (eval, ["max_new_tokens"]), + ] + for args, names in required: + for name in names: + if getattr(args, name) is None: + issues.append( + f"{__file__} requires the {name!r} argument. This is set in {args}" + ) + if issues: + raise ValueError("\n".join(issues)) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(setup) diff --git a/examples/llm_finetuning/generate/adapter.py b/examples/llm_finetuning/generate/adapter.py new file mode 100644 index 00000000000..3daa88362b6 --- /dev/null +++ b/examples/llm_finetuning/generate/adapter.py @@ -0,0 +1,159 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import sys +import time +from pathlib import Path +from typing import Literal, Optional + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.base import generate +from lit_gpt import Tokenizer +from lit_gpt.adapter import GPT, Config +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + get_default_supported_precision, + lazy_load, +) + +from scripts.prepare_alpaca import generate_prompt + + +def main( + prompt: str = "What food do llamas eat?", + input: str = "", + adapter_path: Path = Path( + "out/adapter/alpaca/lit_model_adapter_finetuned.pth" + ), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + quantize: Optional[ + Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"] + ] = None, + max_new_tokens: int = 100, + top_k: Optional[int] = 200, + temperature: float = 0.8, + precision: Optional[str] = None, +) -> None: + """Generates a response based on a given instruction and an optional input. + This script will only work with checkpoints from the instruction-tuned GPT-Adapter model. + See `finetune/adapter.py`. + + Args: + prompt: The prompt/instruction (Alpaca style). + input: Optional input (Alpaca style). + adapter_path: Path to the checkpoint with trained adapter weights, which are the output of + `finetune/adapter.py`. + checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + - bnb.int8: 8-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + precision: Indicates the Fabric precision setting to use. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + fabric.launch() + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_json(checkpoint_dir / "lit_config.json") + + checkpoint_path = checkpoint_dir / "lit_model.pth" + + tokenizer = Tokenizer(checkpoint_dir) + sample = {"instruction": prompt, "input": input} + prompt = generate_prompt(sample) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", + file=sys.stderr, + ) + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + fabric.print( + f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + t0 = time.perf_counter() + checkpoint = lazy_load(checkpoint_path) + adapter_checkpoint = lazy_load(adapter_path) + checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint)) + model.load_state_dict(checkpoint) + fabric.print( + f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + model = fabric.setup(model) + + L.seed_everything(1234) + t0 = time.perf_counter() + y = generate( + model, + encoded, + max_returned_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id, + ) + t = time.perf_counter() - t0 + + output = tokenizer.decode(y) + output = output.split("### Response:")[1].strip() + fabric.print(output) + + tokens_generated = y.size(0) - prompt_length + fabric.print( + f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", + file=sys.stderr, + ) + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", + file=sys.stderr, + ) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/examples/llm_finetuning/generate/adapter_v2.py b/examples/llm_finetuning/generate/adapter_v2.py new file mode 100644 index 00000000000..6f9d76d4c72 --- /dev/null +++ b/examples/llm_finetuning/generate/adapter_v2.py @@ -0,0 +1,159 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import sys +import time +from pathlib import Path +from typing import Literal, Optional + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.base import generate +from lit_gpt import Tokenizer +from lit_gpt.adapter_v2 import GPT, Config +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + get_default_supported_precision, + lazy_load, +) + +from scripts.prepare_alpaca import generate_prompt + + +def main( + prompt: str = "What food do llamas eat?", + input: str = "", + adapter_path: Path = Path( + "out/adapter_v2/alpaca/lit_model_adapter_finetuned.pth" + ), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + quantize: Optional[ + Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"] + ] = None, + max_new_tokens: int = 100, + top_k: Optional[int] = 200, + temperature: float = 0.8, + precision: Optional[str] = None, +) -> None: + """Generates a response based on a given instruction and an optional input. + This script will only work with checkpoints from the instruction-tuned GPT-AdapterV2 model. + See `finetune/adapter_v2.py`. + + Args: + prompt: The prompt/instruction (Alpaca style). + input: Optional input (Alpaca style). + adapter_path: Path to the checkpoint with trained adapter weights, which are the output of + `finetune/adapter_v2.py`. + checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + - bnb.int8: 8-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + precision: Indicates the Fabric precision setting to use. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + fabric.launch() + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_json(checkpoint_dir / "lit_config.json") + + checkpoint_path = checkpoint_dir / "lit_model.pth" + + tokenizer = Tokenizer(checkpoint_dir) + sample = {"instruction": prompt, "input": input} + prompt = generate_prompt(sample) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", + file=sys.stderr, + ) + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + fabric.print( + f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + t0 = time.perf_counter() + checkpoint = lazy_load(checkpoint_path) + adapter_checkpoint = lazy_load(adapter_path) + checkpoint.update(adapter_checkpoint.get("model", adapter_checkpoint)) + model.load_state_dict(checkpoint) + fabric.print( + f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + model = fabric.setup(model) + + L.seed_everything(1234) + t0 = time.perf_counter() + y = generate( + model, + encoded, + max_returned_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id, + ) + t = time.perf_counter() - t0 + + output = tokenizer.decode(y) + output = output.split("### Response:")[1].strip() + fabric.print(output) + + tokens_generated = y.size(0) - prompt_length + fabric.print( + f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", + file=sys.stderr, + ) + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", + file=sys.stderr, + ) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/examples/llm_finetuning/generate/base.py b/examples/llm_finetuning/generate/base.py new file mode 100644 index 00000000000..f8cfa7bd54c --- /dev/null +++ b/examples/llm_finetuning/generate/base.py @@ -0,0 +1,244 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import sys +import time +from pathlib import Path +from typing import Any, Literal, Optional + +import lightning as L +import torch +import torch._dynamo.config +import torch._inductor.config +from lightning.fabric.plugins import BitsandbytesPrecision + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import GPT, Config, Tokenizer +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + get_default_supported_precision, + load_checkpoint, +) + + +def multinomial_num_samples_1(probs: torch.Tensor) -> torch.Tensor: + if torch._dynamo.is_compiling(): + # Faster alternative to `torch.multinomial(probs, num_samples=1)` that is also CUDAGraph friendly + distribution = torch.empty_like(probs).exponential_(1) + return torch.argmax(probs / distribution, dim=-1, keepdim=True) + return torch.multinomial(probs, num_samples=1) + + +def sample( + logits: torch.Tensor, temperature: float = 1.0, top_k: Optional[int] = None +) -> torch.Tensor: + logits = logits[0, -1] + # optionally crop the logits to only the top k options + if top_k is not None: + v, i = torch.topk(logits, min(top_k, logits.size(-1))) + # do not use `torch.where` as in nanogpt because it will repeat top-k collisions + logits = torch.full_like(logits, float("-inf")).scatter_(-1, i, v) + # optionally scale the logits and sample from a probability distribution + if temperature > 0.0: + probs = torch.nn.functional.softmax(logits / temperature, dim=-1) + return multinomial_num_samples_1(probs) + return torch.argmax(logits, dim=-1, keepdim=True) + + +def next_token( + model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any +) -> torch.Tensor: + logits = model(x, input_pos) + next = sample(logits, **kwargs) + return next.to(dtype=x.dtype) + + +@torch.inference_mode() +def generate( + model: GPT, + prompt: torch.Tensor, + max_returned_tokens: int, + *, + temperature: float = 1.0, + top_k: Optional[int] = None, + eos_id: Optional[int] = None, +) -> torch.Tensor: + """Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested. + + The implementation of this function is modified from A. Karpathy's nanoGPT. + + Args: + model: The model to use. + prompt: Tensor of shape (T) with indices of the prompt sequence. + max_returned_tokens: The maximum number of tokens to return (given plus generated). + temperature: Scales the predicted logits by 1 / temperature. + top_k: If specified, only sample among the tokens with the k highest probabilities. + eos_id: If specified, stop generating any more token once the token is triggered. + """ + T = prompt.size(0) + assert max_returned_tokens > T + if model.max_seq_length < max_returned_tokens - 1: + # rolling the kv cache based on the `input_pos` value would be necessary. However, doing so would introduce a + # data dependency on the `input_pos` tensor and impact model compilation. Since this setting is uncommon, we do + # not support it to avoid negatively impacting the overall speed + raise NotImplementedError( + f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}" + ) + + device = prompt.device + tokens = [prompt] + input_pos = torch.tensor([T], device=device) + token = next_token( + model, + torch.arange(0, T, device=device), + prompt.view(1, -1), + temperature=temperature, + top_k=top_k, + ).clone() + tokens.append(token) + for _ in range(2, max_returned_tokens - T + 1): + token = next_token( + model, + input_pos, + token.view(1, -1), + temperature=temperature, + top_k=top_k, + ).clone() + tokens.append(token) + if token == eos_id: + break + input_pos = input_pos.add_(1) + return torch.cat(tokens) + + +@torch.inference_mode() +def main( + prompt: str = "What food do llamas eat?", + *, + num_samples: int = 1, + max_new_tokens: int = 50, + top_k: Optional[int] = 200, + temperature: float = 0.8, + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + quantize: Optional[ + Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"] + ] = None, + precision: Optional[str] = None, + compile: bool = False, +) -> None: + """Generates text samples based on a pre-trained model and tokenizer. + + Args: + prompt: The prompt string to use for generating the samples. + num_samples: The number of text samples to generate. + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + checkpoint_dir: The checkpoint directory to load. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + - bnb.int8: 8-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md + precision: Indicates the Fabric precision setting to use. + compile: Whether to compile the model. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_json(checkpoint_dir / "lit_config.json") + + checkpoint_path = checkpoint_dir / "lit_model.pth" + + tokenizer = Tokenizer(checkpoint_dir) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", + file=sys.stderr, + ) + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + fabric.print( + f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + if compile: + torch._dynamo.config.automatic_dynamic_shapes = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.coordinate_descent_tuning = True + global next_token + next_token = torch.compile(next_token, mode="reduce-overhead") + + model = fabric.setup_module(model) + + t0 = time.perf_counter() + load_checkpoint(fabric, model, checkpoint_path) + fabric.print( + f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + L.seed_everything(1234) + for i in range(num_samples): + t0 = time.perf_counter() + y = generate( + model, + encoded, + max_returned_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id, + ) + t = time.perf_counter() - t0 + for block in model.transformer.h: + block.attn.kv_cache.reset_parameters() + fabric.print(tokenizer.decode(y)) + tokens_generated = y.size(0) - prompt_length + fabric.print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", + file=sys.stderr, + ) + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", + file=sys.stderr, + ) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/examples/llm_finetuning/generate/full.py b/examples/llm_finetuning/generate/full.py new file mode 100644 index 00000000000..cc1da495ef4 --- /dev/null +++ b/examples/llm_finetuning/generate/full.py @@ -0,0 +1,153 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import sys +import time +from pathlib import Path +from typing import Literal, Optional + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.base import generate +from lit_gpt import GPT, Config, Tokenizer +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + get_default_supported_precision, + load_checkpoint, +) + +from scripts.prepare_alpaca import generate_prompt + + +def main( + prompt: str = "What food do llamas eat?", + input: str = "", + finetuned_path: Path = Path("out/full/alpaca/lit_model_finetuned.pth"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + quantize: Optional[ + Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"] + ] = None, + max_new_tokens: int = 100, + top_k: Optional[int] = 200, + temperature: float = 0.8, + precision: Optional[str] = None, +) -> None: + """Generates a response based on a given instruction and an optional input. + This script will only work with checkpoints from the instruction-tuned GPT model. + See `finetune/full.py`. + + Args: + prompt: The prompt/instruction (Alpaca style). + input: Optional input (Alpaca style). + finetuned_path: Path to the checkpoint with trained weights, which are the output of + `finetune/full.py`. + checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + - bnb.int8: 8-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + precision: Indicates the Fabric precision setting to use. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + fabric.launch() + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_json(checkpoint_dir / "lit_config.json") + + checkpoint_path = finetuned_path + + tokenizer = Tokenizer(checkpoint_dir) + sample = {"instruction": prompt, "input": input} + prompt = generate_prompt(sample) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", + file=sys.stderr, + ) + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + fabric.print( + f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + model = fabric.setup(model) + + t0 = time.perf_counter() + load_checkpoint(fabric, model, checkpoint_path) + fabric.print( + f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + L.seed_everything(1234) + t0 = time.perf_counter() + y = generate( + model, + encoded, + max_returned_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id, + ) + t = time.perf_counter() - t0 + + output = tokenizer.decode(y) + output = output.split("### Response:")[1].strip() + fabric.print(output) + + tokens_generated = y.size(0) - prompt_length + fabric.print( + f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", + file=sys.stderr, + ) + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", + file=sys.stderr, + ) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/examples/llm_finetuning/generate/lora.py b/examples/llm_finetuning/generate/lora.py new file mode 100644 index 00000000000..0b30b701ef2 --- /dev/null +++ b/examples/llm_finetuning/generate/lora.py @@ -0,0 +1,178 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import sys +import time +from pathlib import Path +from typing import Literal, Optional + +import lightning as L +import torch +from lightning.fabric.plugins import BitsandbytesPrecision + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from generate.base import generate +from lit_gpt import Tokenizer +from lit_gpt.lora import GPT, Config, merge_lora_weights +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + get_default_supported_precision, + lazy_load, +) + +from scripts.prepare_alpaca import generate_prompt + + +def main( + prompt: str = "What food do llamas eat?", + input: str = "", + lora_path: Path = Path("out/lora/alpaca/lit_model_lora_finetuned.pth"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + quantize: Optional[ + Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8"] + ] = None, + max_new_tokens: int = 100, + top_k: Optional[int] = 200, + temperature: float = 0.8, + precision: Optional[str] = None, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.05, + lora_query: bool = True, + lora_key: bool = False, + lora_value: bool = True, + lora_projection: bool = False, + lora_mlp: bool = False, + lora_head: bool = False, +) -> None: + """Generates a response based on a given instruction and an optional input. + This script will only work with checkpoints from the instruction-tuned GPT-LoRA model. + See `finetune/lora.py`. + + Args: + prompt: The prompt/instruction (Alpaca style). + input: Optional input (Alpaca style). + lora_path: Path to the checkpoint with trained adapter weights, which are the output of + `finetune/lora.py`. + checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + - bnb.int8: 8-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + precision: Indicates the Fabric precision setting to use. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None and quantize.startswith("bnb."): + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric(devices=1, precision=precision, plugins=plugins) + fabric.launch() + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_json( + checkpoint_dir / "lit_config.json", + r=lora_r, + alpha=lora_alpha, + dropout=lora_dropout, + to_query=lora_query, + to_key=lora_key, + to_value=lora_value, + to_projection=lora_projection, + to_mlp=lora_mlp, + to_head=lora_head, + ) + + checkpoint_path = checkpoint_dir / "lit_model.pth" + + tokenizer = Tokenizer(checkpoint_dir) + sample = {"instruction": prompt, "input": input} + prompt = generate_prompt(sample) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", + file=sys.stderr, + ) + t0 = time.perf_counter() + with fabric.init_module(empty_init=True): + model = GPT(config) + fabric.print( + f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + t0 = time.perf_counter() + checkpoint = lazy_load(checkpoint_path) + lora_checkpoint = lazy_load(lora_path) + checkpoint.update(lora_checkpoint.get("model", lora_checkpoint)) + model.load_state_dict(checkpoint) + fabric.print( + f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + merge_lora_weights(model) + model = fabric.setup(model) + + L.seed_everything(1234) + t0 = time.perf_counter() + y = generate( + model, + encoded, + max_returned_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id, + ) + t = time.perf_counter() - t0 + + output = tokenizer.decode(y) + output = output.split("### Response:")[1].strip() + fabric.print(output) + + tokens_generated = y.size(0) - prompt_length + fabric.print( + f"\n\nTime for inference: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", + file=sys.stderr, + ) + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", + file=sys.stderr, + ) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + CLI(main) diff --git a/examples/llm_finetuning/generate/sequentially.py b/examples/llm_finetuning/generate/sequentially.py new file mode 100644 index 00000000000..d2dde4bb843 --- /dev/null +++ b/examples/llm_finetuning/generate/sequentially.py @@ -0,0 +1,301 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import itertools +import logging +import re +import sys +import time +from collections import OrderedDict +from functools import partial +from pathlib import Path +from typing import Literal, Optional + +import lightning as L +import torch +from lightning.fabric.accelerators import CUDAAccelerator +from lightning.fabric.plugins import BitsandbytesPrecision +from lightning.fabric.utilities.init import _materialize_meta_tensors +from typing_extensions import Type + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import generate.base as generate_base +from lit_gpt import GPT, Config, Tokenizer +from lit_gpt.model import Block, build_mask_cache +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + get_default_supported_precision, +) + + +@torch.inference_mode() +def sequential( + model: GPT, root: torch.device, max_seq_length: int, devices: int +): + if model.config.n_layer % devices: + # TODO: support smarter partitioning schemes + raise NotImplementedError( + f"Only balanced partitioning is implemented: n_layer={model.config.n_layer}, devices {devices}" + ) + layers_per_rank = model.config.n_layer // devices + # dictates where each block should be instantiated + mapping = layer_to_device( + model, chunk_on=Block, chunk_size=layers_per_rank + ) + + # materialize each block on the appropriate device + for path, target_index in mapping.items(): + submodule = model.get_submodule(path) + target_device = torch.device(root.type, target_index) + print(f"Moving {path!r} to {target_device}", file=sys.stderr) + # submodules loaded by the checkpoint will be on CPU (if no quantization). move them + replace_device( + submodule, replace=torch.device("cpu"), by=target_device + ) + # in case the checkpoint was partial, materialize leftover metas + _materialize_meta_tensors(submodule, target_device) + # and build the kv cache + submodule.attn.kv_cache = submodule.attn.build_kv_cache( + 1, max_seq_length, model.cos.size(-1), target_device + ) + # rebuild odd ends + with root: + model.max_seq_length = max_seq_length + # the rope cache which is on meta device + model.cos, model.sin = model.rope_cache() + # the mask cache which cannot be created with `set_kv_cache` because that will set it for all layers + model.mask_cache = build_mask_cache(max_seq_length) + # and everything that is not a block in the root + _materialize_meta_tensors(model, root) + replace_device(model, replace=torch.device("cpu"), by=root) + + if devices > 1: + # install hooks to move layer inputs/output between devices + for layer_num, (path, target_index) in enumerate(mapping.items()): + submodule = model.get_submodule(path) + if layer_num >= layers_per_rank: + # we need to move the block input on the boundaries between devices + # and also on every non-root device because the RoPE and mask cache is shared + # TODO: the second case could be optimized and then we would only need this hook for + # `layer_num in [layers_per_rank * i - 1 for i in range(1, devices + 1)]` + target_device = torch.device(root.type, target_index) + submodule.register_forward_pre_hook( + partial(move_block_input, target_device) + ) + if layer_num == model.config.n_layer - 1: + submodule.register_forward_hook( + partial(move_block_output, root) + ) + + return model + + +def layer_to_device( + module: torch.nn.Module, chunk_on: Type[torch.nn.Module], chunk_size: int +) -> "OrderedDict[str, int]": + """Create a mapping from layer (block) to device.""" + # this assumes that the definition order is the same as the execution order + hits = [ + name + for name, submodule in module.named_modules() + if isinstance(submodule, chunk_on) + ] + return OrderedDict((name, i // chunk_size) for i, name in enumerate(hits)) + + +def move_block_input(device: torch.device, module: torch.nn.Module, ins): + """``forward_pre_hook`` to move a Block's input before forward.""" + # during inference, none of the inputs are None: x, cos, sin, mask, input_pos + return tuple(t.to(device) for t in ins) + + +def move_block_output( + device: torch.device, module: torch.nn.Module, ins, outs +) -> torch.Tensor: + """``forward_hook`` to move a Block's output after forward.""" + return outs.to(device) + + +def replace_device( + module: torch.nn.Module, replace: torch.device, by: torch.device +) -> torch.nn.Module: + for name, submodule in module.named_modules(): + tensors = dict( + itertools.chain( + submodule.named_parameters(recurse=False), + submodule.named_buffers(recurse=False), + ) + ) + if not tensors: + continue + devices = {t.device for t in tensors.values()} + if len(devices) != 1: + # since this is using `submodule.to`, different devices in the same submodule is a problem + path_to_device = { + f"{name}.{p}": t.device for p, t in tensors.items() + } + raise ValueError(f"Found multiple devices: {path_to_device}") + if devices.pop() == replace: + submodule.to(by) + return module + + +@torch.inference_mode() +def main( + prompt: str = "What food do llamas eat?", + *, + num_samples: int = 1, + max_new_tokens: int = 50, + top_k: Optional[int] = 200, + temperature: float = 0.8, + checkpoint_dir: Path = Path( + "checkpoints/mistralai/Mistral-7B-Instruct-v0.1" + ), + quantize: Optional[ + Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"] + ] = None, + precision: Optional[str] = None, + compile: bool = False, +) -> None: + """Generates text samples based on a pre-trained model and tokenizer. + + Args: + prompt: The prompt string to use for generating the samples. + num_samples: The number of text samples to generate. + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + checkpoint_dir: The checkpoint directory to load. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md + precision: Indicates the Fabric precision setting to use. + compile: Whether to compile the model. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None: + if compile: + raise NotImplementedError # untested + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + fabric = L.Fabric( + devices=1, precision=precision, accelerator="cuda", plugins=plugins + ) + + total_devices = CUDAAccelerator.auto_device_count() + print(f"Using {total_devices} devices", file=sys.stderr) + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_json(checkpoint_dir / "lit_config.json") + + checkpoint_path = checkpoint_dir / "lit_model.pth" + + tokenizer = Tokenizer(checkpoint_dir) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", + file=sys.stderr, + ) + t0 = time.perf_counter() + # cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced + # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert + # still, use init_tensor for the precision + with fabric.init_tensor(), torch.device("meta"): + model = GPT(config) + print( + f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + t0 = time.perf_counter() + state_dict = torch.load( + str(checkpoint_path), mmap=True, map_location="cpu" + ) + # TODO: this assumes that the model fits on CPU. Use lazy_load and make the materialization checkpoint aware + model.load_state_dict(state_dict, assign=True) + print( + f"Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + model = fabric.setup_module(model, move_to_device=False) + + t0 = time.perf_counter() + model = sequential( + model, fabric.device, max_returned_tokens, total_devices + ) + print( + f"Time to sequential-ize the model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + if compile: + # TODO: raises an internal compile AssertionError caused by fabric.strategy.precision.forward_context + raise NotImplementedError + # silence developer warning on nightly builds + # https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/ir.py#L4166 + pattern = re.compile(".*DeviceCopy in input program.*") + logging.getLogger("torch._inductor.utils").addFilter( + lambda record: not pattern.search(record.getMessage()) + ) + torch._dynamo.config.automatic_dynamic_shapes = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.coordinate_descent_tuning = True + # cannot use cudagraphs because it doesn't support multiple device indices + # https://github.com/pytorch/pytorch/blob/v2.2.0-rc5/torch/_inductor/compile_fx.py#L371-L375 + generate_base.next_token = torch.compile(generate_base.next_token) + + L.seed_everything(1234) + for i in range(num_samples): + t0 = time.perf_counter() + y = generate_base.generate( + model, + encoded, + max_returned_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id, + ) + t = time.perf_counter() - t0 + for block in model.transformer.h: + block.attn.kv_cache.reset_parameters() + print(tokenizer.decode(y)) + tokens_generated = y.size(0) - prompt_length + print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", + file=sys.stderr, + ) + print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", + file=sys.stderr, + ) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + logging.getLogger( + "lightning.fabric.plugins.precision.bitsandbytes" + ).setLevel(logging.DEBUG) + + CLI(main) diff --git a/examples/llm_finetuning/generate/tp.py b/examples/llm_finetuning/generate/tp.py new file mode 100644 index 00000000000..e8c7e1efc6b --- /dev/null +++ b/examples/llm_finetuning/generate/tp.py @@ -0,0 +1,287 @@ +"""Tensor-parallel implementation adapted from https://github.com/pytorch-labs/gpt-fast/blob/14df27/tp.py""" + +import logging +import sys +import time +from functools import partial +from pathlib import Path +from typing import Literal, Optional, Union + +import lightning as L +import torch +import torch._dynamo.config +import torch._inductor.config +from lightning.fabric.plugins import BitsandbytesPrecision +from lightning.fabric.utilities import rank_zero_only +from torch.distributed._functional_collectives import all_reduce + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import generate.base as generate_base +from lit_gpt import GPT, Config, Tokenizer +from lit_gpt.model import CausalSelfAttention, GptNeoxMLP, LLaMAMLP, LLaMAMoE +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + get_default_supported_precision, +) + + +def tensor_parallel_linear( + fabric: L.Fabric, linear: torch.nn.Linear, style: str +) -> None: + world_size = fabric.world_size + dim, attr = { + "colwise": (0, "out_features"), + "rowwise": (1, "in_features"), + }[style] + size = getattr(linear, attr) + if size % world_size != 0: + raise ValueError( + f"This linear's {attr} value ({size}) is not evenly divisible by the world size ({world_size})" + ) + + shard = torch.tensor_split(linear.weight, world_size, dim=dim)[ + fabric.global_rank + ] + # overwrite `.data` instead of recreating the parameter for quantization (bitsandbytes) support. + # the bitsandbytes linear classes use custom `torch.nn.Parameter` subclasses + linear.weight.data = shard + setattr(linear, attr, shard.size(dim)) + + if linear.bias is not None and dim == 0: + shard = torch.tensor_split(linear.bias, world_size)[fabric.global_rank] + linear.bias = torch.nn.Parameter( + shard, requires_grad=linear.bias.requires_grad + ) + + +def tensor_parallel_mlp( + fabric: L.Fabric, mlp: Union[GptNeoxMLP, LLaMAMLP, LLaMAMoE] +) -> None: + if isinstance(mlp, LLaMAMLP): + tensor_parallel_linear(fabric, mlp.fc_1, "colwise") + tensor_parallel_linear(fabric, mlp.fc_2, "colwise") + tensor_parallel_linear(fabric, mlp.proj, "rowwise") + mlp.register_forward_hook( + partial(all_reduce_output, fabric.world_size) + ) + elif isinstance(mlp, GptNeoxMLP): + tensor_parallel_linear(fabric, mlp.fc, "colwise") + tensor_parallel_linear(fabric, mlp.proj, "rowwise") + mlp.register_forward_hook( + partial(all_reduce_output, fabric.world_size) + ) + elif isinstance(mlp, LLaMAMoE): + # we use expert slicing across ranks, alternatively, we could create a expert parallelism group + # when the number of experts is a multiple of the world size + for expert in mlp.experts: + tensor_parallel_mlp(fabric, expert) + else: + raise NotImplementedError + + +def tensor_parallel_attn(fabric: L.Fabric, attn: CausalSelfAttention) -> None: + tensor_parallel_linear(fabric, attn.attn, "colwise") + tensor_parallel_linear(fabric, attn.proj, "rowwise") + attn.register_forward_hook(partial(all_reduce_output, fabric.world_size)) + + +def all_reduce_output( + world_size: int, module: torch.nn.Module, ins, outs +) -> torch.Tensor: + return all_reduce(outs, "sum", list(range(world_size))) + + +def tensor_parallel(fabric: L.Fabric, model: GPT) -> GPT: + for block in model.transformer.h: + tensor_parallel_mlp(fabric, block.mlp) + tensor_parallel_attn(fabric, block.attn) + + # update the config values to the shard sizes + # this is only relevant for `tensor_parallel_attn`, but it needs to run only once + world_size = fabric.world_size + attrs = ["n_head", "n_embd", "n_query_groups"] + for attr in attrs: + size = getattr(model.config, attr) + if size % world_size != 0: + raise ValueError( + f"This {attr} value ({size}) is not evenly divisible by the world size ({world_size})" + ) + setattr(model.config, attr, size // world_size) + + return model + + +@torch.inference_mode() +def main( + prompt: str = "What food do llamas eat?", + *, + num_samples: int = 1, + max_new_tokens: int = 50, + top_k: Optional[int] = 200, + temperature: float = 0.8, + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + quantize: Optional[ + Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq"] + ] = None, + precision: Optional[str] = None, + compile: bool = False, +) -> None: + """Generates text samples based on a pre-trained model and tokenizer. + + Args: + prompt: The prompt string to use for generating the samples. + num_samples: The number of text samples to generate. + max_new_tokens: The number of generation steps to take. + top_k: The number of top most probable tokens to consider in the sampling process. + temperature: A value controlling the randomness of the sampling process. Higher values result in more random + samples. + checkpoint_dir: The checkpoint directory to load. + quantize: Whether to quantize the model and using which method: + - bnb.nf4, bnb.nf4-dq, bnb.fp4, bnb.fp4-dq: 4-bit quantization from bitsandbytes + for more details, see https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials/quantize.md + precision: Indicates the Fabric precision setting to use. + compile: Whether to compile the model. + """ + precision = precision or get_default_supported_precision(training=False) + + plugins = None + if quantize is not None: + if compile: + raise NotImplementedError # untested + if "mixed" in precision: + raise ValueError( + "Quantization and mixed precision is not supported." + ) + dtype = { + "16-true": torch.float16, + "bf16-true": torch.bfloat16, + "32-true": torch.float32, + }[precision] + plugins = BitsandbytesPrecision(quantize[4:], dtype) + precision = None + + # set "ddp" as the strategy for the launching functionality, but there's no data-parallelism + fabric = L.Fabric( + devices="auto", strategy="ddp", precision=precision, plugins=plugins + ) + fabric.launch() + + check_valid_checkpoint_dir(checkpoint_dir) + + config = Config.from_json(checkpoint_dir / "lit_config.json") + + model_file = "lit_model.pth" + checkpoint_path = checkpoint_dir / model_file + + tokenizer = Tokenizer(checkpoint_dir) + encoded = tokenizer.encode(prompt, device=fabric.device) + prompt_length = encoded.size(0) + max_returned_tokens = prompt_length + max_new_tokens + + fabric.print( + f"Loading model {str(checkpoint_path)!r} with {config.__dict__}", + file=sys.stderr, + ) + t0 = time.perf_counter() + # cannot use `init_module` because if bitsandbytes is used, the Linear layers will be replaced + # which means that the weights will get quantized on cuda:0 on checkpoint load. we need to load and then convert + # still, use init_tensor for the precision + with fabric.init_tensor(), torch.device("meta"): + model = GPT(config) + fabric.print( + f"Time to instantiate model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + # sequentially do: load the checkpoint on CPU -> quantize -> apply tp -> move to device + # so that the CPU RAM doesn't OOM with larger models + for rank in range(fabric.world_size): + if fabric.global_rank == rank: + t0 = time.perf_counter() + state_dict = torch.load( + str(checkpoint_path), mmap=True, map_location="cpu" + ) + model.load_state_dict(state_dict, assign=True) + print( + f"[{rank}] Time to load the model weights: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + # cannot use `.setup_module` because it will wrap with DDP + model = fabric._precision.convert_module(model) + + t0 = time.perf_counter() + model = tensor_parallel(fabric, model) + print( + f"[{rank}] Time to tensor-parallelize the model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + + with fabric.init_tensor(): + # set the max_seq_length to limit the memory usage to what we need + model.max_seq_length = max_returned_tokens + # the rope cache which is on meta device + model.cos, model.sin = model.rope_cache() + # enable the kv cache + model.set_kv_cache(batch_size=1) + model.eval() + + t0 = time.perf_counter() + model = fabric.to_device(model) + print( + f"[{rank}] Time to move the model: {time.perf_counter() - t0:.02f} seconds.", + file=sys.stderr, + ) + fabric.barrier() + + if compile: + torch._dynamo.config.automatic_dynamic_shapes = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.coordinate_descent_tuning = True + generate_base.next_token = torch.compile( + generate_base.next_token, mode="reduce-overhead" + ) + + L.seed_everything(1234) + for i in range(num_samples): + t0 = time.perf_counter() + y = generate_base.generate( + model, + encoded, + max_returned_tokens, + temperature=temperature, + top_k=top_k, + eos_id=tokenizer.eos_id, + ) + t = time.perf_counter() - t0 + for block in model.transformer.h: + block.attn.kv_cache.reset_parameters() + fabric.print(tokenizer.decode(y)) + tokens_generated = y.size(0) - prompt_length + fabric.print( + f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_generated / t:.02f} tokens/sec", + file=sys.stderr, + ) + if fabric.device.type == "cuda": + fabric.print( + f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB", + file=sys.stderr, + ) + + +if __name__ == "__main__": + torch.set_float32_matmul_precision("high") + + bnb_logger = logging.getLogger( + "lightning.fabric.plugins.precision.bitsandbytes" + ) + bnb_logger.setLevel(logging.DEBUG) + bnb_logger.debug = rank_zero_only(bnb_logger.debug) + + CLI(main) diff --git a/examples/llm_finetuning/lit_gpt/__init__.py b/examples/llm_finetuning/lit_gpt/__init__.py new file mode 100644 index 00000000000..f3974ec0a1b --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/__init__.py @@ -0,0 +1,29 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import logging +import re + +from lightning_utilities.core.imports import RequirementCache + +from lit_gpt.model import GPT # isort: skip +from lit_gpt.config import Config # isort: skip +from lit_gpt.tokenizer import Tokenizer + +_LIGHTNING_AVAILABLE = RequirementCache("lightning>=2.2.0.dev0") +if not bool(_LIGHTNING_AVAILABLE): + raise ImportError( + "Lit-GPT requires lightning nightly. Please run:\n" + f" pip uninstall -y lightning; pip install -r requirements.txt\n{str(_LIGHTNING_AVAILABLE)}" + ) + +# Suppress excessive warnings, see https://github.com/pytorch/pytorch/issues/111632 +pattern = re.compile(".*Profiler function .* will be ignored") +logging.getLogger("torch._dynamo.variables.torch").addFilter( + lambda record: not pattern.search(record.getMessage()) +) + +# Avoid printing state-dict profiling output at the WARNING level when saving a checkpoint +logging.getLogger("torch.distributed.fsdp._optim_utils").disabled = True +logging.getLogger("torch.distributed.fsdp._debug_utils").disabled = True + +__all__ = ["GPT", "Config", "Tokenizer"] diff --git a/examples/llm_finetuning/lit_gpt/adapter.py b/examples/llm_finetuning/lit_gpt/adapter.py new file mode 100644 index 00000000000..61744419e4a --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/adapter.py @@ -0,0 +1,206 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation of the paper: + +LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention +https://arxiv.org/abs/2303.16199 + +Port for Lit-GPT +""" + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from typing_extensions import Self + +from lit_gpt.config import Config as BaseConfig +from lit_gpt.model import GPT as BaseModel +from lit_gpt.model import Block as BaseBlock +from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention + + +@dataclass +class Config(BaseConfig): + adapter_prompt_length: int = 10 + adapter_start_layer: int = 2 + + +class GPT(BaseModel): + """The implementation is identical to `lit_gpt.model.GPT` with the exception that + the `Block` saves the layer index and passes it down to the attention layer. + """ + + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear( + config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias + ) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList( + Block(config, i) for i in range(config.n_layer) + ), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + def forward( + self, + idx: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + lm_head_chunk_size: int = 0, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + T = idx.size(1) + if self.max_seq_length < T: + raise ValueError( + f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." + ) + + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + x = self.transformer.wte( + idx + ) # token embeddings of shape (b, t, n_embd) + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos) + x = self.transformer.ln_f(x) + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [ + self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1) + ] + return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, CausalSelfAttention): + module.reset_parameters() + + +class Block(BaseBlock): + """The implementation is identical to `lit_gpt.model.Block` with the exception that + we replace the attention layer where adaption is implemented.""" + + def __init__(self, config: Config, block_idx: int) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config, block_idx) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + +class CausalSelfAttention(BaseCausalSelfAttention): + """A modification of `lit_gpt.model.CausalSelfAttention` that adds the attention + over the adaption prompt.""" + + def __init__(self, config: Config, block_idx: int) -> None: + super().__init__(config) + if block_idx >= config.adapter_start_layer: + # adapter embedding layer + self.adapter_wte = nn.Embedding( + config.adapter_prompt_length, config.n_embd + ) + # gate for adaption + self.gating_factor = torch.nn.Parameter( + torch.zeros(1, 1, config.n_head, 1) + ) + # kv cache for inference + self.adapter_kv_cache: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None + self.block_idx = block_idx + + def scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + y = super().scaled_dot_product_attention(q, k, v, mask) + if self.block_idx < self.config.adapter_start_layer: + return y + + aT = self.config.adapter_prompt_length + if self.adapter_kv_cache is not None: + # since this uses the wte weights as the prefix and the kv cache is only used during inference, ak and av + # are the same every call + ak, av = self.adapter_kv_cache + else: + prefix = self.adapter_wte.weight.reshape(1, aT, self.config.n_embd) + aqkv = self.attn(prefix) + q_per_kv = self.config.n_head // self.config.n_query_groups + aqkv = aqkv.view( + 1, + aT, + self.config.n_query_groups, + q_per_kv + 2, + self.config.head_size, + ) + aqkv = aqkv.permute(0, 2, 3, 1, 4) + _, ak, av = aqkv.split((q_per_kv, 1, 1), dim=2) + if self.config.n_query_groups != 1: + # for MHA this is a no-op + ak = ak.repeat_interleave(q_per_kv, dim=2) + av = av.repeat_interleave(q_per_kv, dim=2) + ak = ak.view( + 1, -1, aT, self.config.head_size + ) # (1, nh_ak, aT, hs) + av = av.view( + 1, -1, aT, self.config.head_size + ) # (1, nh_av, aT, hs) + self.adapter_kv_cache = (ak, av) + + T = q.size(2) + amask = torch.ones(T, aT, dtype=torch.bool, device=q.device) + ay = super().scaled_dot_product_attention(q, ak, av, amask) + return y + self.gating_factor * ay + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.gating_factor) + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with older checkpoints.""" + if (key := prefix + "gating_factor") in state_dict and state_dict[ + key + ].size(1) == self.config.n_head: + state_dict[key] = state_dict[key].permute(0, 2, 1, 3) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def mark_only_adapter_as_trainable(model: GPT) -> None: + """Sets `requires_grad=False` for all non-adapter weights.""" + for name, param in model.named_parameters(): + param.requires_grad = adapter_filter(name, param) + + +def adapter_filter(key: str, value: Any) -> bool: + return "adapter_wte" in key or "gating_factor" in key diff --git a/examples/llm_finetuning/lit_gpt/adapter_v2.py b/examples/llm_finetuning/lit_gpt/adapter_v2.py new file mode 100644 index 00000000000..5d389471b55 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/adapter_v2.py @@ -0,0 +1,269 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation of the paper: + +LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model +https://arxiv.org/abs/2304.15010 + +Port for Lit-GPT +""" + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple, Type + +import torch +import torch.nn as nn +from typing_extensions import Self + +import lit_gpt +from lit_gpt.adapter import GPT as BaseModel +from lit_gpt.adapter import Block as BaseBlock +from lit_gpt.adapter import CausalSelfAttention as BaseCausalSelfAttention +from lit_gpt.adapter import Config as BaseConfig +from lit_gpt.model import KVCache +from lit_gpt.utils import map_old_state_dict_weights + + +@dataclass +class Config(BaseConfig): + @property + def mlp_class(self) -> Type: + return getattr(lit_gpt.adapter_v2, self._mlp_class) + + +def adapter_filter(key: str, value: Any) -> bool: + adapter_substrings = ( + # regular adapter v1 parameters + "adapter_wte", + "gating_factor", + # adapter v2: new bias and scale used in Linear + "adapter_scale", + "adapter_bias", + # adapter v2: Norm parameters are now trainable + "norm_1", + "norm_2", + "ln_f", + ) + return any(s in key for s in adapter_substrings) + + +class AdapterV2Linear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int, **kwargs) -> None: + super().__init__() + self.linear = torch.nn.Linear(in_features, out_features, **kwargs) + self.adapter_bias = torch.nn.Parameter( + torch.zeros(out_features), requires_grad=False + ) + self.adapter_scale = torch.nn.Parameter( + torch.ones(out_features), requires_grad=False + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.adapter_scale * (self.linear(x) + self.adapter_bias) + + def reset_parameters(self) -> None: + nn.init.zeros_(self.adapter_bias) + nn.init.ones_(self.adapter_scale) + + +class GPT(BaseModel): + def __init__(self, config: Config) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = AdapterV2Linear( + config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias + ) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList( + Block(config, i) for i in range(config.n_layer) + ), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, AdapterV2Linear): + module.reset_parameters() + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "lm_head.weight": "lm_head.linear.weight", + "lm_head.bias": "lm_head.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class Block(BaseBlock): + """The implementation is identical to `lit_gpt.model.Block` with the exception that + we replace the attention layer where adaption is implemented.""" + + def __init__(self, config: Config, block_idx: int) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config, block_idx) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + +class CausalSelfAttention(BaseCausalSelfAttention): + """A modification of `lit_gpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class""" + + def __init__(self, config: Config, block_idx: int) -> None: + # Skip the parent class __init__ altogether and replace it to avoid useless allocations + nn.Module.__init__(self) + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = AdapterV2Linear( + in_features=config.n_embd, out_features=shape, bias=config.bias + ) + # output projection + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` + self.proj = AdapterV2Linear( + config.head_size * config.n_head, config.n_embd, bias=config.bias + ) + # disabled by default + self.kv_cache: Optional[KVCache] = None + + if block_idx >= config.adapter_start_layer: + # adapter embedding layer + self.adapter_wte = nn.Embedding( + config.adapter_prompt_length, config.n_embd + ) + # gate for adaption + self.gating_factor = torch.nn.Parameter( + torch.zeros(1, 1, config.n_head, 1) + ) + # kv cache for inference + self.adapter_kv_cache: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None + self.block_idx = block_idx + + self.config = config + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "attn.weight": "attn.linear.weight", + "attn.bias": "attn.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + # For compatibility with older checkpoints + if (key := prefix + "gating_factor") in state_dict and state_dict[ + key + ].size(1) == self.config.n_head: + state_dict[key] = state_dict[key].permute(0, 2, 1, 3) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc = AdapterV2Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.proj = AdapterV2Linear( + config.intermediate_size, config.n_embd, bias=config.bias + ) + + self.config = config + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc.weight": "fc.linear.weight", + "fc.bias": "fc.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class LLaMAMLP(lit_gpt.model.LLaMAMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc_1 = AdapterV2Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.fc_2 = AdapterV2Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.proj = AdapterV2Linear( + config.intermediate_size, config.n_embd, bias=config.bias + ) + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc_1.weight": "fc_1.linear.weight", + "fc_1.bias": "fc_1.linear.bias", + "fc_2.weight": "fc_2.linear.weight", + "fc_2.bias": "fc_2.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class GemmaMLP(LLaMAMLP): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.gelu(x_fc_1) * x_fc_2 + return self.proj(x) + + +class LLaMAMoE(lit_gpt.model.LLaMAMoE): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.gate = AdapterV2Linear(config.n_embd, config.n_expert, bias=False) + self.experts = nn.ModuleList( + LLaMAMLP(config) for _ in range(config.n_expert) + ) + + self.config = config + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = {"gate.weight": "gate.linear.weight"} + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def mark_only_adapter_v2_as_trainable(model: GPT) -> None: + """Sets requires_grad=False for all non-adapter weights""" + for name, param in model.named_parameters(): + param.requires_grad = adapter_filter(name, param) diff --git a/examples/llm_finetuning/lit_gpt/args.py b/examples/llm_finetuning/lit_gpt/args.py new file mode 100644 index 00000000000..264c8f511ee --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/args.py @@ -0,0 +1,85 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + + +@dataclass +class TrainArgs: + """Training related arguments""" + + save_interval: int = 1000 + """Number of optimizer steps between checkpoints""" + log_interval: int = 1 + """Number of iterations between logging calls""" + global_batch_size: int = 64 + """Number of samples between optimizer steps across data-parallel ranks""" + micro_batch_size: int = 4 + """Number of samples per data-parallel rank""" + lr_warmup_steps: int = 100 + """Number of iterations with learning rate warmup active""" + epochs: Optional[int] = None + """Number of epochs to run""" + epoch_size: Optional[int] = None + """Size of the epoch""" + # TODO: pretrain/tinyllama is the only script using `max_tokens` explicitly. replace it with epoch_size*epochs? + max_tokens: Optional[int] = None + """Total number of tokens to train on""" + max_seq_length: Optional[int] = None + """Limits the length of samples. Off by default""" + + # Optimization args + learning_rate: float = 1e-3 + weight_decay: float = 0.02 + beta1: float = 0.9 + beta2: float = 0.95 + max_norm: Optional[float] = None + min_lr: float = 6e-5 + + def max_iters(self, devices: int) -> int: + """Number of iterations""" + max_iters = ( + self.epochs * self.epoch_size // devices // self.micro_batch_size + ) + assert max_iters > 0 + return max_iters + + def gradient_accumulation_iters(self, devices: int) -> int: + """Number of iterations between gradient synchronizations""" + gradient_accumulation_iters = ( + self.batch_size(devices) // self.micro_batch_size + ) + assert gradient_accumulation_iters > 0 + return gradient_accumulation_iters + + def batch_size(self, devices: int) -> int: + """Number of samples between optimizer steps per data-parallel rank""" + batch_size = self.global_batch_size // devices + assert batch_size > 0 + return batch_size + + +@dataclass +class EvalArgs: + """Evaluation related arguments""" + + interval: int = 600 + """Number of optimizer steps between evaluation calls""" + max_new_tokens: Optional[int] = None + """Number of tokens to generate""" + max_iters: int = 100 + """Number of iterations""" + + +@dataclass +class IOArgs: + """Inputs and outputs related arguments""" + + # Optional because pretrain/tinyllama hardcodes the path + train_data_dir: Optional[Path] = Path("data/alpaca") + """Where to read training data from""" + val_data_dir: Optional[Path] = None + """Where to read validation data from""" + checkpoint_dir: Optional[Path] = None + """Where to read weights and tokenizer data from""" + out_dir: Path = Path("out/adapter/alpaca") + """Where to save artifacts""" diff --git a/examples/llm_finetuning/lit_gpt/config.py b/examples/llm_finetuning/lit_gpt/config.py new file mode 100644 index 00000000000..dab1523ba53 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/config.py @@ -0,0 +1,1487 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +from copy import deepcopy +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal, Optional, Type, Union + +import torch +from typing_extensions import Self + +import lit_gpt.model +from lit_gpt.utils import find_multiple + + +@dataclass +class Config: + name: str = "" + hf_config: dict = field(default_factory=dict) + scale_embeddings: bool = False + block_size: int = 4096 + vocab_size: int = 50254 + padding_multiple: int = 512 + padded_vocab_size: Optional[int] = None + n_layer: int = 16 + n_head: int = 32 + head_size: Optional[int] = None + n_embd: int = 4096 + rotary_percentage: float = 0.25 + parallel_residual: bool = True + bias: bool = True + lm_head_bias: bool = False + # to use multi-head attention (MHA), set this to `n_head` (default) + # to use multi-query attention (MQA), set this to 1 + # to use grouped-query attention (GQA), set this to a value in between + # Example with `n_head=4` + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + # │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ + # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + # │ │ │ │ │ │ │ + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ + # │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ + # └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ + # │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ + # ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ + # │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ + # └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ + # ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ + # MHA GQA MQA + # n_query_groups=4 n_query_groups=2 n_query_groups=1 + # + # credit https://arxiv.org/pdf/2305.13245.pdf + n_query_groups: Optional[int] = None + shared_attention_norm: bool = False + _norm_class: Literal["LayerNorm", "RMSNorm"] = "LayerNorm" + norm_eps: float = 1e-5 + _mlp_class: Literal[ + "GptNeoxMLP", "LLaMAMLP", "GemmaMLP", "LLaMAMoE" + ] = "GptNeoxMLP" + gelu_approximate: str = "none" + intermediate_size: Optional[int] = None + rope_condense_ratio: int = 1 + rope_base: int = 10000 + n_expert: int = 0 + n_expert_per_token: int = 0 + + def __post_init__(self): + if not self.name: + self.name = self.hf_config.get("name", self.name) + + if self.head_size is None: + assert self.n_embd % self.n_head == 0 + self.head_size = self.n_embd // self.n_head + + # vocab size should be a power of 2 to be optimal on hardware. compute the closest value + if self.padded_vocab_size is None: + self.padded_vocab_size = find_multiple( + self.vocab_size, self.padding_multiple + ) + else: + # vocab size shouldn't be larger than padded vocab size + self.vocab_size = min(self.vocab_size, self.padded_vocab_size) + + # compute the number of query groups + if self.n_query_groups is not None: + assert self.n_head % self.n_query_groups == 0 + else: + self.n_query_groups = self.n_head + + # compute the intermediate size for MLP if not set + if self.intermediate_size is None: + if self._mlp_class == "LLaMAMLP": + raise ValueError( + "The config needs to set the `intermediate_size`" + ) + self.intermediate_size = 4 * self.n_embd + + self.rope_n_elem = int(self.rotary_percentage * self.head_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + if name not in name_to_config: + # search through all `config['hf_config']['name']` + try: + conf_dict = next( + config + for config in configs + if name == config["hf_config"]["name"] + ) + except StopIteration: + raise ValueError(f"{name!r} is not a supported config name") + else: + conf_dict = name_to_config[name] + + conf_dict = conf_dict.copy() + if "condense_ratio" in kwargs: # legacy name + kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio") + conf_dict.update(kwargs) + return cls(**conf_dict) + + @classmethod + def from_json(cls, path: Union[str, Path], **kwargs: Any) -> Self: + with open(path, encoding="utf-8") as fp: + json_kwargs = json.load(fp) + if "condense_ratio" in json_kwargs: # legacy name + json_kwargs["rope_condense_ratio"] = json_kwargs.pop( + "condense_ratio" + ) + if "condense_ratio" in kwargs: # legacy name + kwargs["rope_condense_ratio"] = kwargs.pop("condense_ratio") + if "org" in json_kwargs: # legacy name + json_kwargs["hf_config"] = { + "name": json_kwargs["name"], + "org": json_kwargs.pop("org"), + } + if "org" in kwargs: # legacy name + kwargs["hf_config"] = { + "name": kwargs.get("name", json_kwargs["name"]), + "org": kwargs.pop("org"), + } + json_kwargs.update(kwargs) + return cls(**json_kwargs) + + @classmethod + def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self: + """Automatically load `lit_config.json` and if it doesn't exist - a matching config from `lit_gpt/config.py`.""" + if (config_path := path / "lit_config.json").is_file(): + return cls.from_json(config_path, **kwargs) + if (model_name := path.name) in name_to_config: + return cls.from_name(model_name, **kwargs) + raise FileNotFoundError( + f"For {str(path)!r} neither 'lit_config.json' nor matching config exists." + ) + + @property + def mlp_class(self) -> Type: + # `self._mlp_class` cannot be the type to keep the config json serializable + return getattr(lit_gpt.model, self._mlp_class) + + @property + def norm_class(self) -> Type: + # `self._norm_class` cannot be the type to keep the config json serializable + if self._norm_class == "RMSNorm": + from functools import partial + + from lit_gpt.rmsnorm import RMSNorm + + return partial(RMSNorm, add_unit_offset="Gemma" in self.name) + return getattr(torch.nn, self._norm_class) + + +######################## +# Stability AI StableLM +######################## +configs = [ + # https://huggingface.co/stabilityai/stablelm-base-alpha-3b/blob/main/config.json + dict( + name="stablelm-base-alpha-3b", + hf_config=dict(org="stabilityai", name="stablelm-base-alpha-3b"), + ), + # https://huggingface.co/stabilityai/stablelm-base-alpha-7b/blob/main/config.json + dict( + name="stablelm-base-alpha-7b", + hf_config=dict(org="stabilityai", name="stablelm-base-alpha-7b"), + n_head=48, + n_embd=6144, + padding_multiple=256, + ), + # https://huggingface.co/stabilityai/stablelm-tuned-alpha-3b/blob/main/config.json + dict( + name="stablelm-tuned-alpha-3b", + hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-3b"), + n_head=32, + ), + # https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b/blob/main/config.json + dict( + name="stablelm-tuned-alpha-7b", + hf_config=dict(org="stabilityai", name="stablelm-tuned-alpha-7b"), + n_head=48, + n_embd=6144, + padding_multiple=256, + ), + # https://huggingface.co/stabilityai/stablelm-zephyr-3b/blob/main/config.json + dict( + name="stablelm-zephyr-3b", + hf_config=dict(org="stabilityai", name="stablelm-zephyr-3b"), + padded_vocab_size=50304, + n_layer=32, + n_head=32, + n_embd=2560, + parallel_residual=False, + bias=False, + _mlp_class="LLaMAMLP", + intermediate_size=6912, + ), +] + +#################### +# EleutherAI Pythia +#################### +pythia = [ + # https://huggingface.co/EleutherAI/pythia-14m/blob/main/config.json + dict( + name="pythia-14m", + hf_config=dict(org="EleutherAI", name="pythia-14m"), + block_size=512, + n_layer=6, + n_embd=128, + n_head=4, + padding_multiple=128, + ), + # https://huggingface.co/EleutherAI/pythia-31m/blob/main/config.json + dict( + name="pythia-31m", + hf_config=dict(org="EleutherAI", name="pythia-31m"), + block_size=1024, + n_layer=6, + n_embd=256, + n_head=8, + padding_multiple=128, + ), + # https://huggingface.co/EleutherAI/pythia-70m/blob/main/config.json + dict( + name="pythia-70m", + hf_config=dict(org="EleutherAI", name="pythia-70m"), + block_size=2048, + n_layer=6, + n_embd=512, + n_head=8, + padding_multiple=128, + ), + # https://huggingface.co/EleutherAI/pythia-160m/blob/main/config.json + dict( + name="pythia-160m", + hf_config=dict(org="EleutherAI", name="pythia-160m"), + block_size=2048, + n_layer=12, + n_embd=768, + n_head=12, + padding_multiple=128, + ), + # https://huggingface.co/EleutherAI/pythia-410m/blob/main/config.json + dict( + name="pythia-410m", + hf_config=dict(org="EleutherAI", name="pythia-410m"), + block_size=2048, + n_layer=24, + n_embd=1024, + n_head=16, + padding_multiple=128, + ), + # https://huggingface.co/EleutherAI/pythia-1b/blob/main/config.json + dict( + name="pythia-1b", + hf_config=dict(org="EleutherAI", name="pythia-1b"), + block_size=2048, + n_embd=2048, + n_head=8, + padding_multiple=128, + ), + # https://huggingface.co/EleutherAI/pythia-1.4b/blob/main/config.json + dict( + name="pythia-1.4b", + hf_config=dict(org="EleutherAI", name="pythia-1.4b"), + block_size=2048, + n_layer=24, + n_embd=2048, + n_head=16, + padding_multiple=128, + ), + # https://huggingface.co/EleutherAI/pythia-2.8b/blob/main/config.json + dict( + name="pythia-2.8b", + hf_config=dict(org="EleutherAI", name="pythia-2.8b"), + block_size=2048, + n_layer=32, + n_embd=2560, + padding_multiple=128, + ), + # https://huggingface.co/EleutherAI/pythia-6.9b/blob/main/config.json + dict( + name="pythia-6.9b", + hf_config=dict(org="EleutherAI", name="pythia-6.9b"), + block_size=2048, + n_layer=32, + padding_multiple=256, + ), + # https://huggingface.co/EleutherAI/pythia-12b/blob/main/config.json + dict( + name="pythia-12b", + hf_config=dict(org="EleutherAI", name="pythia-12b"), + block_size=2048, + n_layer=36, + n_embd=5120, + n_head=40, + ), +] +configs.extend(pythia) +for c in pythia: + # "pythia-14m" and "pythia-31m" don't have deduped version + if c["name"] in ("pythia-14m", "pythia-31m"): + continue + copy = deepcopy(c) + copy["name"] = f"{c['name']}-deduped" + copy["hf_config"]["name"] = f"{c['hf_config']['name']}-deduped" + configs.append(copy) + + +################### +# databricks Dolly +################### +dolly = [ + # https://huggingface.co/databricks/dolly-v2-3b/blob/main/config.json + dict( + name="dolly-v2-3b", + hf_config=dict(org="databricks", name="dolly-v2-3b"), + block_size=2048, + n_layer=32, + n_embd=2560, + padded_vocab_size=50280, + ), + # https://huggingface.co/databricks/dolly-v2-7b/blob/main/config.json + dict( + name="dolly-v2-7b", + hf_config=dict(org="databricks", name="dolly-v2-7b"), + block_size=2048, + n_layer=32, + padded_vocab_size=50280, + ), + # https://huggingface.co/databricks/dolly-v2-12b/blob/main/config.json + dict( + name="dolly-v2-12b", + hf_config=dict(org="databricks", name="dolly-v2-12b"), + block_size=2048, + n_layer=36, + n_embd=5120, + n_head=40, + padded_vocab_size=50280, + ), +] +configs.extend(dolly) + + +#################################### +# togethercomputer RedPajama INCITE +#################################### +redpajama_incite = [ + # https://huggingface.co/togethercomputer/RedPajama-INCITE-Base-3B-v1/blob/main/config.json + dict( + name="RedPajama-INCITE-{}-3B-v1", + hf_config=dict( + org="togethercomputer", name="RedPajama-INCITE-{}-3B-v1" + ), + block_size=2048, + n_layer=32, + n_embd=2560, + padding_multiple=256, + rotary_percentage=1.0, + parallel_residual=False, + ), + # https://huggingface.co/togethercomputer/RedPajama-INCITE-7B-Base/blob/main/config.json + dict( + name="RedPajama-INCITE-7B-{}", + hf_config=dict(org="togethercomputer", name="RedPajama-INCITE-7B-{}"), + block_size=2048, + n_layer=32, + padding_multiple=256, + rotary_percentage=1.0, + parallel_residual=False, + ), + # this redirects to the checkpoint above. kept for those who had the old weights already downloaded + dict( + name="RedPajama-INCITE-{}-7B-v0.1", + hf_config=dict( + org="togethercomputer", name="RedPajama-INCITE-{}-7B-v0.1" + ), + block_size=2048, + n_layer=32, + padding_multiple=256, + rotary_percentage=1.0, + parallel_residual=False, + ), +] +for c in redpajama_incite: + for kind in ("Base", "Chat", "Instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + + +################# +# TII UAE Falcon +################# +falcon = [ + # https://huggingface.co/tiiuae/falcon-7b/blob/main/config.json + dict( + name="falcon-7b{}", + hf_config=dict(org="tiiuae", name="falcon-7b{}"), + block_size=2048, + vocab_size=65024, + padded_vocab_size=65024, + n_layer=32, + n_head=71, + n_embd=4544, + rotary_percentage=1.0, + n_query_groups=1, + bias=False, + # this is not in the config, but in the original model implementation, only for this config + shared_attention_norm=True, + ), + # https://huggingface.co/tiiuae/falcon-40b/blob/main/config.json + dict( + name="falcon-40b{}", + hf_config=dict(org="tiiuae", name="falcon-40b{}"), + block_size=2048, + vocab_size=65024, + padded_vocab_size=65024, + n_layer=60, + n_head=128, + n_embd=8192, + rotary_percentage=1.0, + n_query_groups=8, + bias=False, + ), +] +for c in falcon: + for kind in ("", "-instruct"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + +# https://huggingface.co/tiiuae/falcon-180b/blob/main/config.json +falcon180b = dict( + name="falcon-180B{}", + hf_config=dict(org="tiiuae", name="falcon-180B{}"), + block_size=2048, + vocab_size=65024, + padded_vocab_size=65024, + n_layer=80, + n_head=232, + n_embd=14848, + rotary_percentage=1.0, + n_query_groups=8, + bias=False, +) + +for kind in ("", "-chat"): + copy = deepcopy(falcon180b) + copy["name"] = falcon180b["name"].format(kind) + copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind) + configs.append(copy) + + +############################# +# OpenLM Research Open LLaMA +############################# +open_LLaMA = [ + # https://huggingface.co/openlm-research/open_llama_3b/blob/main/config.json + dict( + name="open_llama_3b", + hf_config=dict(org="openlm-research", name="open_llama_3b"), + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=26, + n_embd=3200, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=8640, + ), + # https://huggingface.co/openlm-research/open_llama_7b/blob/main/config.json + dict( + name="open_llama_7b", + hf_config=dict(org="openlm-research", name="open_llama_7b"), + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/openlm-research/open_llama_13b/blob/main/config.json + dict( + name="open_llama_13b", + hf_config=dict(org="openlm-research", name="open_llama_13b"), + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), +] +configs.extend(open_LLaMA) + + +############### +# LMSYS Vicuna +############### +vicuna = [ + # https://huggingface.co/lmsys/vicuna-7b-v1.3/blob/main/config.json + dict( + name="vicuna-7b-v1.3", + hf_config=dict(org="lmsys", name="vicuna-7b-v1.3"), + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/lmsys/vicuna-13b-v1.3/blob/main/config.json + dict( + name="vicuna-13b-v1.3", + hf_config=dict(org="lmsys", name="vicuna-13b-v1.3"), + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/lmsys/vicuna-33b-v1.3/blob/main/config.json + dict( + name="vicuna-33b-v1.3", + hf_config=dict(org="lmsys", name="vicuna-33b-v1.3"), + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=60, + n_head=52, + n_embd=6656, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=17920, + ), + # https://huggingface.co/lmsys/vicuna-7b-v1.5/blob/main/config.json + dict( + name="vicuna-7b-v1.5", + hf_config=dict(org="lmsys", name="vicuna-7b-v1.5"), + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/lmsys/vicuna-7b-v1.5-16k/blob/main/config.json + dict( + name="vicuna-7b-v1.5-16k", + hf_config=dict(org="lmsys", name="vicuna-7b-v1.5-16k"), + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_condense_ratio=4, + ), + # https://huggingface.co/lmsys/vicuna-13b-v1.5/blob/main/config.json + dict( + name="vicuna-13b-v1.5", + hf_config=dict(org="lmsys", name="vicuna-13b-v1.5"), + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/lmsys/vicuna-13b-v1.5-16k/blob/main/config.json + dict( + name="vicuna-13b-v1.5-16k", + hf_config=dict(org="lmsys", name="vicuna-13b-v1.5-16k"), + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=13824, + rope_condense_ratio=4, + ), +] +configs.extend(vicuna) + + +################# +# LMSYS LongChat +################# +long_chat = [ + # https://huggingface.co/lmsys/longchat-7b-16k/blob/main/config.json + dict( + name="longchat-7b-16k", + hf_config=dict(org="lmsys", name="longchat-7b-16k"), + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_condense_ratio=8, + ), + # https://huggingface.co/lmsys/longchat-13b-16k/blob/main/config.json + dict( + name="longchat-13b-16k", + hf_config=dict(org="lmsys", name="longchat-13b-16k"), + block_size=16384, + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + rope_condense_ratio=8, + ), +] +configs.extend(long_chat) + + +###################### +# NousResearch Hermes +###################### +nous_research = [ + # https://huggingface.co/NousResearch/Nous-Hermes-llama-2-7b/blob/main/config.json + dict( + name="Nous-Hermes-llama-2-7b", + hf_config=dict(org="NousResearch", name="Nous-Hermes-llama-2-7b"), + padded_vocab_size=32000, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/NousResearch/Nous-Hermes-13B/blob/main/config.json + dict( + name="Nous-Hermes-13b", + hf_config=dict(org="NousResearch", name="Nous-Hermes-13b"), + block_size=2048, + vocab_size=32000, + padded_vocab_size=32001, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-6, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/NousResearch/Nous-Hermes-Llama2-13b + dict( + name="Nous-Hermes-Llama2-13b", + hf_config=dict(org="NousResearch", name="Nous-Hermes-Llama2-13b"), + vocab_size=32000, + padded_vocab_size=32032, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), +] +configs.extend(nous_research) + + +############### +# Meta LLaMA 2 +############### +llama_2 = [ + # https://huggingface.co/meta-llama/Llama-2-7b-hf/blob/main/config.json + dict( + name="Llama-2-7b{}-hf", + hf_config=dict(org="meta-llama", name="Llama-2-7b{}-hf"), + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/meta-llama/Llama-2-13b-hf/blob/main/config.json + dict( + name="Llama-2-13b{}-hf", + hf_config=dict(org="meta-llama", name="Llama-2-13b{}-hf"), + vocab_size=32000, + padding_multiple=64, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/meta-llama/Llama-2-70b-hf/blob/main/config.json + dict( + name="Llama-2-70b{}-hf", + hf_config=dict(org="meta-llama", name="Llama-2-70b{}-hf"), + vocab_size=32000, + padding_multiple=64, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ), +] +for c in llama_2: + for kind in ("", "-chat"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) + + +############### +# Google Gemma +############### +gemma = [ + # https://huggingface.co/google/gemma-2b/blob/main/config.json + dict( + name="Gemma-2b", + hf_config=dict(org="google", name="gemma-2b"), + scale_embeddings=True, + vocab_size=256000, + padding_multiple=64, + n_embd=2048, + n_layer=18, + n_head=8, + n_query_groups=1, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="GemmaMLP", + intermediate_size=16384, + ), + # https://huggingface.co/google/gemma-7b/blob/main/config.json + dict( + name="Gemma-7b", + hf_config=dict(org="google", name="gemma-7b"), + scale_embeddings=True, + vocab_size=256000, + padding_multiple=64, + n_embd=3072, + n_layer=28, + n_head=16, + head_size=256, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="GemmaMLP", + intermediate_size=24576, + ), +] +configs.extend(gemma) +for c in gemma: + copy = deepcopy(c) + copy["name"] = f"{c['name']}-it" + copy["hf_config"]["name"] = f"{c['hf_config']['name']}-it" + configs.append(copy) + + +########################## +# Stability AI FreeWilly2 +########################## +freewilly_2 = [ + # https://huggingface.co/stabilityai/FreeWilly2/blob/main/config.json + dict( + name="FreeWilly2", + hf_config=dict(org="stabilityai", name="FreeWilly2"), + vocab_size=32000, + padding_multiple=64, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ) +] +configs.extend(freewilly_2) + + +################## +# Meta Code Llama +################## +code_llama = [ + # https://huggingface.co/codellama/CodeLlama-7b-hf/blob/main/config.json + dict( + name="CodeLlama-7b-hf", + hf_config=dict(org="codellama", name="CodeLlama-7b-hf"), + block_size=16384, + vocab_size=32016, + padding_multiple=16, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-13b-hf/blob/main/config.json + dict( + name="CodeLlama-13b-hf", + hf_config=dict(org="codellama", name="CodeLlama-13b-hf"), + block_size=16384, + vocab_size=32016, + padding_multiple=16, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-34b-hf/blob/main/config.json + dict( + name="CodeLlama-34b-hf", + hf_config=dict(org="codellama", name="CodeLlama-34b-hf"), + block_size=16384, + vocab_size=32000, + padded_vocab_size=32000, + n_layer=48, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=22016, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-70b-hf/blob/main/config.json + dict( + name="CodeLlama-70b-hf", + hf_config=dict(org="codellama", name="CodeLlama-70b-hf"), + block_size=16384, + vocab_size=32016, + padding_multiple=16, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=28672, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-7b-Python-hf/blob/main/config.json + dict( + name="CodeLlama-7b-Python-hf", + hf_config=dict(org="codellama", name="CodeLlama-7b-Python-hf"), + block_size=16384, + vocab_size=32000, + padded_vocab_size=32000, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-13b-Python-hf/blob/main/config.json + dict( + name="CodeLlama-13b-Python-hf", + hf_config=dict(org="codellama", name="CodeLlama-13b-Python-hf"), + block_size=16384, + vocab_size=32000, + padded_vocab_size=32000, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-34b-Python-hf/blob/main/config.json + dict( + name="CodeLlama-34b-Python-hf", + hf_config=dict(org="codellama", name="CodeLlama-34b-Python-hf"), + block_size=16384, + vocab_size=32000, + padded_vocab_size=32000, + n_layer=48, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=22016, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-70b-Python-hf/blob/main/config.json + dict( + name="CodeLlama-70b-Python-hf", + hf_config=dict(org="codellama", name="CodeLlama-70b-Python-hf"), + block_size=16384, + vocab_size=32016, + padding_multiple=16, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=28672, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-7b-Instruct-hf/blob/main/config.json + dict( + name="CodeLlama-7b-Instruct-hf", + hf_config=dict(org="codellama", name="CodeLlama-7b-Instruct-hf"), + block_size=16384, + vocab_size=32016, + padding_multiple=16, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-13b-Instruct-hf/blob/main/config.json + dict( + name="CodeLlama-13b-Instruct-hf", + hf_config=dict(org="codellama", name="CodeLlama-13b-Instruct-hf"), + block_size=2048, + vocab_size=32016, + padding_multiple=16, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-34b-Instruct-hf/blob/main/config.json + dict( + name="CodeLlama-34b-Instruct-hf", + hf_config=dict(org="codellama", name="CodeLlama-34b-Instruct-hf"), + block_size=16384, + vocab_size=32000, + padded_vocab_size=32000, + n_layer=48, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=22016, + rope_base=1000000, + ), + # https://huggingface.co/codellama/CodeLlama-70b-Instruct-hf/blob/main/config.json + dict( + name="CodeLlama-70b-Instruct-hf", + hf_config=dict(org="codellama", name="CodeLlama-70b-Instruct-hf"), + block_size=16384, + vocab_size=32016, + padding_multiple=16, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=28672, + rope_base=1000000, + ), +] +configs.extend(code_llama) + + +######################## +# garage-bAInd Platypus +######################## +platypus = [ + # https://huggingface.co/garage-bAInd/Platypus-30B/blob/main/config.json + dict( + name="Platypus-30B", + hf_config=dict(org="garage-bAInd", name="Platypus-30B"), + block_size=2048, + padded_vocab_size=32000, + n_layer=60, + n_head=52, + n_embd=6656, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-06, + _mlp_class="LLaMAMLP", + intermediate_size=17920, + ), + # https://huggingface.co/garage-bAInd/Platypus2-7B/blob/main/config.json + dict( + name="Platypus2-7B", + hf_config=dict(org="garage-bAInd", name="Platypus2-7B"), + padded_vocab_size=32000, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=11008, + ), + # https://huggingface.co/garage-bAInd/Platypus2-13B/blob/main/config.json + dict( + name="Platypus2-13B", + hf_config=dict(org="garage-bAInd", name="Platypus2-13B"), + padded_vocab_size=32000, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/garage-bAInd/Platypus2-70B/blob/main/config.json + dict( + name="Platypus2-70B", + hf_config=dict(org="garage-bAInd", name="Platypus2-70B"), + padded_vocab_size=32000, + n_layer=80, + n_head=64, + n_embd=8192, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ), + # https://huggingface.co/garage-bAInd/Camel-Platypus2-13B/blob/main/config.json + dict( + name="Camel-Platypus2-13B", + hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-13B"), + padded_vocab_size=32000, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/garage-bAInd/Camel-Platypus2-70B/blob/main/config.json + dict( + name="Camel-Platypus2-70B", + hf_config=dict(org="garage-bAInd", name="Camel-Platypus2-70B"), + padded_vocab_size=32000, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ), + # https://huggingface.co/garage-bAInd/Stable-Platypus2-13B/blob/main/config.json + dict( + name="Stable-Platypus2-13B", + hf_config=dict(org="garage-bAInd", name="Stable-Platypus2-13B"), + padded_vocab_size=32000, + n_layer=40, + n_head=40, + n_embd=5120, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=13824, + ), + # https://huggingface.co/garage-bAInd/Platypus2-70B-instruct/blob/main/config.json + dict( + name="Platypus2-70B-instruct", + hf_config=dict(org="garage-bAInd", name="Platypus2-70B-instruct"), + padded_vocab_size=32000, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=28672, + ), +] +configs.extend(platypus) + + +########################## +# Stability AI StableCode +########################## +stablecode = [ + # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b/blob/main/config.json + dict( + name="stablecode-completion-alpha-3b", + hf_config=dict( + org="stabilityai", name="stablecode-completion-alpha-3b" + ), + block_size=16384, + vocab_size=49152, + n_layer=32, + n_embd=2560, + ), + # https://huggingface.co/stabilityai/stablecode-completion-alpha-3b-4k/blob/main/config.json + dict( + name="stablecode-completion-alpha-3b-4k", + hf_config=dict( + org="stabilityai", name="stablecode-completion-alpha-3b-4k" + ), + vocab_size=49152, + n_layer=32, + n_embd=2560, + ), + # https://huggingface.co/stabilityai/stablecode-instruct-alpha-3b/blob/main/config.json + dict( + name="stablecode-instruct-alpha-3b", + hf_config=dict(org="stabilityai", name="stablecode-instruct-alpha-3b"), + vocab_size=49152, + n_layer=32, + n_embd=2560, + ), +] +configs.extend(stablecode) + + +################################## +# togethercomputer LLaMA-2-7B-32K +################################## +together_llama2_32k = [ + # https://huggingface.co/togethercomputer/LLaMA-2-7B-32K/blob/main/config.json + dict( + name="LLaMA-2-7B-32K", + hf_config=dict(org="togethercomputer", name="LLaMA-2-7B-32K"), + vocab_size=32000, + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=11008, + rope_condense_ratio=8, + ) +] +configs.extend(together_llama2_32k) + + +################ +# Microsoft Phi +################ +phi = [ + # https://huggingface.co/microsoft/phi-1_5/blob/main/config.json + dict( + name="phi-1_5", + hf_config=dict(org="microsoft", name="phi-1_5"), + vocab_size=50257, + padded_vocab_size=51200, + block_size=2048, + n_embd=2048, + n_layer=24, + rotary_percentage=0.5, # 32 / (n_embd / n_head) = 32 / 64 + shared_attention_norm=True, + lm_head_bias=True, + gelu_approximate="tanh", + ), + # https://huggingface.co/microsoft/phi-2/blob/main/config.json + dict( + name="phi-2", + hf_config=dict(org="microsoft", name="phi-2"), + vocab_size=50257, + padded_vocab_size=51200, + block_size=2048, + n_embd=2560, + n_layer=32, + rotary_percentage=0.4, # 32 / (n_embd / n_head) = 32 / 80 + shared_attention_norm=True, + lm_head_bias=True, + gelu_approximate="tanh", + ), +] +configs.extend(phi) + + +############# +# Mistral AI +############# +mistral = [ + # https://huggingface.co/mistralai/Mistral-7B-v0.1/blob/main/config.json + dict( + name="Mistral-7B-{}v0.1", + hf_config=dict(org="mistralai", name="Mistral-7B-{}v0.1"), + padded_vocab_size=32000, + block_size=4096, # should be 32768 but sliding window attention is not implemented + n_layer=32, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=14336, + ), + # https://huggingface.co/mistralai/Mixtral-8x7B-v0.1/blob/main/config.json + dict( + name="Mixtral-8x7B-{}v0.1", + hf_config=dict(org="mistralai", name="Mixtral-8x7B-{}v0.1"), + padded_vocab_size=32000, + block_size=32768, + n_layer=32, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMoE", + intermediate_size=14336, + rope_base=1000000, + n_expert=8, + n_expert_per_token=2, + ), +] +for c in mistral: + for kind in ("", "Instruct-"): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) + configs.append(copy) +configs.append( + # https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/config.json + dict( + name="Mistral-7B-Instruct-v0.2", + hf_config=dict(org="mistralai", name="Mistral-7B-Instruct-v0.2"), + padded_vocab_size=32000, + block_size=32768, + n_layer=32, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + norm_eps=1e-05, + _mlp_class="LLaMAMLP", + intermediate_size=14336, + ) +) + + +############ +# TinyLlama +############ +tiny_llama = [ + dict( + name="tiny-llama-1.1b{}", + hf_config=dict(org="TinyLlama", name="TinyLlama-1.1B{}"), + block_size=2048, + vocab_size=32000, + padding_multiple=64, + n_layer=22, + n_head=32, + n_embd=2048, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", # original TinyLlama uses FusedRMSNorm + norm_eps=1e-5, + _mlp_class="LLaMAMLP", + intermediate_size=5632, + n_query_groups=4, + ) +] +for c in tiny_llama: + for kind, hf_postfix in ( + ("", "-intermediate-step-1431k-3T"), + ("-chat", "-Chat-v1.0"), + ): + copy = deepcopy(c) + copy["name"] = c["name"].format(kind) + copy["hf_config"]["name"] = c["hf_config"]["name"].format(hf_postfix) + configs.append(copy) + + +########################## +# Trelis Function Calling +########################## +llama_2_function_calling = [ + # https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2/blob/main/config.json + dict( + name="Llama-2-7b-chat-hf-function-calling-v2", + hf_config=dict( + org="Trelis", name="Llama-2-7b-chat-hf-function-calling-v2" + ), + padding_multiple=64, + n_layer=32, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + _norm_class="RMSNorm", + _mlp_class="LLaMAMLP", + intermediate_size=11008, + norm_eps=1e-6, + block_size=4096, + vocab_size=32000, + n_head=32, + n_embd=4096, + rope_base=10000, + ) +] + +configs.extend(llama_2_function_calling) + +name_to_config = {config["name"]: config for config in configs} diff --git a/examples/llm_finetuning/lit_gpt/lora.py b/examples/llm_finetuning/lit_gpt/lora.py new file mode 100644 index 00000000000..84d42543e73 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/lora.py @@ -0,0 +1,816 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +# Derived from https://github.com/microsoft/LoRA +# ------------------------------------------------------------------------------------------ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information. +# ------------------------------------------------------------------------------------------ + +r""" + Low Ranking Adaptation for LLMs scheme. + + ┌───────────────────┐ + ┆ h ┆ + └───────────────────┘ + ▲ + | + + + / \ + ┌─────────────────┐ ╭───────────────╮ Matrix initialization: + ┆ ┆ \ B / B = 0 + ┆ pretrained ┆ \ r*d / A = N(0, sigma^2) + ┆ weights ┆ ╰─────────╯ + ┆ ┆ | r | r - rank + ┆ W e R^(d*d) ┆ | ◀─────▶ | + ┆ ┆ ╭─────────╮ + └─────────────────┘ / A \ + ▲ / d*r \ + \ ╰───────────────╯ + \ ▲ + \ / + \ / + ┌───────────────────┐ + ┆ x ┆ + └───────────────────┘ + +With LoRA (Low Ranking Adaptation: https://arxiv.org/abs/2106.09685) instead of learning weights of size d*d, +we can freeze the pretrained weights and instead learn two matrices of size d*r and r*d (they will store weight updates +for the pretrained weights): the number of parameters in this case will be reduced drastically (depending on the rank of +course) yet after multiplication of matrices d*r and r*d we will get a matrix d*d which we can sum with frozen +pretrained weights and thus fine-tune the model. + +The goal of this approach is to move weight updates into a separate matrix which is decomposed with +two matrices of a lower rank. +""" + +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import torch +import torch.nn as nn +from torch.nn import functional as F +from typing_extensions import Self + +import lit_gpt +from lit_gpt.config import Config as BaseConfig +from lit_gpt.model import GPT as BaseModel +from lit_gpt.model import Block as BaseBlock +from lit_gpt.model import CausalSelfAttention as BaseCausalSelfAttention +from lit_gpt.model import KVCache +from lit_gpt.utils import map_old_state_dict_weights + + +class LoRALayer(nn.Module): + def __init__(self, r: int, lora_alpha: int, lora_dropout: float): + """Store LoRA specific attributes in a class. + + Args: + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + """ + super().__init__() + assert r >= 0 + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + + +class LoRALinear(LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + # ↓ this part is for pretrained weights + in_features: int, + out_features: int, + # ↓ the remaining part is for LoRA + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + **kwargs: Any, + ): + """LoRA wrapper around linear class. + + This class has three weight matrices: + 1. Pretrained weights are stored as `self.linear.weight` + 2. LoRA A matrix as `self.lora_A` + 3. LoRA B matrix as `self.lora_B` + Only LoRA's A and B matrices are updated, pretrained weights stay frozen. + + Args: + in_features: number of input features of the pretrained weights + out_features: number of output features of the pretrained weights + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + """ + super().__init__(r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout) + self.linear = torch.nn.Linear(in_features, out_features, **kwargs) + + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(torch.zeros((r, in_features))) + self.lora_B = nn.Parameter(torch.zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + self.reset_parameters() + + def reset_parameters(self) -> None: + """Reset all the weights, even including pretrained ones.""" + if hasattr(self, "lora_A"): + # initialize A the same way as the default for nn.Linear and B to zero + # Wondering why 'a' is equal to math.sqrt(5)?: https://github.com/pytorch/pytorch/issues/15314 + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def get_lora_AB(self) -> torch.Tensor: + """Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.""" + return (self.lora_B @ self.lora_A) * self.scaling + + def merge(self) -> None: + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + if self.r > 0 and not self.merged: + pretrained_dtype = self.linear.weight.data.dtype + lora_data = self.get_lora_AB() + # if the pretrained weights and LoRA weights are of the same dtype - simply sum them + if pretrained_dtype == lora_data.dtype: + self.linear.weight.data += lora_data + # if only the pretrained are in quantized form - dequantize, sum with LoRA and quantize the result + elif pretrained_dtype == torch.uint8: + import bitsandbytes as bnb + + weight = self.linear.weight + # dequantize the pretrained weights + weight_data = bnb.functional.dequantize_4bit( + weight.data, weight.quant_state + ).to(lora_data.dtype) + # add pretrained and LoRA weights + weight_data += lora_data + # assign updated weights and quantize by moving to CUDA device + self.linear.weight = bnb.nn.Params4bit( + weight_data, requires_grad=False, **weight.__dict__ + ) + self.linear.weight.cuda(weight.device) + else: + raise NotImplementedError( + f"Cannot merge the pretrained weights of type {pretrained_dtype}" + f" and LoRA weights of type {lora_data.dtype}" + ) + + self.merged = True + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # if weights are merged or rank is less or equal to zero (LoRA is disabled) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.linear(x) + if self.r == 0 or self.merged: + return pretrained + lora = ( + self.lora_dropout(x) + @ self.lora_A.transpose(0, 1) + @ self.lora_B.transpose(0, 1) + ) * self.scaling + return pretrained + lora + + +class LoRAQKVLinear(LoRALinear): + # LoRA implemented in a dense layer + def __init__( + self, + # ↓ this part is for pretrained weights + in_features: int, + out_features: int, + # ↓ the remaining part is for LoRA + n_head: int, + n_query_groups: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + enable_lora: Union[bool, Tuple[bool, bool, bool]] = False, + **kwargs: Any, + ): + """LoRA wrapper around linear class that is used for calculation of q, k and v matrices. + + This class has three weight matrices: + 1. Pretrained weights are stored as `self.linear.weight` + 2. LoRA A matrix as `self.lora_A` + 3. LoRA B matrix as `self.lora_B` + Only LoRA's A and B matrices are updated, pretrained weights stay frozen. + + Args: + in_features: number of input features of the pretrained weights + out_features: number of output features of the pretrained weights + n_head: number of attention heads + n_query_groups: number of query groups (see diagram in `lit_gpt/config.py`) + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + lora_alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + lora_dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + enable_lora: MergeLinear class is for attention mechanism where qkv are calculated with a single weight matrix. If we + don't want to apply LoRA we can set it as False. For example if we want to apply LoRA only to `query` + and `value` but keep `key` without weight updates we should pass `[True, False, True]` + """ + super(LoRALinear, self).__init__( + r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout + ) + self.linear = torch.nn.Linear(in_features, out_features, **kwargs) + self.n_head = n_head + self.n_query_groups = n_query_groups + if isinstance(enable_lora, bool): + enable_lora = [enable_lora] * 3 + assert len(enable_lora) == 3 + self.enable_lora = enable_lora + + # Actual trainable parameters + # To better understand initialization let's imagine that we have such parameters: + # ⚬ in_features: 128 (embeddings_size) + # ⚬ out_features: 384 (3 * embedding_size) + # ⚬ r: 2 + # ⚬ enable_lora: [True, False, True] + if r > 0 and any(enable_lora): + self.lora_A = nn.Parameter( + torch.zeros((r * sum(enable_lora), in_features)) + ) # (4, 128) + enable_q, enable_k, enable_v = enable_lora + self.kv_embd_size = self.linear.in_features // ( + n_head // n_query_groups + ) + # qkv_shapes will be used to split a tensor with weights correctly + qkv_shapes = ( + self.linear.in_features * enable_q, + self.kv_embd_size * enable_k, + self.kv_embd_size * enable_v, + ) + self.qkv_shapes = [s for s in qkv_shapes if s] + self.lora_B = nn.Parameter( + torch.zeros(sum(self.qkv_shapes), r) + ) # (256, 2)) + # Notes about shapes above + # - self.lora_A has shape (4, 128): 4 because rank is 2 and LoRA is applied only to two matrices; + # 128 is the input size of the x (embedding size). (4, 128) and not (128, 4) because later on in + # F.linear function weights are automatically transposed. In addition conv1d requires channels to + # be before seq length + # - self.lora_B has shape (256, 2): 256 because LoRA is applied only to two matrices, so the output is + # 128*2; 2 tells to have two channels per group for group convolution + + # Scaling: + # This balances the pretrained model`s knowledge and the new task-specific adaptation + # https://lightning.ai/pages/community/tutorial/lora-llm/ + # So, set alpha to 1.0 to fully add LoRA. If the LoRA seems to have too much effect (i.e., overfitted), set + # alpha to lower value. If the LoRA seems to have too little effect, set alpha to higher than 1.0. You can + # tune these values to your needs. This value can be even slightly greater than 1.0! + # https://github.com/cloneofsimo/lora + self.scaling = self.lora_alpha / self.r + + # Compute the indices + # Indices are needed to properly pad weight updates with zeros in `zero_pad` method. + q_per_kv = self.n_head // self.n_query_groups + total_qkv = q_per_kv + 2 + head_size = out_features // (self.n_query_groups * total_qkv) + ind = range(out_features) + self.lora_ind = [] + if enable_q: + q_ind = [ + x + for x in ind + if (x // head_size) % total_qkv < total_qkv - 2 + ] + self.lora_ind.extend(q_ind) + if enable_k: + k_ind = [ + x + for x in ind + if (x // head_size) % total_qkv == total_qkv - 2 + ] + self.lora_ind.extend(k_ind) + if enable_v: + v_ind = [ + x + for x in ind + if (x // head_size) % total_qkv == total_qkv - 1 + ] + self.lora_ind.extend(v_ind) + self.reset_parameters() + + def zero_pad(self, x: torch.Tensor) -> torch.Tensor: + """Properly pad weight updates with zeros. + + If, based on `self.enable_lora`, we want to fine-tune queries and values, but not keys, + then the weights update should be: + + [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,], + [....................................], + [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW,]] + ↑ ↑ ↑ + ________________________________________ + | query | key | value | + ---------------------------------------- + For Llama2's GQA support, Q, K, and V weights are interleaved, so that weights for grouped + queries are adjacent to their associated key and value weights. + For example, suppose we have n_head = 12 with 3 query groups. + Then along the embedding dimension the interleaved weights would look like + + [Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V, Q, Q, Q, Q, K, V], + + where each Q, K, and V has size head_size. + + In this case, the previously-described weight update applies separately to each + individual block, so the update will take the form + + [[ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...], + [.............................................................................], + [ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ΔW,ΔW,ΔW, ..., 0,0,0, ..., ΔW,ΔW,ΔW, ...]] + ↑ ↑ ↑ ↑ ↑ ↑ + ________________________________________________________________________________ + | q block 1 | k block 1 | v block 1 | q block 2 | k block 2 | v block 2 | ... + -------------------------------------------------------------------------------- + Note that in the above diagram, the size of each q block will equal q_per_kv + times the size of each k and v block. + + Args: + x: tensor with weights update that will be padded with zeros if necessary + + Returns: + A tensor with weight updates and zeros for deselected q, k or v + """ + # we need to do zero padding only if LoRA is disabled for one of QKV matrices + if all(self.enable_lora): + return x + + # Let's image that: + # ⚬ input x has shape (64, 64, 256): (batch_size, sequence_length, embeddings_size) + # ⚬ embeddings_size: 128 + # ⚬ self.linear.out_features: 384 (3 * embeddings_size) + # ⚬ enable_lora: [True, False, True] + # Then x has embeddings_size of 256 (2 * 128 as enable_lora only for query and value, not keys) and expected + # embeddings_size is 384 (self.linear.out_features), so that means that we need to pad from 256 to 384 with zeros, but + # only for key updates (this is where self.lora_ind comes in handy) + # Note: double transpose (in the beginning and in the end) is basically a guard for two-dimensional tensors + # for example when we want to merge/unmerge LoRA weights and pretrained weights + x = x.transpose(0, 1) + result = x.new_zeros( + (*x.shape[:-1], self.linear.out_features) + ) # (64, 64, 384) + result = result.view(-1, self.linear.out_features) # (4096, 384) + result = result.index_copy( + 1, + torch.tensor(self.lora_ind, device=result.device), + x.reshape(-1, sum(self.qkv_shapes)), + ) # (4096, 256) + return result.view( + (*x.shape[:-1], self.linear.out_features) + ).transpose( + 0, 1 + ) # (64, 64, 384) + + def conv1d( + self, input: torch.Tensor, weight: torch.Tensor + ) -> torch.Tensor: + """An extension of the `torch.nn.functional.conv1d` function with a logic specific to grouped queries. + + If the number of heads is equal to the number of query groups - grouped queries are disabled + (see scheme in `lit_gpt/config.py:Config`). In this case the combined QKV matrix consists of equally sized + query, key and value parts, which means we can utilize `groups` argument from `conv1d`: with this argument the + input and weight matrices will be splitted in equally sized parts and applied separately (like having multiple + conv layers side by side). + + Otherwise QKV matrix consists of unequally sized parts and thus we have to split input and weight matrices manually, + apply each part of the weight matrix to the corresponding input's part and concatenate the result. + + Args: + input: input matrix of shape (B, C, T) + weight: weight matrix of shape (C_output, rank, 1). + "C_output" is defined as a sum of embedding sizes for each enabled LoRA layer (see init method of the class). + + Returns: + A tensor with a shape (B, C_output, T) + + """ + if self.n_head == self.n_query_groups: + return F.conv1d( + input, weight, groups=sum(self.enable_lora) + ) # (B, C_output, T) + + # Notation: + # ⚬ N: number of enabled LoRA layers (self.enable_lora) + # ⚬ C_output': embeddings size for each LoRA layer (not equal in size) + # ⚬ r: rank of all LoRA layers (equal in size) + + input_splitted = input.chunk( + sum(self.enable_lora), dim=1 + ) # N * (B, C // N, T) + weight_splitted = weight.split( + self.qkv_shapes + ) # N * (C_output', r, 1) + return torch.cat( + [F.conv1d(a, b) for a, b in zip(input_splitted, weight_splitted)], + dim=1, # (B, C_output', T) + ) # (B, C_output, T) + + def get_lora_AB(self) -> torch.Tensor: + """Return merged lora_A and lora_B matrices with the same shape as the pretrained weights.""" + # Let's assume that: + # ⚬ self.linear.weight.data: (384, 128) or (3 * embedding_size, embedding_size) + # ⚬ self.lora_A.data: (4, 128) + # ⚬ self.lora_B.data: (256, 2) + lora = self.conv1d( + self.lora_A.data.unsqueeze(0), # (4, 128) -> (1, 4, 128) + self.lora_B.data.unsqueeze(-1), # (256, 2) -> (256, 2, 1) + ).squeeze( + 0 + ) # (1, 4, 128) @ (256, 2, 1) -> (1, 256, 128) -> (256, 128) + return self.zero_pad( + lora * self.scaling + ) # (256, 128) after zero_pad (384, 128) + + def merge(self) -> None: + """Merges the LoRA weights into the full-rank weights (W = W + delta_W).""" + if self.r > 0 and any(self.enable_lora) and not self.merged: + super().merge() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Do the forward pass. + + If LoRA's weights are merged with pretrained ones then it's a simple matrix multiplication. + If not, then multiply pretrained weights with input, apply LoRA on input and do summation. + + Args: + x: input tensor of shape (batch_size, context_length, embedding_size) + + Returns: + Output tensor of shape (batch_size, context_length, 3 * embedding_size) + """ + + # Let's assume that: + # ⚬ x: (64, 64, 128) or (batch_size, context_length, embedding_size) + # ⚬ self.linear.weight: (384, 128) or (3 * embedding_size, embedding_size) + # ⚬ self.lora_A.data: (4, 128) + # ⚬ self.lora_B.data: (256, 2) + + # if weights are merged or LoRA is disabled (r <= 0 or all `enable_lora` are False) - it's only a regular nn.Linear forward pass; + # otherwise in addition do the forward pass with LoRA weights and add it's output to the output from pretrained weights + pretrained = self.linear(x) + if self.r == 0 or not any(self.enable_lora) or self.merged: + return pretrained + after_A = F.linear( + self.lora_dropout(x), self.lora_A + ) # (64, 64, 128) @ (4, 128) -> (64, 64, 4) + # For F.conv1d: + # ⚬ input: input tensor of shape (mini-batch, in_channels, iW) + # ⚬ weight: filters of shape (out_channels, in_channels/groups, kW) + after_B = self.conv1d( + after_A.transpose(-2, -1), # (64, 64, 4) -> (64, 4, 64) + self.lora_B.unsqueeze(-1), # (256, 2) -> (256, 2, 1) + ).transpose( + -2, -1 + ) # (64, 4, 64) @ (256, 2, 1) -> (64, 256, 64) -> (64, 64, 256) + lora = ( + self.zero_pad(after_B) * self.scaling + ) # (64, 64, 256) after zero_pad (64, 64, 384) + return pretrained + lora + + +def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None: + """Freeze all modules except LoRA's and depending on 'bias' value unfreezes bias weights. + + Args: + model: model with LoRA layers + bias: + ``"none"``: all bias weights will be frozen, + ``"lora_only"``: only bias weight for LoRA layers will be unfrozen, + ``"all"``: all bias weights will be unfrozen. + + Raises: + NotImplementedError: if `bias` not in ["none", "lora_only", "all"] + """ + # freeze all layers except LoRA's + for n, p in model.named_parameters(): + if "lora_" not in n: + p.requires_grad = False + + # depending on the `bias` value unfreeze bias weights + if bias == "none": + return + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + for m in model.modules(): + if ( + isinstance(m, LoRALayer) + and hasattr(m, "bias") + and m.bias is not None + ): + m.bias.requires_grad = True + else: + raise NotImplementedError + + +def lora_filter(key: str, value: Any) -> bool: + return "lora_" in key + + +@dataclass +class Config(BaseConfig): + """ + Args: + r: rank of the weight update matrices. To make sense of using LoRA the rank should be smaller than the rank of + the weights of the model. The rank can be as low as 1: https://arxiv.org/pdf/2106.09685.pdf (section 7.2) + alpha: alpha is needed for scaling updates as alpha/r + "This scaling helps to reduce the need to retune hyperparameters when we vary r" + https://arxiv.org/pdf/2106.09685.pdf (section 4.1) + dropout: dropout that is applied on the input in the LoRA branch (before multiplying by matrix A) + to_*: either apply LoRA to the specified weights or not + """ + + r: int = 0 + alpha: int = 1 + dropout: float = 0.0 + to_query: bool = False + to_key: bool = False + to_value: bool = False + to_projection: bool = False + to_mlp: bool = False + to_head: bool = False + + @property + def mlp_class(self) -> Type: + return getattr(lit_gpt.lora, self._mlp_class) + + +class GPT(BaseModel): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = LoRALinear( + config.n_embd, + config.padded_vocab_size, + bias=config.lm_head_bias, + r=(config.r if config.to_head else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + def forward( + self, + idx: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + lm_head_chunk_size: int = 0, + ) -> Union[torch.Tensor, List[torch.Tensor]]: + T = idx.size(1) + if self.max_seq_length < T: + raise ValueError( + f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." + ) + + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + x = self.transformer.wte( + idx + ) # token embeddings of shape (b, t, n_embd) + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos) + x = self.transformer.ln_f(x) + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [ + self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1) + ] + return self.lm_head(x) # (B, T, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`. Unused method left for completeness.""" + super()._init_weights(module) + if isinstance(module, LoRALinear): + module.reset_parameters() + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "lm_head.weight": "lm_head.linear.weight", + "lm_head.bias": "lm_head.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class Block(BaseBlock): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config) + if not config.shared_attention_norm: + self.norm_2 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.mlp = config.mlp_class(config) + + self.config = config + + +class CausalSelfAttention(BaseCausalSelfAttention): + def __init__(self, config: Config) -> None: + # Skip the parent class __init__ altogether and replace it to avoid + # useless allocations + nn.Module.__init__(self) + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = LoRAQKVLinear( + in_features=config.n_embd, + out_features=shape, + r=config.r, + lora_alpha=config.alpha, + lora_dropout=config.dropout, + enable_lora=(config.to_query, config.to_key, config.to_value), + bias=config.bias, + # for MQA/GQA support + n_head=config.n_head, + n_query_groups=config.n_query_groups, + ) + # output projection + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` + self.proj = LoRALinear( + config.head_size * config.n_head, + config.n_embd, + bias=config.bias, + r=(config.r if config.to_projection else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + # disabled by default + self.kv_cache: Optional[KVCache] = None + + self.config = config + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "attn.weight": "attn.linear.weight", + "attn.bias": "attn.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class GptNeoxMLP(lit_gpt.model.GptNeoxMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.proj = LoRALinear( + config.intermediate_size, + config.n_embd, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + + self.config = config + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc.weight": "fc.linear.weight", + "fc.bias": "fc.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class LLaMAMLP(lit_gpt.model.LLaMAMLP): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.fc_1 = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.fc_2 = LoRALinear( + config.n_embd, + config.intermediate_size, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.proj = LoRALinear( + config.intermediate_size, + config.n_embd, + bias=config.bias, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = { + "fc_1.weight": "fc_1.linear.weight", + "fc_1.bias": "fc_1.linear.bias", + "fc_2.weight": "fc_2.linear.weight", + "fc_2.bias": "fc_2.linear.bias", + "proj.weight": "proj.linear.weight", + "proj.bias": "proj.linear.bias", + } + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +class GemmaMLP(LLaMAMLP): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.gelu(x_fc_1) * x_fc_2 + return self.proj(x) + + +class LLaMAMoE(lit_gpt.model.LLaMAMoE): + def __init__(self, config: Config) -> None: + nn.Module.__init__(self) + self.gate = LoRALinear( + config.n_embd, + config.n_expert, + bias=False, + r=(config.r if config.to_mlp else 0), + lora_alpha=config.alpha, + lora_dropout=config.dropout, + ) + self.experts = nn.ModuleList( + LLaMAMLP(config) for _ in range(config.n_expert) + ) + + self.config = config + + def _load_from_state_dict( + self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any + ) -> None: + """For compatibility with base checkpoints.""" + mapping = {"gate.weight": "gate.linear.weight"} + state_dict = map_old_state_dict_weights(state_dict, mapping, prefix) + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) + + +def merge_lora_weights(model: GPT) -> None: + """Merge LoRA weights into the full-rank weights to speed up inference.""" + for module in model.modules(): + if isinstance(module, LoRALinear): + module.merge() diff --git a/examples/llm_finetuning/lit_gpt/model.py b/examples/llm_finetuning/lit_gpt/model.py new file mode 100644 index 00000000000..1ff378fd419 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/model.py @@ -0,0 +1,501 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Full definition of a decoder-only transformer-based language model, all of it in this single file. + +Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and +https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. +""" + +import math +from typing import Any, Optional, Tuple + +import torch +import torch.nn as nn +from typing_extensions import Self + +from lit_gpt.config import Config + + +class GPT(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear( + config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias + ) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), + ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), + ) + ) + self.max_seq_length = self.config.block_size + self.mask_cache: Optional[torch.Tensor] = None + + @property + def max_seq_length(self) -> int: + return self._max_seq_length + + @max_seq_length.setter + def max_seq_length(self, value: int) -> None: + """ + When doing inference, the sequences used might be shorter than the model's context length. + This allows setting a smaller number to avoid allocating unused memory + """ + if value > self.config.block_size: + raise ValueError( + f"Cannot attend to {value}, block size is only {self.config.block_size}" + ) + self._max_seq_length = value + if not hasattr(self, "cos"): + # first call + cos, sin = self.rope_cache() + self.register_buffer("cos", cos, persistent=False) + self.register_buffer("sin", sin, persistent=False) + # override + elif value != self.cos.size(0): + self.cos, self.sin = self.rope_cache(device=self.cos.device) + # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know + # if the kv cache is expected + + def reset_parameters(self) -> None: + # Trigger resetting the rope-cache + self.cos, self.sin = self.rope_cache() + + def _init_weights(self, module: nn.Module) -> None: + """Meant to be used with `gpt.apply(gpt._init_weights)`.""" + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if module.bias is not None: + torch.nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + + def forward( + self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None + ) -> torch.Tensor: + T = idx.size(1) + if self.max_seq_length < T: + raise ValueError( + f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." + ) + + if input_pos is not None: # use the kv cache + cos = self.cos.index_select(0, input_pos) + sin = self.sin.index_select(0, input_pos) + if self.mask_cache is None: + raise TypeError("You need to call `gpt.set_kv_cache()`") + mask = self.mask_cache.index_select(2, input_pos) + else: + cos = self.cos[:T] + sin = self.sin[:T] + mask = None + + x = self.transformer.wte( + idx + ) # token embeddings of shape (b, t, n_embd) + if self.config.scale_embeddings: + x = x * (self.config.n_embd**0.5) + + for block in self.transformer.h: + x = block(x, cos, sin, mask, input_pos) + x = self.transformer.ln_f(x) + return self.lm_head(x) # (b, t, vocab_size) + + @classmethod + def from_name(cls, name: str, **kwargs: Any) -> Self: + return cls(Config.from_name(name, **kwargs)) + + def rope_cache( + self, device: Optional[torch.device] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + return build_rope_cache( + seq_len=self.max_seq_length, + n_elem=self.config.rope_n_elem, + device=device, + condense_ratio=self.config.rope_condense_ratio, + base=self.config.rope_base, + ) + + def set_kv_cache( + self, + batch_size: int, + rope_cache_length: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + if rope_cache_length is None: + rope_cache_length = self.cos.size(-1) + max_seq_length = self.max_seq_length + + # initialize the kv cache for all blocks + for block in self.transformer.h: + block.attn.kv_cache = block.attn.build_kv_cache( + batch_size, max_seq_length, rope_cache_length, device, dtype + ) + + if ( + self.mask_cache is None + or self.mask_cache.size(3) != max_seq_length + ): + # passing `attn_mask` to SDPA disables the flash implementation. since we only need the mask + # for the kv-cache support (only during inference), we only create it in that situation + self.mask_cache = build_mask_cache(max_seq_length, device) + + def clear_kv_cache(self) -> None: + self.mask_cache = None + for block in self.transformer.h: + block.attn.kv_cache = None + + +class Block(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + self.attn = CausalSelfAttention(config) + self.norm_2 = ( + None + if config.shared_attention_norm + else config.norm_class(config.n_embd, eps=config.norm_eps) + ) + self.mlp = config.mlp_class(config) + + self.config = config + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + n_1 = self.norm_1(x) + h = self.attn(n_1, cos, sin, mask, input_pos) + if self.config.parallel_residual: + n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) + x = self.mlp(n_2) + h + x + else: + if self.config.shared_attention_norm: + raise NotImplementedError( + "No checkpoint amongst the ones we support uses this configuration" + " (non-parallel residual and shared attention norm)." + ) + x = h + x + x = self.mlp(self.norm_2(x)) + x + return x + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + # key, query, value projections for all heads, but in a batch + self.attn = nn.Linear(config.n_embd, shape, bias=config.bias) + # output projection + # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` + self.proj = nn.Linear( + config.head_size * config.n_head, config.n_embd, bias=config.bias + ) + # disabled by default + self.kv_cache: Optional[KVCache] = None + + self.config = config + + def forward( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + mask: Optional[torch.Tensor] = None, + input_pos: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + ( + B, + T, + C, + ) = ( + x.size() + ) # batch size, sequence length, embedding dimensionality (n_embd) + + qkv = self.attn(x) + + # assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`) + q_per_kv = self.config.n_head // self.config.n_query_groups + total_qkv = ( + q_per_kv + 2 + ) # each group has 1+ queries, 1 key, and 1 value + qkv = qkv.view( + B, T, self.config.n_query_groups, total_qkv, self.config.head_size + ) + qkv = qkv.permute( + 0, 2, 3, 1, 4 + ) # (B, n_query_groups, total_qkv, T, hs) + + # split batched computation into three + q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) + + # maybe repeat k and v if for the non multi-head attention cases + # training: flash attention requires it + # inference: multi-query would require a full kv cache so avoid it to limit its memory usage + if self.config.n_query_groups != self.config.n_head and ( + input_pos is None or self.config.n_query_groups != 1 + ): + k = k.expand( + B, + self.config.n_query_groups, + q_per_kv, + T, + self.config.head_size, + ) + v = v.expand( + B, + self.config.n_query_groups, + q_per_kv, + T, + self.config.head_size, + ) + + q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs) + k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs) + v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs) + + q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) + k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) + + if input_pos is not None: + if not isinstance(self.kv_cache, KVCache): + raise TypeError("You need to call `gpt.set_kv_cache()`") + k, v = self.kv_cache(input_pos, k, v) + + y = self.scaled_dot_product_attention(q, k, v, mask) + + y = y.reshape( + B, T, self.config.head_size * self.config.n_head + ) # re-assemble all head outputs side by side + + # output projection + return self.proj(y) + + def scaled_dot_product_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + scale = 1.0 / math.sqrt(self.config.head_size) + y = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=0.0, + scale=scale, + is_causal=mask is None, + ) + return y.transpose(1, 2) + + def build_kv_cache( + self, + batch_size: int, + max_seq_length: int, + rope_cache_length: Optional[int] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> "KVCache": + heads = 1 if self.config.n_query_groups == 1 else self.config.n_head + v_shape = (batch_size, heads, max_seq_length, self.config.head_size) + if rope_cache_length is None: + if self.config.rotary_percentage != 1.0: + raise TypeError( + "Please pass the `rope_cache_length=gpt.cos.size(-1)` value" + ) + k_shape = v_shape + else: + k_shape = ( + batch_size, + heads, + max_seq_length, + rope_cache_length + + self.config.head_size + - self.config.rope_n_elem, + ) + return KVCache(k_shape, v_shape, device=device, dtype=dtype) + + +class GptNeoxMLP(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.fc = nn.Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.proj = nn.Linear( + config.intermediate_size, config.n_embd, bias=config.bias + ) + + self.config = config + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.fc(x) + x = torch.nn.functional.gelu( + x, approximate=self.config.gelu_approximate + ) + return self.proj(x) + + +class LLaMAMLP(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.fc_1 = nn.Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.fc_2 = nn.Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.proj = nn.Linear( + config.intermediate_size, config.n_embd, bias=config.bias + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + return self.proj(x) + + +class GemmaMLP(LLaMAMLP): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_fc_1 = self.fc_1(x) + x_fc_2 = self.fc_2(x) + x = torch.nn.functional.gelu(x_fc_1) * x_fc_2 + return self.proj(x) + + +class LLaMAMoE(nn.Module): + def __init__(self, config: Config) -> None: + super().__init__() + self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False) + self.experts = nn.ModuleList( + LLaMAMLP(config) for _ in range(config.n_expert) + ) + + self.config = config + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 + See also figure 1 in https://arxiv.org/abs/2211.15841 + """ + ( + B, + T, + C, + ) = ( + x.size() + ) # batch size, sequence length, embedding dimensionality (n_embd) + x = x.view(-1, C) # (B*T, C) + router = self.gate(x) # (B*T, n_expert) + probs, indices = torch.topk( + router, self.config.n_expert_per_token + ) # (B*T, n_expert_per_token) + probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype) + masks = indices.unsqueeze(-1) == torch.arange( + self.config.n_expert, device=x.device + ) + masks = masks.permute(2, 0, 1) # (n_expert, B*T, n_expert_per_token) + y = torch.zeros_like(x) # (B*T, C) + for mask, expert in zip(masks, self.experts): + token_idx, expert_idx = torch.where(mask) + y[token_idx] += probs[token_idx, expert_idx, None] * expert( + x[token_idx] + ) + return y.view(B, T, C) + + +def build_rope_cache( + seq_len: int, + n_elem: int, + device: Optional[torch.device] = None, + base: int = 10000, + condense_ratio: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / ( + base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem) + ) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, device=device) / condense_ratio + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) + + return torch.cos(idx_theta), torch.sin(idx_theta) + + +def apply_rope( + x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor +) -> torch.Tensor: + head_size = x.size(-1) + x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) + x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) + rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) + roped = (x * cos) + (rotated * sin) + return roped.to(dtype=x.dtype) + + +class KVCache(nn.Module): + def __init__( + self, + k_shape: Tuple[int, int, int, int], + v_shape: Tuple[int, int, int, int], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__() + self.register_buffer( + "k", + torch.zeros(k_shape, device=device, dtype=dtype), + persistent=False, + ) + self.register_buffer( + "v", + torch.zeros(v_shape, device=device, dtype=dtype), + persistent=False, + ) + + def forward( + self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # move the buffer to the activation dtype for when AMP is used + self.k = self.k.to(k.dtype) + self.v = self.v.to(v.dtype) + # update the cache + k = self.k.index_copy_(2, input_pos, k) + v = self.v.index_copy_(2, input_pos, v) + return k, v + + def reset_parameters(self) -> None: + torch.nn.init.zeros_(self.k) + torch.nn.init.zeros_(self.v) + + +def build_mask_cache( + max_seq_length: int, device: Optional[torch.device] = None +) -> torch.Tensor: + ones = torch.ones( + (max_seq_length, max_seq_length), device=device, dtype=torch.bool + ) + return torch.tril(ones).unsqueeze(0).unsqueeze(0) diff --git a/examples/llm_finetuning/lit_gpt/packed_dataset.py b/examples/llm_finetuning/lit_gpt/packed_dataset.py new file mode 100644 index 00000000000..a183d4c2423 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/packed_dataset.py @@ -0,0 +1,274 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +# Very loosely inspired by indexed_dataset in Fairseq, Megatron +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py + + +import os +import random +import struct + +import numpy as np +import torch +from torch.utils.data import IterableDataset, get_worker_info + +dtypes = { + 1: np.uint8, + 2: np.int8, + 3: np.int16, + 4: np.int32, + 5: np.int64, + 6: np.float32, + 7: np.float64, + 8: np.uint16, +} + + +def code(dtype): + for k in dtypes: + if dtypes[k] == dtype: + return k + raise ValueError(dtype) + + +HDR_MAGIC = b"LITPKDS" +HDR_SIZE = 24 # bytes + + +class PackedDataset(IterableDataset): + def __init__( + self, + filenames, + n_chunks, + block_size, + seed=12345, + shuffle=True, + wrap=False, + num_processes=1, + process_rank=0, + ): + self._filenames = filenames + self._n_chunks = n_chunks + self._block_size = block_size + self._seed = seed + self._shuffle = shuffle + self._wrap = wrap + self._num_processes = num_processes + self._process_rank = process_rank + + def __iter__(self): + worker_info = get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + worker_id = worker_info.id if worker_info is not None else 0 + num_shards = num_workers * self._num_processes + shard_id = self._process_rank * num_workers + worker_id + + max_num_files = len(self._filenames) // num_shards * num_shards + filenames = self._filenames[shard_id:max_num_files:num_shards] + + return PackedDatasetIterator( + filenames=filenames, + n_chunks=self._n_chunks, + block_size=self._block_size, + seed=self._seed, + shuffle=self._shuffle, + wrap=self._wrap, + ) + + +class PackedDatasetBuilder(object): + def __init__( + self, + outdir, + prefix, + chunk_size, + sep_token, + dtype="auto", + vocab_size=None, + ): + if dtype == "auto": + if vocab_size is None: + raise ValueError("vocab_size cannot be None when dtype='auto'") + if vocab_size is not None and vocab_size < 65500: + self._dtype = np.uint16 + else: + self._dtype = np.int32 + else: + self._dtype = dtype + self._counter = 0 + self._chunk_size = chunk_size + self._outdir = outdir + self._prefix = prefix + self._sep_token = sep_token + self._arr = np.zeros(self._chunk_size, dtype=self._dtype) + self._arr.fill(self._sep_token) + self._idx = 0 + self._version = 1 + self._filenames = [] + + def _write_chunk(self): + filename = f"{self._prefix}_{self._counter:010d}.bin" + filename = os.path.join(self._outdir, filename) + + with open(filename, "wb") as f: + f.write(HDR_MAGIC) + f.write(struct.pack(" self._chunk_size: + part_len = self._chunk_size - self._idx + self._arr[self._idx : self._idx + part_len] = arr[:part_len] + self._write_chunk() + arr = arr[part_len:] + + arr_len = arr.shape[0] + self._arr[self._idx : self._idx + arr_len] = arr + self._idx += arr_len + + def write_reminder(self): + self._write_chunk() + + +class PackedDatasetIterator: + def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): + self._seed = seed + self._shuffle = shuffle + self._rng = np.random.default_rng(seed) if shuffle else None + self._block_idxs = None + + self._wrap = wrap + + # TODO: instead of filenames, we could have a single text stream + # (or text file) with the sequence of all files to be + # fetched/loaded. + self._filenames = filenames + self._file_idx = 0 + + self._n_chunks = n_chunks + + self._dtype = None + self._block_size = block_size + self._n_blocks = None + + self._mmaps = [] + self._buffers = [] + + self._block_idxs = [] + self._curr_idx = 0 + + self._load_n_chunks() + + def _read_header(self, path): + with open(path, "rb") as f: + magic = f.read(len(HDR_MAGIC)) + assert magic == HDR_MAGIC, "File doesn't match expected format." + version = struct.unpack(" len(self._filenames[self._file_idx :]): + if not self._wrap: + raise StopIteration + self._file_idx = 0 + + for i in range(self._n_chunks): + filename = self._filenames[self._file_idx + i] + if self._dtype is None: + self._dtype, self._chunk_size = self._read_header(filename) + self._n_blocks = self._chunk_size // self._block_size + # TODO: check header matches with previous files + mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) + self._mmaps.append(mmap) + self._buffers.append(memoryview(mmap)) + + self._file_idx += self._n_chunks + n_all_blocks = self._n_chunks * self._n_blocks + + self._block_idxs = ( + self._rng.permutation(n_all_blocks) + if self._shuffle + else range(n_all_blocks) + ) + + self._curr_idx = 0 + + def __del__(self): + self._close_mmaps() + del self._mmaps + del self._buffers + + def __iter__(self): + return self + + def __next__(self): + if self._curr_idx >= len(self._block_idxs): + self._load_n_chunks() + # TODO: trigger fetching next next n_chunks if remote + block_idx = self._block_idxs[self._curr_idx] + chunk_id = block_idx // self._n_blocks + buffer = self._buffers[chunk_id] + elem_id = (block_idx % self._n_blocks) * self._block_size + offset = np.dtype(self._dtype).itemsize * elem_id + arr = np.frombuffer( + buffer, dtype=self._dtype, count=self._block_size, offset=offset + ) + self._curr_idx += 1 + return torch.from_numpy(arr.astype(np.int64)) + + +class CombinedDataset(IterableDataset): + def __init__(self, datasets, seed, weights=None): + self._seed = seed + self._datasets = datasets + self._weights = weights + n_datasets = len(datasets) + if weights is None: + self._weights = [1 / n_datasets] * n_datasets + else: + self._weights = [w / sum(weights) for w in weights] + + def __iter__(self): + return CombinedDatasetIterator( + self._datasets, self._seed, self._weights + ) + + +class CombinedDatasetIterator: + def __init__(self, datasets, seed, weights): + self._datasets = [iter(el) for el in datasets] + self._weights = weights + self._rng = random.Random(seed) + + def __next__(self): + (dataset,) = self._rng.choices( + self._datasets, weights=self._weights, k=1 + ) + return next(dataset) diff --git a/examples/llm_finetuning/lit_gpt/rmsnorm.py b/examples/llm_finetuning/lit_gpt/rmsnorm.py new file mode 100644 index 00000000000..108288128f7 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/rmsnorm.py @@ -0,0 +1,40 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import torch + + +class RMSNorm(torch.nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__( + self, + size: int, + dim: int = -1, + eps: float = 1e-6, + add_unit_offset: bool = False, + ) -> None: + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + self.add_unit_offset = add_unit_offset + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + x = x.float() + # NOTE: the original RMSNorm paper implementation is not equivalent + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + x_normed = x_normed.to(dtype=dtype) + if self.add_unit_offset: + # Gemma model requires a unit offset + # https://github.com/google/gemma_pytorch/blob/main/gemma/model.py#L176 + return x_normed * (1 + self.weight) + return x_normed * self.weight + + def reset_parameters(self) -> None: + torch.nn.init.ones_(self.weight) diff --git a/examples/llm_finetuning/lit_gpt/tokenizer.py b/examples/llm_finetuning/lit_gpt/tokenizer.py new file mode 100644 index 00000000000..f2832ce61c2 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/tokenizer.py @@ -0,0 +1,136 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +from pathlib import Path +from typing import Optional, Union + +import torch + + +class Tokenizer: + def __init__(self, checkpoint_dir: Union[Path, str]) -> None: + checkpoint_dir = Path(checkpoint_dir) + if not checkpoint_dir.exists(): + raise NotADirectoryError( + f"The checkpoint directory does not exist: {str(checkpoint_dir)}" + ) + + self.use_bos = self.check_if_bos_token_used(checkpoint_dir) + self.bos_id = None + self.eos_id = None + + # some checkpoints have both files, `.model` takes precedence + if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): + from sentencepiece import SentencePieceProcessor + + self.processor = SentencePieceProcessor( + model_file=str(vocabulary_path) + ) + self.backend = "sentencepiece" + self.bos_id = self.processor.bos_id() + self.eos_id = self.processor.eos_id() + + elif (vocabulary_path := checkpoint_dir / "tokenizer.json").is_file(): + from tokenizers import Tokenizer as HFTokenizer + + self.processor = HFTokenizer.from_file(str(vocabulary_path)) + self.backend = "huggingface" + + if ( + special_tokens_path := checkpoint_dir / "tokenizer_config.json" + ).is_file(): + with open(special_tokens_path) as fp: + config = json.load(fp) + bos_token = config.get("bos_token") + self.bos_id = ( + self.token_to_id(bos_token) + if bos_token is not None + else None + ) + eos_token = config.get("eos_token") + self.eos_id = ( + self.token_to_id(eos_token) + if eos_token is not None + else None + ) + if ( + special_tokens_path := checkpoint_dir + / "generation_config.json" + ).is_file(): + with open(special_tokens_path) as fp: + config = json.load(fp) + if self.bos_id is None: + self.bos_id = config.get("bos_token_id") + if self.eos_id is None: + self.eos_id = config.get("eos_token_id") + else: + raise NotImplementedError + + @property + def vocab_size(self) -> int: + if self.backend == "huggingface": + return self.processor.get_vocab_size(with_added_tokens=False) + if self.backend == "sentencepiece": + return self.processor.vocab_size() + raise RuntimeError + + def token_to_id(self, token: str) -> int: + if self.backend == "huggingface": + id_ = self.processor.token_to_id(token) + elif self.backend == "sentencepiece": + id_ = self.processor.piece_to_id(token) + else: + raise RuntimeError + if id_ is None: + raise ValueError(f"token {token!r} not found in the collection.") + return id_ + + def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool: + if not ( + tokenizer_config_path := checkpoint_dir / "tokenizer_config.json" + ).is_file(): + return False + with open(tokenizer_config_path) as fp: + config = json.load(fp) + if any( + config.get(check, False) + for check in ("add_bos_token", "add_prefix_space") + ): + return True + # for examples that also use the Llama tokenizer, but do not have or set add_bos_token to True. + # ex: https://huggingface.co/stabilityai/StableBeluga2/blob/main/tokenizer_config.json#L2 + return ( + config.get("add_bos_token") is None + and config.get("tokenizer_class") == "LlamaTokenizer" + ) + + def encode( + self, + string: str, + device: Optional[torch.device] = None, + bos: Optional[bool] = None, + eos: bool = False, + max_length: int = -1, + ) -> torch.Tensor: + if self.backend == "huggingface": + tokens = self.processor.encode(string).ids + elif self.backend == "sentencepiece": + tokens = self.processor.encode(string) + else: + raise RuntimeError + if bos or (bos is None and self.use_bos): + bos_id = self.bos_id + if bos_id is None: + raise NotImplementedError( + "This tokenizer does not have a defined a bos token" + ) + tokens = [bos_id] + tokens + if eos: + tokens = tokens + [self.eos_id] + if max_length > 0: + tokens = tokens[:max_length] + return torch.tensor(tokens, dtype=torch.int, device=device) + + def decode(self, tensor: torch.Tensor) -> str: + tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() + return self.processor.decode(tokens) diff --git a/examples/llm_finetuning/lit_gpt/utils.py b/examples/llm_finetuning/lit_gpt/utils.py new file mode 100644 index 00000000000..ba4706ff473 --- /dev/null +++ b/examples/llm_finetuning/lit_gpt/utils.py @@ -0,0 +1,477 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Utility functions for training and inference.""" + +import math +import pickle +import sys +from io import BytesIO +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Mapping, + Optional, + TypeVar, + Union, +) + +import lightning as L +import torch +import torch.nn as nn +import torch.utils._device +from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.utilities.load import _lazy_load as lazy_load +from torch.serialization import normalize_storage_type +from typing_extensions import Self + +if TYPE_CHECKING: + from lit_gpt import GPT + + +def find_multiple(n: int, k: int) -> int: + assert k > 0 + if n % k == 0: + return n + return n + k - (n % k) + + +def num_parameters( + module: nn.Module, requires_grad: Optional[bool] = None +) -> int: + total = 0 + for p in module.parameters(): + if requires_grad is None or p.requires_grad == requires_grad: + if hasattr(p, "quant_state"): + # bitsandbytes 4bit layer support + total += math.prod(p.quant_state[1]) + else: + total += p.numel() + return total + + +def check_valid_checkpoint_dir(checkpoint_dir: Path) -> None: + files = { + "lit_model.pth": (checkpoint_dir / "lit_model.pth").is_file(), + "lit_config.json": (checkpoint_dir / "lit_config.json").is_file(), + "tokenizer.json OR tokenizer.model": ( + checkpoint_dir / "tokenizer.json" + ).is_file() + or (checkpoint_dir / "tokenizer.model").is_file(), + "tokenizer_config.json": ( + checkpoint_dir / "tokenizer_config.json" + ).is_file(), + } + if checkpoint_dir.is_dir(): + if all(files.values()): + # we're good + return + problem = f" is missing the files: {[f for f, exists in files.items() if not exists]!r}" + else: + problem = " is not a checkpoint directory" + + # list locally available checkpoints + available = list(Path("checkpoints").glob("*/*")) + if available: + options = "\n --checkpoint_dir ".join( + [""] + [repr(str(p.resolve())) for p in available] + ) + extra = f"\nYou have downloaded locally:{options}\n" + else: + extra = "" + + error_message = ( + f"--checkpoint_dir {str(checkpoint_dir.absolute())!r}{problem}." + "\nFind download instructions at https://github.com/Lightning-AI/lit-gpt/blob/main/tutorials\n" + f"{extra}\nSee all download options by running:\n python scripts/download.py" + ) + print(error_message, file=sys.stderr) + raise SystemExit(1) + + +class SavingProxyForStorage: + def __init__(self, obj, saver, protocol_version=5): + self.protocol_version = protocol_version + self.saver = saver + if not ( + isinstance(obj, torch.storage.TypedStorage) + or torch.is_storage(obj) + ): + raise TypeError(f"expected storage, not {type(obj)}") + + # this logic is taken from PyTorch 2.0+ torch/serialization.py + if isinstance(obj, torch.storage.TypedStorage): + # PT upstream wants to deprecate this eventually... + storage = obj._untyped_storage + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + else: + storage = obj + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + storage_key = saver._write_storage_and_return_key(storage) + location = torch.serialization.location_tag(storage) + + self.storage_info = ( + "storage", + storage_type, + storage_key, + location, + storage_numel, + ) + + def __reduce_ex__(self, protocol_version): + assert False, "this should be handled with out of band" + + +class SavingProxyForTensor: + def __init__(self, tensor, saver, protocol_version=5): + self.protocol_version = protocol_version + self.reduce_ret_fn, reduce_args = tensor.__reduce_ex__( + protocol_version + ) + if reduce_args[0] == torch._utils._rebuild_tensor_v2: + # for Tensors with Python attributes + (a0, a1, (storage, *a2_other), *other_reduce_args) = reduce_args + assert isinstance( + storage, torch.storage.TypedStorage + ), "Please check for updates" + storage_proxy = SavingProxyForStorage( + storage, saver, protocol_version=protocol_version + ) + self.reduce_args = ( + a0, + a1, + (storage_proxy, *a2_other), + *other_reduce_args, + ) + else: + (storage, *other_reduce_args) = reduce_args + assert isinstance( + storage, torch.storage.TypedStorage + ), "Please check for updates" + storage_proxy = SavingProxyForStorage( + storage, saver, protocol_version=protocol_version + ) + self.reduce_args = (storage_proxy, *other_reduce_args) + + def __reduce_ex__(self, protocol_version): + if protocol_version != self.protocol_version: + raise RuntimeError( + f"Unexpected protocol version: expected {self.protocol_version}, got {protocol_version}" + ) + return self.reduce_ret_fn, self.reduce_args + + +class IncrementalPyTorchPickler(pickle.Pickler): + def __init__(self, saver, *args, **kwargs): + super().__init__(*args, **kwargs) + self.storage_dtypes = {} + self.saver = saver + self.id_map = {} + + # this logic is taken from PyTorch 2.0+ torch/serialization.py + def persistent_id(self, obj): + # FIXME: the docs say that persistent_id should only return a string + # but torch store returns tuples. This works only in the binary protocol + # see + # https://docs.python.org/2/library/pickle.html#pickling-and-unpickling-external-objects + # https://github.com/python/cpython/blob/master/Lib/pickle.py#L527-L537 + if isinstance(obj, SavingProxyForStorage): + return obj.storage_info + + if isinstance(obj, torch.storage.TypedStorage) or torch.is_storage( + obj + ): + if isinstance(obj, torch.storage.TypedStorage): + # TODO: Once we decide to break serialization FC, this case + # can be deleted + storage = obj._untyped_storage + storage_dtype = obj.dtype + storage_type_str = obj._pickle_storage_type() + storage_type = getattr(torch, storage_type_str) + storage_numel = obj._size() + + else: + storage = obj + storage_dtype = torch.uint8 + storage_type = normalize_storage_type(type(obj)) + storage_numel = storage.nbytes() + + # If storage is allocated, ensure that any other saved storages + # pointing to the same data all have the same dtype. If storage is + # not allocated, don't perform this check + if storage.data_ptr() != 0: + if storage.data_ptr() in self.storage_dtypes: + if ( + storage_dtype + != self.storage_dtypes[storage.data_ptr()] + ): + raise RuntimeError( + "Cannot save multiple tensors or storages that view the same data as different types" + ) + else: + self.storage_dtypes[storage.data_ptr()] = storage_dtype + + storage_key = self.id_map.get(storage._cdata) + if storage_key is None: + storage_key = self.saver._write_storage_and_return_key(storage) + self.id_map[storage._cdata] = storage_key + location = torch.serialization.location_tag(storage) + + return ( + "storage", + storage_type, + storage_key, + location, + storage_numel, + ) + + return None + + +class incremental_save: + def __init__(self, name): + self.name = name + self.zipfile = torch._C.PyTorchFileWriter(str(name)) + self.has_saved = False + self.next_key = 0 + + def __enter__(self): + return self + + def store_early(self, tensor): + if isinstance(tensor, torch.Tensor): + return SavingProxyForTensor(tensor, self) + raise TypeError(f"can only store tensors early, not {type(tensor)}") + + def save(self, obj): + if self.has_saved: + raise RuntimeError("have already saved") + # Write the pickle data for `obj` + data_buf = BytesIO() + pickler = IncrementalPyTorchPickler(self, data_buf, protocol=5) + pickler.dump(obj) + data_value = data_buf.getvalue() + self.zipfile.write_record("data.pkl", data_value, len(data_value)) + self.has_saved = True + + def _write_storage_and_return_key(self, storage): + if self.has_saved: + raise RuntimeError("have already saved") + key = self.next_key + self.next_key += 1 + name = f"data/{key}" + if storage.device.type != "cpu": + storage = storage.cpu() + num_bytes = storage.nbytes() + self.zipfile.write_record(name, storage.data_ptr(), num_bytes) + return key + + def __exit__(self, type, value, traceback): + self.zipfile.write_end_of_file() + + +T = TypeVar("T") + + +def chunked_cross_entropy( + logits: Union[torch.Tensor, List[torch.Tensor]], + targets: torch.Tensor, + chunk_size: int = 128, + ignore_index: int = -1, +) -> torch.Tensor: + # with large max_sequence_lengths, the beginning of `backward` allocates a large memory chunk which can dominate + # the memory usage in fine-tuning settings with low number of parameters. + # as a workaround hack, the cross entropy computation is chunked to force it to deallocate on the go, reducing + # the memory spike's magnitude + + # lm_head was chunked (we are fine-tuning) + if isinstance(logits, list): + # don't want to chunk cross entropy + if chunk_size == 0: + logits = torch.cat(logits, dim=1) + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + return torch.nn.functional.cross_entropy( + logits, targets, ignore_index=ignore_index + ) + + # chunk cross entropy + logit_chunks = [ + logit_chunk.reshape(-1, logit_chunk.size(-1)) + for logit_chunk in logits + ] + target_chunks = [ + target_chunk.reshape(-1) + for target_chunk in targets.split(logits[0].size(1), dim=1) + ] + loss_chunks = [ + torch.nn.functional.cross_entropy( + logit_chunk, + target_chunk, + ignore_index=ignore_index, + reduction="none", + ) + for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) + ] + non_masked_elems = (targets != ignore_index).sum() + return torch.cat(loss_chunks).sum() / max(1, non_masked_elems) + + # no chunking at all + logits = logits.reshape(-1, logits.size(-1)) + targets = targets.reshape(-1) + if chunk_size == 0: + return torch.nn.functional.cross_entropy( + logits, targets, ignore_index=ignore_index + ) + + # lm_head wasn't chunked, chunk cross entropy + logit_chunks = logits.split(chunk_size) + target_chunks = targets.split(chunk_size) + loss_chunks = [ + torch.nn.functional.cross_entropy( + logit_chunk, + target_chunk, + ignore_index=ignore_index, + reduction="none", + ) + for logit_chunk, target_chunk in zip(logit_chunks, target_chunks) + ] + non_masked_elems = (targets != ignore_index).sum() + return torch.cat(loss_chunks).sum() / max(1, non_masked_elems) + + +def map_old_state_dict_weights( + state_dict: Dict, mapping: Mapping, prefix: str +) -> Dict: + for checkpoint_name, attribute_name in mapping.items(): + full_checkpoint_name = prefix + checkpoint_name + if full_checkpoint_name in state_dict: + full_attribute_name = prefix + attribute_name + state_dict[full_attribute_name] = state_dict.pop( + full_checkpoint_name + ) + return state_dict + + +def get_default_supported_precision(training: bool) -> str: + """Return default precision that is supported by the hardware: either `bf16` or `16`. + + Args: + training: `-mixed` or `-true` version of the precision to use + + Returns: + default precision that is suitable for the task and is supported by the hardware + """ + from lightning.fabric.accelerators import MPSAccelerator + + if MPSAccelerator.is_available() or ( + torch.cuda.is_available() and not torch.cuda.is_bf16_supported() + ): + return "16-mixed" if training else "16-true" + return "bf16-mixed" if training else "bf16-true" + + +def load_checkpoint( + fabric: L.Fabric, + model: nn.Module, + checkpoint_path: Path, + strict: bool = True, +) -> None: + if isinstance(fabric.strategy, FSDPStrategy): + fabric.load_raw(checkpoint_path, model, strict=strict) + else: + state_dict = lazy_load(checkpoint_path) + state_dict = state_dict.get("model", state_dict) + model.load_state_dict(state_dict, strict=strict) + + +def flops_per_param( + max_seq_length: int, n_layer: int, n_embd: int, n_params: int +) -> int: + flops_per_token = ( + 2 * n_params + ) # each parameter is used for a MAC (2 FLOPS) per network operation + # this assumes that all samples have a fixed length equal to the block size + # which is most likely false during finetuning + flops_per_seq = flops_per_token * max_seq_length + attn_flops_per_seq = n_layer * 2 * 2 * (n_embd * (max_seq_length**2)) + return flops_per_seq + attn_flops_per_seq + + +def estimate_flops(model: "GPT", training: bool) -> int: + """Measures estimated FLOPs for MFU. + + Refs: + * https://ar5iv.labs.arxiv.org/html/2205.05198#A1 + * https://ar5iv.labs.arxiv.org/html/2204.02311#A2 + """ + # using all parameters for this is a naive over estimation because not all model parameters actually contribute to + # this FLOP computation (e.g. embedding, norm). For this reason, the result will be higher by a fixed percentage + # (~10%) compared to the measured FLOPs, making those lower but more realistic. + # For a proper estimate, this needs a more fine-grained calculation as in Appendix A of the paper. + n_trainable_params = num_parameters(model, requires_grad=True) + trainable_flops = flops_per_param( + model.max_seq_length, + model.config.n_layer, + model.config.n_embd, + n_trainable_params, + ) + # forward + backward + gradients (assumes no gradient accumulation) + ops_per_step = 3 if training else 1 + n_frozen_params = num_parameters(model, requires_grad=False) + frozen_flops = flops_per_param( + model.max_seq_length, + model.config.n_layer, + model.config.n_embd, + n_frozen_params, + ) + # forward + backward + frozen_ops_per_step = 2 if training else 1 + return ops_per_step * trainable_flops + frozen_ops_per_step * frozen_flops + + +class CycleIterator: + """An iterator that cycles through an iterable indefinitely. + + Example: + >>> iterator = CycleIterator([1, 2, 3]) + >>> [next(iterator) for _ in range(5)] + [1, 2, 3, 1, 2] + + Note: + Unlike ``itertools.cycle``, this iterator does not cache the values of the iterable. + """ + + def __init__(self, iterable: Iterable) -> None: + self.iterable = iterable + self.epoch = 0 + self._iterator = None + + def __next__(self) -> Any: + if self._iterator is None: + self._iterator = iter(self.iterable) + try: + return next(self._iterator) + except StopIteration: + self._iterator = iter(self.iterable) + self.epoch += 1 + return next(self._iterator) + + def __iter__(self) -> Self: + return self + + +def CLI(*args: Any, **kwargs: Any) -> Any: + from jsonargparse import CLI, set_docstring_parse_options + + set_docstring_parse_options(attribute_docstrings=True) + + kwargs.setdefault("as_positional", False) + return CLI(*args, **kwargs) diff --git a/examples/llm_finetuning/materializers/__init__.py b/examples/llm_finetuning/materializers/__init__.py new file mode 100644 index 00000000000..757bd8418a5 --- /dev/null +++ b/examples/llm_finetuning/materializers/__init__.py @@ -0,0 +1,16 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/examples/llm_finetuning/materializers/directory_materializer.py b/examples/llm_finetuning/materializers/directory_materializer.py new file mode 100644 index 00000000000..4adc7b4a10a --- /dev/null +++ b/examples/llm_finetuning/materializers/directory_materializer.py @@ -0,0 +1,71 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from pathlib import Path +from tempfile import mkdtemp +from typing import Any, ClassVar, Tuple, Type + +from zenml.enums import ArtifactType +from zenml.io import fileio +from zenml.materializers.base_materializer import BaseMaterializer + + +class DirectoryMaterializer(BaseMaterializer): + """Materializer to store local directories in the artifact store.""" + + ASSOCIATED_TYPES: ClassVar[Tuple[Type[Any], ...]] = (Path,) + ASSOCIATED_ARTIFACT_TYPE: ClassVar[ArtifactType] = ArtifactType.DATA + + def load(self, data_type: Type[Any]) -> Any: + """Copy the artifact files to a local temp directory. + + Args: + data_type: Unused. + + Returns: + Path to the local directory that contains the artifact files. + """ + directory = mkdtemp(prefix="zenml-artifact") + self._copy_directory(src=self.uri, dst=directory) + return Path(directory) + + def save(self, data: Any) -> None: + """Store the directory in the artifact store. + + Args: + data: Path to a local directory to store. + """ + assert isinstance(data, Path) + self._copy_directory(src=str(data), dst=self.uri) + + @staticmethod + def _copy_directory(src: str, dst: str) -> None: + """Recursively copy a directory. + + Args: + src: The directory to copy. + dst: Where to copy the directory to. + """ + for src_dir, _, files in fileio.walk(src): + dst_dir = os.path.join(dst, os.path.relpath(src_dir, src)) + fileio.makedirs(dst_dir) + + for file in files: + src_file = os.path.join(src_dir, file) + dst_file = os.path.join(dst_dir, file) + fileio.copy(src_file, dst_file) diff --git a/examples/llm_finetuning/pipelines/__init__.py b/examples/llm_finetuning/pipelines/__init__.py new file mode 100644 index 00000000000..2d7c5390a7d --- /dev/null +++ b/examples/llm_finetuning/pipelines/__init__.py @@ -0,0 +1,21 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pipelines.evaluate import llm_lora_evaluation +from pipelines.feature_engineering import llm_lora_feature_engineering +from pipelines.finetuning import llm_lora_finetuning +from pipelines.merge import llm_lora_merging diff --git a/examples/llm_finetuning/pipelines/evaluate.py b/examples/llm_finetuning/pipelines/evaluate.py new file mode 100644 index 00000000000..41feb5bfa72 --- /dev/null +++ b/examples/llm_finetuning/pipelines/evaluate.py @@ -0,0 +1,33 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from steps import evaluate + +from zenml import pipeline +from zenml.config import DockerSettings + + +@pipeline( + settings={ + "docker": DockerSettings( + apt_packages=["git"], requirements="requirements.txt" + ) + } +) +def llm_lora_evaluation() -> None: + """Pipeline to evaluate a LoRA fine-tuned LLM.""" + evaluate() diff --git a/examples/llm_finetuning/pipelines/feature_engineering.py b/examples/llm_finetuning/pipelines/feature_engineering.py new file mode 100644 index 00000000000..6630bd1fb86 --- /dev/null +++ b/examples/llm_finetuning/pipelines/feature_engineering.py @@ -0,0 +1,33 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from steps import feature_engineering + +from zenml import pipeline +from zenml.config import DockerSettings + + +@pipeline( + settings={ + "docker": DockerSettings( + apt_packages=["git"], requirements="requirements.txt" + ) + } +) +def llm_lora_feature_engineering() -> None: + """Feature engineering pipeline.""" + feature_engineering() diff --git a/examples/llm_finetuning/pipelines/finetuning.py b/examples/llm_finetuning/pipelines/finetuning.py new file mode 100644 index 00000000000..faa7d185fda --- /dev/null +++ b/examples/llm_finetuning/pipelines/finetuning.py @@ -0,0 +1,44 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +from steps import finetune + +from zenml import get_pipeline_context, pipeline +from zenml.config import DockerSettings + + +@pipeline( + settings={ + "docker": DockerSettings( + apt_packages=["git"], requirements="requirements.txt" + ) + } +) +def llm_lora_finetuning( + dataset_artifact_name: Optional[str] = None, + dataset_artifact_version: Optional[str] = None, +) -> None: + """Pipeline to finetune LLMs using LoRA.""" + dataset_directory = None + if dataset_artifact_name: + dataset_directory = get_pipeline_context().model.get_artifact( + name=dataset_artifact_name, version=dataset_artifact_version + ) + + finetune(dataset_directory=dataset_directory) diff --git a/examples/llm_finetuning/pipelines/merge.py b/examples/llm_finetuning/pipelines/merge.py new file mode 100644 index 00000000000..20c1c1f36f1 --- /dev/null +++ b/examples/llm_finetuning/pipelines/merge.py @@ -0,0 +1,33 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from steps import merge + +from zenml import pipeline +from zenml.config import DockerSettings + + +@pipeline( + settings={ + "docker": DockerSettings( + apt_packages=["git"], requirements="requirements.txt" + ) + } +) +def llm_lora_merging() -> None: + """Pipeline to merge LLMs with adapters.""" + merge() diff --git a/examples/llm_finetuning/requirements.txt b/examples/llm_finetuning/requirements.txt new file mode 100644 index 00000000000..ad19fe96de8 --- /dev/null +++ b/examples/llm_finetuning/requirements.txt @@ -0,0 +1,17 @@ +zenml +torch>=2.2.0 +lightning @ git+https://github.com/Lightning-AI/lightning@ed367ca675861cdf40dbad2e4d66f7eee2ec50af +jsonargparse[signatures] # CLI +bitsandbytes==0.41.0 # quantization +scipy # required by bitsandbytes +sentencepiece # llama-based models +tokenizers # pythia, falcon, redpajama +datasets # eval +requests # scripts/prepare_* +zstandard # scripts/prepare_redpajama.py, scripts/prepare_starcoder.py +pandas # scripts/prepare_csv.py, scripts/prepare_starcoder.py +pyarrow # scripts/prepare_starcoder.py +# eval +git+https://github.com/EleutherAI/lm-evaluation-harness.git@115206dc89dad67b8beaa90051fb52db77f0a529 +# scripts/prepare_slimpajama.py, scripts/prepare_starcoder.py, pretrain/tinyllama.py +lightning[data] @ git+https://github.com/Lightning-AI/lightning@ed367ca675861cdf40dbad2e4d66f7eee2ec50af diff --git a/examples/llm_finetuning/run.py b/examples/llm_finetuning/run.py new file mode 100644 index 00000000000..5bfd379ba1d --- /dev/null +++ b/examples/llm_finetuning/run.py @@ -0,0 +1,132 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from typing import Optional + +import click +from pipelines import ( + llm_lora_evaluation, + llm_lora_feature_engineering, + llm_lora_finetuning, + llm_lora_merging, +) + +from zenml.logger import get_logger + +logger = get_logger(__name__) + + +@click.command( + help=""" +ZenML LLM Finetuning project CLI v0.1.0. + +Run the ZenML LLM Finetuning project LLM LoRA finetuning pipelines. + +Examples: + + \b + # Run the feature feature engineering pipeline + python run.py --feature-pipeline + + \b + # Run the finetuning pipeline + python run.py --finetuning-pipeline + + \b + # Run the merging pipeline + python run.py --merging-pipeline + + \b + # Run the evaluation pipeline + python run.py --eval-pipeline +""" +) +@click.option( + "--config", + type=str, + default=None, + help="Path to the YAML config file.", +) +@click.option( + "--feature-pipeline", + is_flag=True, + default=False, + help="Whether to run the pipeline that creates the dataset.", +) +@click.option( + "--finetuning-pipeline", + is_flag=True, + default=False, + help="Whether to run the pipeline that finetunes the model.", +) +@click.option( + "--merging-pipeline", + is_flag=True, + default=False, + help="Whether to run the pipeline that merges the model and adapter.", +) +@click.option( + "--eval-pipeline", + is_flag=True, + default=False, + help="Whether to run the pipeline that evaluates the model.", +) +@click.option( + "--no-cache", + is_flag=True, + default=False, + help="Disable caching for the pipeline run.", +) +def main( + config: Optional[str] = None, + feature_pipeline: bool = False, + finetuning_pipeline: bool = False, + merging_pipeline: bool = False, + eval_pipeline: bool = False, + no_cache: bool = False, +): + """Main entry point for the pipeline execution. + + Args: + no_cache: If `True` cache will be disabled. + """ + config_folder = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "configs", + ) + pipeline_args = {"enable_cache": not no_cache} + if not config: + raise RuntimeError("Config file is required to run a pipeline.") + + pipeline_args["config_path"] = os.path.join(config_folder, config) + + if feature_pipeline: + llm_lora_feature_engineering.with_options(**pipeline_args)() + + if finetuning_pipeline: + llm_lora_finetuning.with_options(**pipeline_args)() + + if merging_pipeline: + llm_lora_merging.with_options(**pipeline_args)() + + if eval_pipeline: + llm_lora_evaluation.with_options(**pipeline_args)() + + +if __name__ == "__main__": + main() diff --git a/examples/llm_finetuning/scripts/convert_hf_checkpoint.py b/examples/llm_finetuning/scripts/convert_hf_checkpoint.py new file mode 100644 index 00000000000..14d0ff6fb73 --- /dev/null +++ b/examples/llm_finetuning/scripts/convert_hf_checkpoint.py @@ -0,0 +1,377 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import gc +import json +import sys +from collections import defaultdict +from dataclasses import asdict +from functools import partial +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +import torch +from lightning.fabric.utilities.load import ( + _NotYetLoadedTensor as NotYetLoadedTensor, +) + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Config +from lit_gpt.utils import incremental_save, lazy_load + + +def copy_weights_gpt_neox( + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + weight_map = { + "gpt_neox.embed_in.weight": "transformer.wte.weight", + "gpt_neox.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", + "gpt_neox.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "gpt_neox.layers.{}.attention.query_key_value.bias": "transformer.h.{}.attn.attn.bias", + "gpt_neox.layers.{}.attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", + "gpt_neox.layers.{}.attention.dense.bias": "transformer.h.{}.attn.proj.bias", + "gpt_neox.layers.{}.attention.dense.weight": "transformer.h.{}.attn.proj.weight", + "gpt_neox.layers.{}.attention.rotary_emb.inv_freq": None, + "gpt_neox.layers.{}.attention.bias": None, + "gpt_neox.layers.{}.attention.masked_bias": None, + "gpt_neox.layers.{}.post_attention_layernorm.bias": "transformer.h.{}.norm_2.bias", + "gpt_neox.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight", + "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias": "transformer.h.{}.mlp.fc.bias", + "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", + "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias": "transformer.h.{}.mlp.proj.bias", + "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", + "gpt_neox.final_layer_norm.bias": "transformer.ln_f.bias", + "gpt_neox.final_layer_norm.weight": "transformer.ln_f.weight", + "embed_out.weight": "lm_head.weight", + } + + for name, param in hf_weights.items(): + if "gpt_neox.layers" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name] + if to_name is None: + continue + to_name = to_name.format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_falcon( + model_name: str, + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + weight_map = { + "transformer.word_embeddings.weight": "transformer.wte.weight", + "transformer.h.{}.self_attention.query_key_value.weight": "transformer.h.{}.attn.attn.weight", + "transformer.h.{}.self_attention.dense.weight": "transformer.h.{}.attn.proj.weight", + "transformer.h.{}.mlp.dense_h_to_4h.weight": "transformer.h.{}.mlp.fc.weight", + "transformer.h.{}.mlp.dense_4h_to_h.weight": "transformer.h.{}.mlp.proj.weight", + "transformer.ln_f.bias": "transformer.ln_f.bias", + "transformer.ln_f.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + # the original model definition is different for each size + if "7b" in model_name: + weight_map.update( + { + "transformer.h.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", + "transformer.h.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + } + ) + elif "40b" in model_name or "180B" in model_name: + weight_map.update( + { + "transformer.h.{}.ln_attn.bias": "transformer.h.{}.norm_1.bias", + "transformer.h.{}.ln_attn.weight": "transformer.h.{}.norm_1.weight", + "transformer.h.{}.ln_mlp.bias": "transformer.h.{}.norm_2.bias", + "transformer.h.{}.ln_mlp.weight": "transformer.h.{}.norm_2.weight", + } + ) + else: + raise NotImplementedError + + for name, param in hf_weights.items(): + if "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_hf_llama( + config: Config, + qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.input_layernorm.weight": "transformer.h.{l}.norm_1.weight", + "model.layers.{}.input_layernorm.bias": "transformer.h.{l}.norm_1.bias", + "model.layers.{}.self_attn.q_proj.weight": None, + "model.layers.{}.self_attn.k_proj.weight": None, + "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.o_proj.weight": "transformer.h.{l}.attn.proj.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.post_attention_layernorm.weight": "transformer.h.{l}.norm_2.weight", + "model.layers.{}.post_attention_layernorm.bias": "transformer.h.{l}.norm_2.bias", + "model.norm.weight": "transformer.ln_f.weight", + "model.norm.bias": "transformer.ln_f.bias", + "lm_head.weight": "lm_head.weight", + } + if config._mlp_class == "LLaMAMoE": + weight_map.update( + { + "model.layers.{}.block_sparse_moe.gate.weight": "transformer.h.{l}.mlp.gate.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w1.weight": "transformer.h.{l}.mlp.experts.{e}.fc_1.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w3.weight": "transformer.h.{l}.mlp.experts.{e}.fc_2.weight", + "model.layers.{}.block_sparse_moe.experts.{}.w2.weight": "transformer.h.{l}.mlp.experts.{e}.proj.weight", + } + ) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"): + weight_map.update( + { + "model.layers.{}.mlp.gate_proj.weight": "transformer.h.{l}.mlp.fc_1.weight", + "model.layers.{}.mlp.up_proj.weight": "transformer.h.{l}.mlp.fc_2.weight", + "model.layers.{}.mlp.down_proj.weight": "transformer.h.{l}.mlp.proj.weight", + } + ) + else: + raise NotImplementedError + + for name, param in hf_weights.items(): + if "model.layers" in name: + from_name, l = layer_template(name, 2) + e = None + if "block_sparse_moe.experts" in name: + from_name, e = layer_template(from_name, 5) + qkv = qkv_weights.setdefault(l, [None, None, None]) + if "q_proj" in name: + qkv[0] = param + elif "k_proj" in name: + qkv[1] = param + elif "v_proj" in name: + qkv[2] = param + to_name = weight_map[from_name] + if to_name is None: + continue + to_name = to_name.format(l=l, e=e) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + if "lm_head.weight" not in state_dict: + state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"] + + # convert separate q, k, v matrices into an interleaved qkv + for i, (q, k, v) in list(qkv_weights.items()): + if q is None or k is None or v is None: + # split across different .bin files + continue + q = load_param(q, f"layer {i} q", dtype) + k = load_param(k, f"layer {i} k", dtype) + v = load_param(v, f"layer {i} v", dtype) + q_per_kv = config.n_head // config.n_query_groups + qs = torch.split(q, config.head_size * q_per_kv) + ks = torch.split(k, config.head_size) + vs = torch.split(v, config.head_size) + cycled = [t for group in zip(qs, ks, vs) for t in group] + qkv = torch.cat(cycled) + state_dict[f"transformer.h.{i}.attn.attn.weight"] = qkv + del qkv_weights[i] + + +def copy_weights_phi( + config: Config, + qkv_weights: dict, + state_dict: Dict[str, torch.Tensor], + hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, + dtype: Optional[torch.dtype] = None, +) -> None: + if any( + layer_name.startswith(("layers.", "transformer.")) + for layer_name in hf_weights + ): + raise ValueError( + "You are using an outdated Phi checkpoint. Please reload it as described in 'tutorials/download_phi.md'" + ) + + weight_map = { + "model.embed_tokens.weight": "transformer.wte.weight", + "model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight", + "model.layers.{}.input_layernorm.bias": "transformer.h.{}.norm_1.bias", + "model.layers.{}.self_attn.q_proj.weight": None, + "model.layers.{}.self_attn.q_proj.bias": None, + "model.layers.{}.self_attn.k_proj.weight": None, + "model.layers.{}.self_attn.k_proj.bias": None, + "model.layers.{}.self_attn.v_proj.weight": None, + "model.layers.{}.self_attn.v_proj.bias": None, + "model.layers.{}.self_attn.dense.weight": "transformer.h.{}.attn.proj.weight", + "model.layers.{}.self_attn.dense.bias": "transformer.h.{}.attn.proj.bias", + "model.layers.{}.mlp.fc1.weight": "transformer.h.{}.mlp.fc.weight", + "model.layers.{}.mlp.fc1.bias": "transformer.h.{}.mlp.fc.bias", + "model.layers.{}.mlp.fc2.weight": "transformer.h.{}.mlp.proj.weight", + "model.layers.{}.mlp.fc2.bias": "transformer.h.{}.mlp.proj.bias", + "model.final_layernorm.weight": "transformer.ln_f.weight", + "model.final_layernorm.bias": "transformer.ln_f.bias", + "lm_head.weight": "lm_head.weight", + "lm_head.bias": "lm_head.bias", + } + + for name, param in hf_weights.items(): + if name.startswith("model.layers."): + from_name, l = layer_template(name, 2) + qkv = qkv_weights.setdefault(l, defaultdict(dict)) + if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")): + weight_name, weight_type = from_name.split(".")[-2:] + qkv[weight_type][weight_name] = param + to_name = weight_map[from_name] + if to_name is None: + continue + to_name = to_name.format(l) + else: + to_name = weight_map[name] + param = load_param(param, name, dtype) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + for i in list(qkv_weights): + for weight_type in list(qkv_weights[i]): + qkv = qkv_weights[i][weight_type] + if len(qkv) != 3: + # split across different .bin files + continue + q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype) + k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype) + v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype) + q_per_kv = config.n_head // config.n_query_groups + qs = torch.split(q, config.head_size * q_per_kv) + ks = torch.split(k, config.head_size) + vs = torch.split(v, config.head_size) + cycled = [t for group in zip(qs, ks, vs) for t in group] + qkv = torch.cat(cycled) + state_dict[f"transformer.h.{i}.attn.attn.{weight_type}"] = qkv + del qkv_weights[i][weight_type] + + +def layer_template(layer_name: str, idx: int) -> Tuple[str, int]: + split = layer_name.split(".") + number = int(split[idx]) + split[idx] = "{}" + from_name = ".".join(split) + return from_name, number + + +def load_param( + param: Union[torch.Tensor, NotYetLoadedTensor], + name: str, + dtype: Optional[torch.dtype], +) -> torch.Tensor: + if hasattr(param, "_load_tensor"): + # support tensors loaded via `lazy_load()` + print(f"Loading {name!r} into RAM") + param = param._load_tensor() + if ( + dtype is not None + and type(dtype) is not NotYetLoadedTensor + and dtype != param.dtype + ): + print(f"Converting {name!r} from {param.dtype} to {dtype}") + param = param.to(dtype) + return param + + +@torch.inference_mode() +def convert_hf_checkpoint( + *, + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + model_name: Optional[str] = None, + dtype: Optional[str] = None, +) -> None: + if model_name is None: + model_name = checkpoint_dir.name + if dtype is not None: + dtype = getattr(torch, dtype) + + config = Config.from_name(model_name) + config_dict = asdict(config) + print(f"Model config {config_dict}") + with open(checkpoint_dir / "lit_config.json", "w") as json_config: + json.dump(config_dict, json_config) + + if "falcon" in model_name: + copy_fn = partial(copy_weights_falcon, model_name) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): + # holder to reconstitute the split q, k, v + qkv_weights = {} + copy_fn = partial(copy_weights_hf_llama, config, qkv_weights) + elif "phi" in model_name: + # holder to reconstitute the split q, k, v + qkv_weights = {} + copy_fn = partial(copy_weights_phi, config, qkv_weights) + else: + copy_fn = copy_weights_gpt_neox + + # initialize a new empty state dict to hold our new weights + sd = {} + + # Load the json file containing weight mapping + pytorch_bin_map_json_path = checkpoint_dir / "pytorch_model.bin.index.json" + if ( + pytorch_bin_map_json_path.is_file() + ): # not all checkpoints have this file + with open(pytorch_bin_map_json_path) as json_map: + bin_index = json.load(json_map) + bin_files = { + checkpoint_dir / bin for bin in bin_index["weight_map"].values() + } + else: + bin_files = set(checkpoint_dir.glob("*.bin")) + # some checkpoints serialize the training arguments + bin_files = {f for f in bin_files if f.name != "training_args.bin"} + if not bin_files: + raise ValueError( + f"Expected {str(checkpoint_dir)!r} to contain .bin files" + ) + + with incremental_save(checkpoint_dir / "lit_model.pth") as saver: + # for checkpoints that split the QKV across several files, we need to keep all the bin files + # open, so we use `ExitStack` to close them all together at the end + for bin_file in sorted(bin_files): + print("Processing", bin_file) + hf_weights = lazy_load(bin_file) + copy_fn(sd, hf_weights, saver=saver, dtype=dtype) + gc.collect() + print("Saving converted checkpoint") + saver.save(sd) + + +if __name__ == "__main__": + from jsonargparse import CLI + + CLI(convert_hf_checkpoint) diff --git a/examples/llm_finetuning/scripts/convert_lit_checkpoint.py b/examples/llm_finetuning/scripts/convert_lit_checkpoint.py new file mode 100644 index 00000000000..1239e7d255d --- /dev/null +++ b/examples/llm_finetuning/scripts/convert_lit_checkpoint.py @@ -0,0 +1,284 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import gc +import sys +from functools import partial +from pathlib import Path +from typing import Dict, Optional, Tuple, Union + +import torch +from lightning.fabric.utilities.load import ( + _NotYetLoadedTensor as NotYetLoadedTensor, +) + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Config +from lit_gpt.utils import CLI, incremental_save, lazy_load + +from scripts.convert_hf_checkpoint import layer_template, load_param + + +def copy_weights_falcon( + model_name: str, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "transformer.word_embeddings.weight", + "transformer.h.{}.attn.attn.weight": "transformer.h.{}.self_attention.query_key_value.weight", + "transformer.h.{}.attn.proj.weight": "transformer.h.{}.self_attention.dense.weight", + "transformer.h.{}.mlp.fc.weight": "transformer.h.{}.mlp.dense_h_to_4h.weight", + "transformer.h.{}.mlp.proj.weight": "transformer.h.{}.mlp.dense_4h_to_h.weight", + "transformer.ln_f.bias": "transformer.ln_f.bias", + "transformer.ln_f.weight": "transformer.ln_f.weight", + "lm_head.weight": "lm_head.weight", + } + # the original model definition is different for each size + if "7b" in model_name: + weight_map.update( + { + "transformer.h.{}.norm_1.bias": "transformer.h.{}.input_layernorm.bias", + "transformer.h.{}.norm_1.weight": "transformer.h.{}.input_layernorm.weight", + } + ) + elif "40b" in model_name or "180B" in model_name: + weight_map.update( + { + "transformer.h.{}.norm_1.bias": "transformer.h.{}.ln_attn.bias", + "transformer.h.{}.norm_1.weight": "transformer.h.{}.ln_attn.weight", + "transformer.h.{}.norm_2.bias": "transformer.h.{}.ln_mlp.bias", + "transformer.h.{}.norm_2.weight": "transformer.h.{}.ln_mlp.weight", + } + ) + else: + raise NotImplementedError + + for name, param in lit_weights.items(): + if "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_gpt_neox( + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "gpt_neox.embed_in.weight", + "transformer.h.{}.norm_1.bias": "gpt_neox.layers.{}.input_layernorm.bias", + "transformer.h.{}.norm_1.weight": "gpt_neox.layers.{}.input_layernorm.weight", + "transformer.h.{}.attn.attn.bias": "gpt_neox.layers.{}.attention.query_key_value.bias", + "transformer.h.{}.attn.attn.weight": "gpt_neox.layers.{}.attention.query_key_value.weight", + "transformer.h.{}.attn.proj.bias": "gpt_neox.layers.{}.attention.dense.bias", + "transformer.h.{}.attn.proj.weight": "gpt_neox.layers.{}.attention.dense.weight", + "transformer.h.{}.norm_2.bias": "gpt_neox.layers.{}.post_attention_layernorm.bias", + "transformer.h.{}.norm_2.weight": "gpt_neox.layers.{}.post_attention_layernorm.weight", + "transformer.h.{}.mlp.fc.bias": "gpt_neox.layers.{}.mlp.dense_h_to_4h.bias", + "transformer.h.{}.mlp.fc.weight": "gpt_neox.layers.{}.mlp.dense_h_to_4h.weight", + "transformer.h.{}.mlp.proj.bias": "gpt_neox.layers.{}.mlp.dense_4h_to_h.bias", + "transformer.h.{}.mlp.proj.weight": "gpt_neox.layers.{}.mlp.dense_4h_to_h.weight", + "transformer.ln_f.bias": "gpt_neox.final_layer_norm.bias", + "transformer.ln_f.weight": "gpt_neox.final_layer_norm.weight", + "lm_head.weight": "embed_out.weight", + } + + for name, param in lit_weights.items(): + if "transformer.h" in name: + from_name, number = layer_template(name, 2) + to_name = weight_map[from_name].format(number) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_llama( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + untie_weights: bool = False, + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.norm_1.weight": "model.layers.{l}.input_layernorm.weight", + "transformer.h.{}.norm_1.bias": "model.layers.{l}.input_layernorm.bias", + "transformer.h.{}.attn.proj.weight": "model.layers.{l}.self_attn.o_proj.weight", + "transformer.h.{}.norm_2.weight": "model.layers.{l}.post_attention_layernorm.weight", + "transformer.h.{}.norm_2.bias": "model.layers.{l}.post_attention_layernorm.bias", + "transformer.ln_f.weight": "model.norm.weight", + "transformer.ln_f.bias": "model.norm.bias", + "lm_head.weight": "lm_head.weight", + } + if config._mlp_class == "LLaMAMoE": + weight_map.update( + { + "transformer.h.{}.mlp.gate.weight": "model.layers.{l}.block_sparse_moe.gate.weight", + "transformer.h.{}.mlp.experts.{}.fc_1.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w1.weight", + "transformer.h.{}.mlp.experts.{}.fc_2.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w3.weight", + "transformer.h.{}.mlp.experts.{}.proj.weight": "model.layers.{l}.block_sparse_moe.experts.{e}.w2.weight", + } + ) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP"): + weight_map.update( + { + "transformer.h.{}.mlp.fc_1.weight": "model.layers.{l}.mlp.gate_proj.weight", + "transformer.h.{}.mlp.fc_2.weight": "model.layers.{l}.mlp.up_proj.weight", + "transformer.h.{}.mlp.proj.weight": "model.layers.{l}.mlp.down_proj.weight", + } + ) + else: + raise NotImplementedError + + for name, param in lit_weights.items(): + if name == "lm_head.weight" and untie_weights: + continue + if name.endswith(".attn.attn.weight"): + from_name, l = layer_template(name, 2) + q = "model.layers.{}.self_attn.q_proj.weight".format(l) + k = "model.layers.{}.self_attn.k_proj.weight".format(l) + v = "model.layers.{}.self_attn.v_proj.weight".format(l) + qkv = load_param(param, name, None) + qp, kp, vp = qkv_split(qkv, config) + for to_name, param in zip((q, k, v), (qp, kp, vp)): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + else: + if "transformer.h" in name: + from_name, l = layer_template(name, 2) + e = None + if "mlp.experts" in name: + from_name, e = layer_template(from_name, 5) + to_name = weight_map[from_name] + to_name = to_name.format(l=l, e=e) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def copy_weights_phi( + config: Config, + state_dict: Dict[str, torch.Tensor], + lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]], + saver: Optional[incremental_save] = None, +) -> None: + weight_map = { + "transformer.wte.weight": "model.embed_tokens.weight", + "transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight", + "transformer.h.{}.norm_1.bias": "model.layers.{}.input_layernorm.bias", + "transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.dense.weight", + "transformer.h.{}.attn.proj.bias": "model.layers.{}.self_attn.dense.bias", + "transformer.h.{}.mlp.fc.weight": "model.layers.{}.mlp.fc1.weight", + "transformer.h.{}.mlp.fc.bias": "model.layers.{}.mlp.fc1.bias", + "transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.fc2.weight", + "transformer.h.{}.mlp.proj.bias": "model.layers.{}.mlp.fc2.bias", + "transformer.ln_f.weight": "model.final_layernorm.weight", + "transformer.ln_f.bias": "model.final_layernorm.bias", + "lm_head.weight": "lm_head.weight", + "lm_head.bias": "lm_head.bias", + } + + for name, param in lit_weights.items(): + if name.endswith((".attn.attn.weight", ".attn.attn.bias")): + from_name, l = layer_template(name, 2) + weight_type = name.split(".")[-1] # weight or bias + q = f"model.layers.{l}.self_attn.q_proj.{weight_type}" + k = f"model.layers.{l}.self_attn.k_proj.{weight_type}" + v = f"model.layers.{l}.self_attn.v_proj.{weight_type}" + qkv = load_param(param, name, None) + qp, kp, vp = qkv_split(qkv, config) + for to_name, param in zip((q, k, v), (qp, kp, vp)): + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + else: + if "transformer.h" in name: + from_name, l = layer_template(name, 2) + to_name = weight_map[from_name] + to_name = to_name.format(l) + else: + to_name = weight_map[name] + param = load_param(param, name, None) + if saver is not None: + param = saver.store_early(param) + state_dict[to_name] = param + + +def qkv_split( + param: Union[torch.Tensor, NotYetLoadedTensor], config: Config +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q_per_kv = config.n_head // config.n_query_groups + qs = [] + ks = [] + vs = [] + for chunk in torch.chunk(param, config.n_query_groups): + split = torch.split( + chunk, + [config.head_size * q_per_kv, config.head_size, config.head_size], + ) + qs.append(split[0]) + ks.append(split[1]) + vs.append(split[2]) + q = torch.cat(qs) + k = torch.cat(ks) + v = torch.cat(vs) + return q, k, v + + +def check_conversion_supported(lit_weights: Dict[str, torch.Tensor]) -> None: + if any("lora" in wn for wn in lit_weights): + raise ValueError( + "Checkpoints with LoRA weights cannot be converted. Call `scripts/merge_lora.py` first." + ) + if any("adapter" in wn or "gating_factor" in wn for wn in lit_weights): + raise NotImplementedError("Converting adapter models is supported.") + + +@torch.inference_mode() +def convert_lit_checkpoint( + checkpoint_path: Path, output_path: Path, config_path: Path +) -> None: + config = Config.from_json(config_path) + + if "falcon" in config.name: + copy_fn = partial(copy_weights_falcon, config.name) + elif config._mlp_class in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"): + untie_weights = "Gemma" in config.name + copy_fn = partial( + copy_weights_llama, config, untie_weights=untie_weights + ) + elif "phi" in config.name: + copy_fn = partial(copy_weights_phi, config) + else: + copy_fn = copy_weights_gpt_neox + + # initialize a new empty state dict to hold our new weights + sd = {} + with incremental_save(output_path) as saver: + lit_weights = lazy_load(checkpoint_path) + lit_weights = lit_weights.get("model", lit_weights) + check_conversion_supported(lit_weights) + copy_fn(sd, lit_weights, saver=saver) + gc.collect() + saver.save(sd) + + +if __name__ == "__main__": + CLI(convert_lit_checkpoint) diff --git a/examples/llm_finetuning/scripts/convert_pretrained_checkpoint.py b/examples/llm_finetuning/scripts/convert_pretrained_checkpoint.py new file mode 100644 index 00000000000..a6c3093374a --- /dev/null +++ b/examples/llm_finetuning/scripts/convert_pretrained_checkpoint.py @@ -0,0 +1,88 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +import shutil +import sys +from dataclasses import asdict +from pathlib import Path + +import torch + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Config +from lit_gpt.utils import CLI, incremental_save + + +@torch.inference_mode() +def convert_checkpoint( + checkpoint_file: Path, + tokenizer_dir: Path, + config_name: str, + output_dir: Path, +) -> None: + """Convert a checkpoint after pretraining. + + The pretrained checkpoint contains optimizer states and several other metadata that are not needed after training + is finished. This script will export the state-dict of the model and place it in the chosen output folder together + with the tokenizer and model config, which then can be loaded by other scripts for inference, evaluation, etc. + + Args: + checkpoint_file: Path to a checkpoint file scripts produced by the scripts in ``lit_gpt/pretrain/``. + tokenizer_dir: A path to the folder that holds the tokenizer configuration files that were used to train + the model. All files with a name starting with 'tokenizer' will be copied to the output folder. + config_name: The name of the model loaded with the ``lit_gpt.Config``. The configuration will be saved as a + JSON file to the output folder. + output_dir: The output folder where model state-dict file, the tokenizer config file, and the model config + file will be saved. + """ + + if output_dir.is_dir() and output_dir.glob("*"): + raise FileExistsError( + f"The output folder exists and is not empty: {str(output_dir)}." + " Please delete it first or choose a different name." + ) + if not tokenizer_dir.is_dir(): + raise FileNotFoundError( + f"The tokenizer_dir must be a directory: {str(output_dir)}." + ) + + output_dir.mkdir(parents=True) + output_checkpoint_file = output_dir / "lit_model.pth" + output_config_file = output_dir / "lit_config.json" + + # Save the config to output folder + config = Config.from_name(config_name) + with open(output_config_file, "w") as json_config: + json.dump(asdict(config), json_config) + + # Export the tokenizer configuration to output folder + for tokenizer_file in tokenizer_dir.glob("tokenizer*"): + shutil.copyfile(tokenizer_file, output_dir / tokenizer_file.name) + + # Copy config for tokenization if found + if (tokenizer_dir / "generation_config.json").is_file(): + shutil.copyfile( + tokenizer_dir / "generation_config.json", + output_dir / "generation_config.json", + ) + + # Extract the model state dict and save to output folder + with incremental_save(output_checkpoint_file) as saver: + print("Processing", checkpoint_file) + full_checkpoint = torch.load(str(checkpoint_file), mmap=True) + loaded_state_dict = full_checkpoint["model"] + converted_state_dict = {} + for param_name, param in loaded_state_dict.items(): + saver.store_early(param) + # remove prefix for compiled model (if any) + param_name = param_name.replace("_orig_mod.", "") + converted_state_dict[param_name] = param + print(f"Saving converted checkpoint to {str(output_checkpoint_file)}.") + saver.save(converted_state_dict) + + +if __name__ == "__main__": + CLI(convert_checkpoint) diff --git a/examples/llm_finetuning/scripts/download.py b/examples/llm_finetuning/scripts/download.py new file mode 100644 index 00000000000..e5a7459d2be --- /dev/null +++ b/examples/llm_finetuning/scripts/download.py @@ -0,0 +1,106 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import os +import sys +from pathlib import Path +from typing import Optional + +import torch +from lightning_utilities.core.imports import RequirementCache + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.utils import CLI + +_SAFETENSORS_AVAILABLE = RequirementCache("safetensors") +_HF_TRANSFER_AVAILABLE = RequirementCache("hf_transfer") + + +def download_from_hub( + repo_id: Optional[str] = None, + access_token: Optional[str] = os.getenv("HF_TOKEN"), + from_safetensors: bool = False, + tokenizer_only: bool = False, + checkpoint_dir: Path = Path("checkpoints"), +) -> None: + if repo_id is None: + from lit_gpt.config import configs + + options = [ + f"{config['hf_config']['org']}/{config['hf_config']['name']}" + for config in configs + ] + print("Please specify --repo_id . Available values:") + print("\n".join(options)) + return + + from huggingface_hub import snapshot_download + + if ( + "meta-llama" in repo_id or "falcon-180" in repo_id + ) and not access_token: + raise ValueError( + f"{repo_id} requires authentication, please set the `HF_TOKEN=your_token` environment" + " variable or pass --access_token=your_token. You can find your token by visiting" + " https://huggingface.co/settings/tokens" + ) + + download_files = ["tokenizer*", "generation_config.json"] + if not tokenizer_only: + if from_safetensors: + if not _SAFETENSORS_AVAILABLE: + raise ModuleNotFoundError(str(_SAFETENSORS_AVAILABLE)) + download_files.append("*.safetensors") + else: + # covers `.bin` files and `.bin.index.json` + download_files.append("*.bin*") + elif from_safetensors: + raise ValueError( + "`--from_safetensors=True` won't have an effect with `--tokenizer_only=True`" + ) + + import huggingface_hub._snapshot_download as download + import huggingface_hub.constants as constants + + previous = constants.HF_HUB_ENABLE_HF_TRANSFER + if _HF_TRANSFER_AVAILABLE and not previous: + print("Setting HF_HUB_ENABLE_HF_TRANSFER=1") + constants.HF_HUB_ENABLE_HF_TRANSFER = True + download.HF_HUB_ENABLE_HF_TRANSFER = True + + directory = checkpoint_dir / repo_id + snapshot_download( + repo_id, + local_dir=directory, + local_dir_use_symlinks=False, + resume_download=True, + allow_patterns=download_files, + token=access_token, + ) + + constants.HF_HUB_ENABLE_HF_TRANSFER = previous + download.HF_HUB_ENABLE_HF_TRANSFER = previous + + # convert safetensors to PyTorch binaries + if from_safetensors: + from safetensors import SafetensorError + from safetensors.torch import load_file as safetensors_load + + print("Converting .safetensor files to PyTorch binaries (.bin)") + for safetensor_path in directory.glob("*.safetensors"): + bin_path = safetensor_path.with_suffix(".bin") + try: + result = safetensors_load(safetensor_path) + except SafetensorError as e: + raise RuntimeError( + f"{safetensor_path} is likely corrupted. Please try to re-download it." + ) from e + print(f"{safetensor_path} --> {bin_path}") + torch.save(result, bin_path) + os.remove(safetensor_path) + + +if __name__ == "__main__": + CLI(download_from_hub) diff --git a/examples/llm_finetuning/scripts/merge_lora.py b/examples/llm_finetuning/scripts/merge_lora.py new file mode 100644 index 00000000000..89818a999fa --- /dev/null +++ b/examples/llm_finetuning/scripts/merge_lora.py @@ -0,0 +1,94 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""This script merges the LoRA weights with the base model""" + +import sys +from pathlib import Path +from typing import Optional + +import lightning as L +import torch + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.lora import GPT, Config, lora_filter, merge_lora_weights +from lit_gpt.utils import ( + CLI, + check_valid_checkpoint_dir, + get_default_supported_precision, + lazy_load, +) + + +def merge_lora( + lora_path: Path = Path("out/lora/alpaca/lit_model_lora_finetuned.pth"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + out_dir: Path = Path("out/lora/checkpoint"), + precision: Optional[str] = None, + lora_r: int = 8, + lora_alpha: int = 16, + lora_dropout: float = 0.05, + lora_query: bool = True, + lora_key: bool = False, + lora_value: bool = True, + lora_projection: bool = False, + lora_mlp: bool = False, + lora_head: bool = False, +) -> None: + """Generates a response based on a given instruction and an optional input. + This script will only work with checkpoints from the instruction-tuned GPT-LoRA model. + See `finetune/lora.py`. + + Args: + lora_path: Path to the checkpoint with trained adapter weights, which are the output of + `finetune/lora.py`. + checkpoint_dir: The path to the checkpoint folder with pretrained GPT weights. + out_dir: The path to the merged model that is created by this script. + precision: Indicates the Fabric precision setting to use. + """ + check_valid_checkpoint_dir(checkpoint_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + precision = precision or get_default_supported_precision(training=False) + fabric = L.Fabric(devices=1, precision=precision) + + config = Config.from_json( + checkpoint_dir / "lit_config.json", + r=lora_r, + alpha=lora_alpha, + dropout=lora_dropout, + to_query=lora_query, + to_key=lora_key, + to_value=lora_value, + to_projection=lora_projection, + to_mlp=lora_mlp, + to_head=lora_head, + ) + + with fabric.init_module(empty_init=True): + model = GPT(config) + checkpoint_path = checkpoint_dir / "lit_model.pth" + checkpoint = lazy_load(checkpoint_path) + lora_checkpoint = lazy_load(lora_path) + checkpoint.update(lora_checkpoint.get("model", lora_checkpoint)) + model.load_state_dict(checkpoint) + + merge_lora_weights(model) + + save_path = out_dir / "lit_model.pth" + fabric.print(f"Saving weights to {str(save_path)!r}") + # remove lora parameters and the lora linear substring + state_dict = { + k.replace("linear.", ""): v + for k, v in model.state_dict().items() + if not lora_filter(k, v) + } + torch.save(state_dict, save_path) + + +if __name__ == "__main__": + CLI(merge_lora) diff --git a/examples/llm_finetuning/scripts/prepare_alpaca.py b/examples/llm_finetuning/scripts/prepare_alpaca.py new file mode 100644 index 00000000000..cde6fca1b67 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_alpaca.py @@ -0,0 +1,169 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation derived from https://github.com/tloen/alpaca-lora""" + +import json +import sys +from pathlib import Path +from typing import Optional + +import torch +from lightning_utilities.core.imports import RequirementCache +from torch.utils.data import random_split +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + + +def prepare( + destination_path: Path = Path("data/alpaca"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + test_split_fraction: float = 0.03865, # to get exactly 2000 test samples, + seed: int = 42, + mask_inputs: bool = False, # as in alpaca-lora + data_file_name: str = "alpaca_data_cleaned_archive.json", + data_file_url: str = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json", + ignore_index: int = -1, + max_seq_length: Optional[int] = None, +) -> None: + """Prepare the Alpaca dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + if max_seq_length is None: + with open( + checkpoint_dir / "lit_config.json", "r", encoding="utf-8" + ) as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + data_file_path = destination_path / data_file_name + print("Loading data file...") + download_if_missing(data_file_path, data_file_url) + with open(data_file_path, "r", encoding="utf-8") as file: + data = json.load(file) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + # Partition the dataset into train and test + train_set, test_set = random_split( + data, + [1.0 - test_split_fraction, test_split_fraction], + generator=torch.Generator().manual_seed(seed), + ) + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def download_if_missing(file_path: Path, file_url: str) -> None: + """Downloads the raw json data file and saves it in the given destination.""" + if file_path.exists() and file_path.stat().st_size > 0: + return + requests_available = RequirementCache("requests") + if not requests_available: + raise ModuleNotFoundError(str(requests_available)) + import requests + + with open(file_path, "w", encoding="utf-8") as f: + f.write(requests.get(file_url).text) + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + if example["input"]: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_csv.py b/examples/llm_finetuning/scripts/prepare_csv.py new file mode 100644 index 00000000000..bbd27074d52 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_csv.py @@ -0,0 +1,157 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +import logging +import sys +from pathlib import Path +from typing import Optional, Tuple + +import torch +from torch.utils.data import random_split +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +logger = logging.getLogger(__name__) +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + + +def prepare( + csv_path: Path, + destination_path: Path = Path("data/csv"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + test_split_fraction: float = 0.1, + seed: int = 42, + mask_inputs: bool = False, + ignore_index: int = -1, + max_seq_length: Optional[int] = None, + columns: Tuple[str, ...] = ("instruction", "input", "output"), +) -> None: + """Prepare a CSV dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + if max_seq_length is None: + with open(checkpoint_dir / "lit_config.json", "r") as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + logger.info("Loading data file ...") + import pandas as pd + + df = pd.read_csv(csv_path, dtype=str).fillna("") + if not (df.columns.values == columns).all(): + raise ValueError( + f"CSV columns must be {columns}, found {df.columns.values}" + ) + data = json.loads(df.to_json(orient="records", indent=4)) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + # Partition the dataset into train and test + train_set, test_set = random_split( + data, + [1.0 - test_split_fraction, test_split_fraction], + generator=torch.Generator().manual_seed(seed), + ) + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + if example["input"]: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_dolly.py b/examples/llm_finetuning/scripts/prepare_dolly.py new file mode 100644 index 00000000000..8bb434398fa --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_dolly.py @@ -0,0 +1,163 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation derived from https://github.com/tloen/alpaca-lora""" + +import json +import sys +from pathlib import Path +from typing import Optional + +import torch +from torch.utils.data import random_split +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + +from scripts.prepare_alpaca import download_if_missing + + +def prepare( + destination_path: Path = Path("data/dolly"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + test_split_fraction: float = 0.1, + seed: int = 42, + mask_inputs: bool = False, + data_file_name: str = "dolly_data_cleaned.json", + data_file_url: str = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl", + ignore_index: int = -1, + max_seq_length: Optional[int] = None, +) -> None: + """Prepare the Dolly 15k dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + + if max_seq_length is None: + with open( + checkpoint_dir / "lit_config.json", "r", encoding="utf-8" + ) as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + data_file_path = destination_path / data_file_name + print("Loading data file...") + download_if_missing(data_file_path, data_file_url) + + with open(data_file_path, "r", encoding="utf-8") as file: + data = file.readlines() + data = [json.loads(line) for line in data] + for item in data: + item["input"] = item.pop("context") + item["output"] = item.pop("response") + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + # Partition the dataset into train and test + train_set, test_set = random_split( + data, + [1.0 - test_split_fraction, test_split_fraction], + generator=torch.Generator().manual_seed(seed), + ) + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + if example["input"]: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_flan.py b/examples/llm_finetuning/scripts/prepare_flan.py new file mode 100644 index 00000000000..a34b547213b --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_flan.py @@ -0,0 +1,249 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation derived from https://github.com/tloen/alpaca-lora""" +import json +import sys +from pathlib import Path +from typing import Optional + +import torch +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + +from scripts.prepare_alpaca import download_if_missing + + +def load_jsonl(filename): + data = [] + with open(filename, "r", encoding="utf-8") as f: + for line in f: + data.append(json.loads(line)) + return data + + +def prepare( + destination_path: Path = Path("data/flan"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + mask_inputs: bool = False, # as in alpaca-lora + subsets: Optional[str] = None, + ignore_index: int = -1, + max_seq_length: Optional[int] = None, +) -> None: + """Prepare the FLAN-collection datasets for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + + Since the original test set does not have responses, the validation set + is used as the test set. + """ + + supported_subsets = { + "aeslc_10templates", + "ag_news_subset_10templates", + "anli_r1_10templates", + "anli_r2_10templates", + "anli_r3_10templates", + "arc_challenge_10templates", + "arc_easy_10templates", + "bool_q_10templates", + "cb_10templates", + "cnn_dailymail_10templates", + "cola_10templates", + "common_gen_10templates", + "copa_10templates", + "coqa_10templates", + "cosmos_qa_10templates", + "dart_10templates", + "definite_pronoun_resolution_10templates", + "drop_10templates", + "e2e_nlg_10templates", + "fix_punct_10templates", + "gigaword_10templates", + "glue_mrpc_10templates", + "glue_qqp_10templates", + "hellaswag_10templates", + "imdb_reviews_10templates", + "math_dataset_10templates", + "mnli_matched_10templates", + "mnli_mismatched_10templates", + "multi_news_10templates", + "multirc_10templates", + "natural_questions_10templates", + "openbookqa_10templates", + "opinion_abstracts_idebate_10templates", + "opinion_abstracts_rotten_tomatoes_10templates", + "para_crawl_enes_10templates", + "paws_wiki_10templates", + "piqa_10templates", + "qnli_10templates", + "quac_10templates", + "record_10templates", + "rte_10templates", + "samsum_10templates", + "sentiment140_10templates", + "snli_10templates", + "squad_v1_10templates", + "squad_v2_10templates", + "sst2_10templates", + "story_cloze_10templates", + "stsb_10templates", + "trec_10templates", + "trivia_qa_10templates", + "true_case_10templates", + "web_nlg_en_10templates", + "wic_10templates", + "wiki_lingua_english_en_10templates", + "wmt14_enfr_10templates", + "wmt16_translate_csen_10templates", + "wmt16_translate_deen_10templates", + "wmt16_translate_fien_10templates", + "wmt16_translate_roen_10templates", + "wmt16_translate_ruen_10templates", + "wmt16_translate_tren_10templates", + "wnli_10templates", + "word_segment_10templates", + "wsc_10templates", + "yelp_polarity_reviews_10templates", + } + + if subsets is not None: + subsets = subsets.split(",") + for sub in subsets: + if sub not in supported_subsets: + raise ValueError(f"{sub} not in {supported_subsets}") + else: + subsets = list(supported_subsets) + + if max_seq_length is None: + with open( + checkpoint_dir / "lit_config.json", "r", encoding="utf-8" + ) as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + print("Loading data file...") + + base_url = "https://huggingface.co/datasets/Muennighoff/flan/resolve/main/" + + train_set, test_set = [], [] + for sub in subsets: + train_sub = sub + "_train" + data_file_name = train_sub + ".jsonl" + data_file_path = destination_path / data_file_name + data_file_url = base_url + "train/" + data_file_name + + print(f"Loading training data file {sub}...") + download_if_missing(data_file_path, data_file_url) + sub_train_set = load_jsonl(data_file_path) + train_set.extend(sub_train_set) + + test_sub = sub + "_test" + data_file_name = test_sub + ".jsonl" + data_file_path = destination_path / data_file_name + data_file_url = base_url + "test/" + data_file_name + + print(f"Loading test data file {sub}...") + download_if_missing(data_file_path, data_file_url) + sub_test_set = load_jsonl(data_file_path) + test_set.extend(sub_test_set) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +): + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["targets"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example): + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['inputs']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_lima.py b/examples/llm_finetuning/scripts/prepare_lima.py new file mode 100644 index 00000000000..e27928ce9e2 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_lima.py @@ -0,0 +1,198 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation derived from https://github.com/tloen/alpaca-lora""" + +import json +import os +import sys +from pathlib import Path +from typing import List, Optional + +import torch +from torch.utils.data import random_split +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + + +def prepare( + destination_path: Path = Path("data/lima"), + test_split_fraction: float = 0.1, + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + mask_inputs: bool = False, # as in alpaca-lora + seed: int = 42, + include_multiturn_conversations: bool = False, + data_repo_id: str = "GAIR/lima", + ignore_index: int = -1, + access_token: Optional[str] = os.getenv("HF_TOKEN"), + max_seq_length: Optional[int] = None, +) -> None: + """Prepare the LIMA dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + + if access_token is None: + raise ValueError( + "LIMA requires authentication, please set the `HF_TOKEN=your_token` environment" + " variable or pass --access_token=your_token. You can find your token by visiting" + " https://huggingface.co/settings/tokens" + ) + + if max_seq_length is None: + with open( + checkpoint_dir / "lit_config.json", "r", encoding="utf-8" + ) as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + print("Loading data file...") + + from datasets import load_dataset + + dataset = load_dataset(data_repo_id, token=access_token) + train_data = format_dataset( + dataset["train"], include_multiturn_conversations + ) + + # test set is present but doesn't have any solutions, so we cannot use it here + # but have to create our own + # for consistency with prepare_alpaca.py and prepare_dolly.py + # test_set = format_dataset(dataset["test"], include_multiturn_conversations) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + # Partition the dataset into train and test + train_set, test_set = random_split( + train_data, + [1.0 - test_split_fraction, test_split_fraction], + generator=torch.Generator().manual_seed(seed), + ) + train_set, test_set = list(train_set), list(test_set) + + print(f"train has {len(train_set):,} samples") + print(f"test has {len(test_set):,} samples") + + print("Processing train split ...") + train_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_set) + ] + torch.save(train_set, destination_path / "train.pt") + + print("Processing test split ...") + test_set = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_set) + ] + torch.save(test_set, destination_path / "test.pt") + + +def format_dataset( + dataset_partition: dict, include_multi_turn_conversations: bool +) -> List[dict]: + formatted_ds = [] + + for entry in dataset_partition: + convo = entry["conversations"] + if include_multi_turn_conversations: + for i in range(0, len(convo) - 1, 2): + formatted_ds.append( + { + "instruction": convo[i], + "input": "", + "output": convo[i + 1], + } + ) + + else: + formatted_ds.append( + {"instruction": convo[0], "input": "", "output": convo[1]} + ) + + return formatted_ds + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction, optional input and a + 'response' field.""" + + if example["input"]: + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:" + ) + return ( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['instruction']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_longform.py b/examples/llm_finetuning/scripts/prepare_longform.py new file mode 100644 index 00000000000..6327bad8654 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_longform.py @@ -0,0 +1,153 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +"""Implementation derived from https://github.com/tloen/alpaca-lora""" + +import json +import sys +from pathlib import Path +from typing import Optional + +import torch +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt.tokenizer import Tokenizer +from lit_gpt.utils import CLI + +from scripts.prepare_alpaca import download_if_missing + + +def prepare( + destination_path: Path = Path("data/longform"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + mask_inputs: bool = False, # as in alpaca-lora + ignore_index: int = -1, + max_seq_length: Optional[int] = None, +) -> None: + """Prepare the Alpaca dataset for instruction tuning. + + The output is a training and test dataset saved as `train.pt` and `test.pt`, + which stores the preprocessed and tokenized prompts and labels. + """ + if max_seq_length is None: + with open( + checkpoint_dir / "lit_config.json", "r", encoding="utf-8" + ) as file: + config = json.load(file) + max_seq_length = config["block_size"] + + destination_path.mkdir(parents=True, exist_ok=True) + + train_file_name = "train.json" + # val_file_name = "val.json" + test_file_name = "test.json" + + train_file_url = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset/train.json" + # val_file_url = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset/val.json" + test_file_url = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset/test.json" + + train_file_path = destination_path / train_file_name + print("Loading train data file...") + download_if_missing(train_file_path, train_file_url) + with open(train_file_path, "r", encoding="utf-8") as file: + train_data = json.load(file) + + test_file_path = destination_path / test_file_name + print("Loading test data file...") + download_if_missing(test_file_path, test_file_url) + with open(test_file_path, "r", encoding="utf-8") as file: + test_data = json.load(file) + + print("Loading tokenizer...") + tokenizer = Tokenizer(checkpoint_dir) + + print(f"train has {len(train_data):,} samples") + print(f"test has {len(test_data):,} samples") + + print("Processing train set ...") + train_data = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(train_data) + ] + torch.save(train_data, destination_path / "train.pt") + + print("Processing test set ...") + test_data = [ + prepare_sample( + example=sample, + tokenizer=tokenizer, + max_length=max_seq_length, + mask_inputs=mask_inputs, + ignore_index=ignore_index, + ) + for sample in tqdm(test_data) + ] + torch.save(test_data, destination_path / "test.pt") + + +def prepare_sample( + example: dict, + tokenizer: Tokenizer, + max_length: int, + mask_inputs: bool, + ignore_index: int, +) -> dict: + """Processes a single sample. + + Each sample in the dataset consists of: + - instruction: A string describing the task + - input: A string holding a special input value for the instruction. + This only applies to some samples, and in others this is empty. + - output: The response string + + This function processes this data to produce a prompt text and a label for + supervised training. The prompt text is formed as a single message including both + the instruction and the input. The label/target is the same message but with the + response attached. + + Finally, both the prompt and the label get tokenized. If desired, all tokens + in the label that correspond to the original input prompt get masked out (default). + """ + full_prompt = generate_prompt(example) + full_prompt_and_response = full_prompt + example["output"] + encoded_full_prompt = tokenizer.encode(full_prompt, max_length=max_length) + encoded_full_prompt_and_response = tokenizer.encode( + full_prompt_and_response, eos=True, max_length=max_length + ) + + # The labels are the full prompt with response, but with the prompt masked out + labels = encoded_full_prompt_and_response.clone() + if mask_inputs: + labels[: len(encoded_full_prompt)] = ignore_index + + return { + **example, + "input_ids": encoded_full_prompt_and_response, + "labels": labels, + } + + +def generate_prompt(example: dict) -> str: + """Generates a standardized message to prompt the model with an instruction and a + 'response' field.""" + + return ( + "Below is an instruction that describes a task, paired with an input that provides further context. " + "Write a response that appropriately completes the request.\n\n" + f"### Instruction:\n{example['input']}\n\n### Response:" + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_openwebtext.py b/examples/llm_finetuning/scripts/prepare_openwebtext.py new file mode 100644 index 00000000000..fbb4a8d9d96 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_openwebtext.py @@ -0,0 +1,100 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +# saves the openwebtext dataset to a binary file for training. following was helpful: +# https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py +import os +import sys +from pathlib import Path +from typing import Union + +import numpy as np +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Tokenizer +from lit_gpt.utils import CLI + + +def prepare( + destination_path: Path = Path("data/openwebtext"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + seed: int = 42, + test_size: Union[float, int, None] = 0.0005, +) -> None: + from datasets import load_dataset # huggingface datasets + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(checkpoint_dir) + + # number of workers in .map() call + # good number to use is ~order number of cpu cores // 2 + num_proc = os.cpu_count() // 2 + + # number of workers in load_dataset() call + # best number might be different from num_proc above as it also depends on HW speed. + # it is better than 1 usually though + num_proc_load_dataset = num_proc + + # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) + dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) + + # owt by default only contains the 'train' split, so create a test split + split_dataset = dataset["train"].train_test_split( + test_size=test_size, seed=seed, shuffle=True + ) + split_dataset["val"] = split_dataset.pop( + "test" + ) # rename the test split to val + + def process(example): + ids = tokenizer.encode(example["text"]).tolist() + ids.append(tokenizer.eos_id) + + # ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens + # ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe + # note: I think eot should be prepended not appended... hmm. it's called "eot" though... + return {"ids": ids, "len": len(ids)} + + # tokenize the dataset + tokenized = split_dataset.map( + process, + remove_columns=["text"], + desc="tokenizing the splits", + num_proc=num_proc, + ) + + # concatenate all the ids in each dataset into one large file we can use for training + for split, dset in tokenized.items(): + arr_len = np.sum(dset["len"], dtype=np.uint64) + filename = destination_path / f"{split}.bin" + dtype = ( + np.uint16 + ) # (can do since enc.max_token_value == 50256 is < 2**16) + arr = np.memmap( + str(filename), dtype=dtype, mode="w+", shape=(arr_len,) + ) + total_batches = 1024 + + idx = 0 + for batch_idx in tqdm( + range(total_batches), desc=f"writing {filename}" + ): + # Batch together samples for faster write + batch = dset.shard( + num_shards=total_batches, index=batch_idx, contiguous=True + ).with_format("numpy") + arr_batch = np.concatenate(batch["ids"]) + # Write into mmap + arr[idx : idx + len(arr_batch)] = arr_batch + idx += len(arr_batch) + arr.flush() + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_redpajama.py b/examples/llm_finetuning/scripts/prepare_redpajama.py new file mode 100644 index 00000000000..02044307797 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_redpajama.py @@ -0,0 +1,185 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import glob +import json +import os +import sys +from pathlib import Path + +import numpy as np +from tqdm import tqdm + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +import lit_gpt.packed_dataset as packed_dataset +from lit_gpt import Config, Tokenizer +from lit_gpt.utils import CLI + +filenames_sample = [ + "arxiv_sample.jsonl", + "book_sample.jsonl", + "c4_sample.jsonl", + "cc_2019-30_sample.jsonl", + "cc_2020-05_sample.jsonl", + "cc_2021-04_sample.jsonl", + "cc_2022-05_sample.jsonl", + "cc_2023-06_sample.jsonl", + "github_sample.jsonl", + "stackexchange_sample.jsonl", + "wikipedia_sample.jsonl", +] + +filename_sets = { + "arxiv": "arxiv/arxiv*", + "book": "book/book*", + "c4": "c4/c4-train*", + "common_crawl": "common_crawl/*", + "github": "github/filtered*", + "stackexchange": "stackexchange/stackexchange*", + "wikipedia": "wikipedia/wiki*", +} + + +def prepare_sample( + source_path: Path, + checkpoint_dir: Path, + destination_path: Path, + chunk_size: int, + match: str = "", +) -> None: + """Prepare the "Red Pajama" dataset using the original tokenizer.""" + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(checkpoint_dir) + + for name in filenames_sample: + if match and match not in name: + continue + + filepath = source_path / name + + if not filepath.is_file(): + raise RuntimeError( + f"Input file not found at {filepath}. \nMake sure you download the data, e.g. wget -i" + " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" + ) + + prefix, _ = os.path.splitext(name) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=prefix, + chunk_size=chunk_size, + sep_token=tokenizer.eos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + print(f"Processing {name}") + + with open(filepath, encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array(np.array(text_ids, dtype=builder.dtype)) + + builder.write_reminder() + + +def prepare_full( + source_path: Path, + checkpoint_dir: Path, + destination_path: Path, + chunk_size: int, + match: str = "", +) -> None: + """Prepare the "Red Pajama" dataset using the original tokenizer.""" + import zstandard as zstd + + destination_path.mkdir(parents=True, exist_ok=True) + + tokenizer = Tokenizer(checkpoint_dir) + + for set_name, pattern in filename_sets.items(): + if match and match not in set_name: + continue + + is_cc = set_name == "common_crawl" + + filenames = glob.glob( + os.path.join(source_path, pattern), recursive=True + ) + + if not filenames: + raise RuntimeError( + f"No files matching {pattern} found at {source_path}. \nMake sure you download the data, e.g. wget -i" + " https://data.together.xyz/redpajama-data-1T/v1.0.0/urls.txt or through" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T" + " \nhttps://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T-Sample \n" + ) + + builder = packed_dataset.PackedDatasetBuilder( + outdir=destination_path, + prefix=set_name, + chunk_size=chunk_size, + sep_token=tokenizer.eos_id, + dtype="auto", + vocab_size=tokenizer.vocab_size, + ) + + for name in filenames: + filepath = source_path / name + + print(f"Processing {name}") + + if is_cc: + with zstd.open( + open(filepath, "rb"), "rt", encoding="utf-8" + ) as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array( + np.array(text_ids, dtype=builder.dtype) + ) + else: + with open(filepath, encoding="utf-8") as f: + for row in tqdm(f): + text = json.loads(row)["text"] + text_ids = tokenizer.encode(text) + builder.add_array( + np.array(text_ids, dtype=builder.dtype) + ) + + builder.write_reminder() + + +def prepare( + source_path: Path = Path("data/RedPajama-Data-1T-Sample"), + checkpoint_dir: Path = Path( + "checkpoints/stabilityai/stablelm-base-alpha-3b" + ), + destination_path: Path = Path("data/redpajama_sample"), + sample: bool = True, + match: str = "", +) -> None: + """Prepare the "Red Pajama" dataset. We assume tokenizer has been trained.""" + config = Config.from_checkpoint(checkpoint_dir) + + prepare_fn = prepare_sample if sample else prepare_full + prepare_fn( + source_path=source_path, + checkpoint_dir=checkpoint_dir, + destination_path=destination_path, + chunk_size=(config.block_size + 1) + * 1024, # block size + 1 for causal, 1024 blocks + match=match, + ) + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_slimpajama.py b/examples/llm_finetuning/scripts/prepare_slimpajama.py new file mode 100644 index 00000000000..0a80191f299 --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_slimpajama.py @@ -0,0 +1,68 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import json +import os +import sys +import time +from pathlib import Path + +import zstandard as zstd +from lightning.data.streaming import DataChunkRecipe, DataProcessor + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Tokenizer +from lit_gpt.utils import CLI + + +class SlimPajamaDataRecipe(DataChunkRecipe): + def __init__(self, tokenizer: Tokenizer, chunk_size: int): + super().__init__(chunk_size) + self.tokenizer = tokenizer + + def prepare_structure(self, input_dir): + files = Path(input_dir).rglob("*.zst") + return [str(file) for file in files] + + def prepare_item(self, filepath): + with zstd.open(open(filepath, "rb"), "rt", encoding="utf-8") as f: + for row in f: + text = json.loads(row)["text"] + if ( + json.loads(row)["meta"]["redpajama_set_name"] + == "RedPajamaGithub" + ): + continue # exclude the GitHub data since it overlaps with starcoder + text_ids = self.tokenizer.encode(text, bos=False, eos=True) + yield text_ids + + +def prepare( + input_dir: Path = Path("data/SlimPajama-627B/train"), + output_dir: Path = Path("data/slimpajama/train"), + tokenizer_path: Path = Path("checkpoints/Llama-2-7b-hf/"), + chunk_size: int = (2049 * 16384), + fast_dev_run: bool = False, +) -> None: + tokenizer = Tokenizer(tokenizer_path) + data_recipe = SlimPajamaDataRecipe( + tokenizer=tokenizer, chunk_size=chunk_size + ) + data_processor = DataProcessor( + input_dir=str(input_dir), + output_dir=str(output_dir), + fast_dev_run=fast_dev_run, + num_workers=os.cpu_count(), + num_downloaders=1, + ) + + start_time = time.time() + data_processor.run(data_recipe) + elapsed_time = time.time() - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/scripts/prepare_starcoder.py b/examples/llm_finetuning/scripts/prepare_starcoder.py new file mode 100644 index 00000000000..1f67c93e1fe --- /dev/null +++ b/examples/llm_finetuning/scripts/prepare_starcoder.py @@ -0,0 +1,78 @@ +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. + +import os +import sys +import time +import traceback +from pathlib import Path + +import pyarrow.parquet as pq +from lightning.data.streaming import DataChunkRecipe, DataProcessor + +# support running without installing as a package +wd = Path(__file__).parent.parent.resolve() +sys.path.append(str(wd)) + +from lit_gpt import Tokenizer +from lit_gpt.utils import CLI + + +class StarcoderDataRecipe(DataChunkRecipe): + def __init__(self, tokenizer: Tokenizer, chunk_size: int): + super().__init__(chunk_size) + self.tokenizer = tokenizer + + def prepare_structure(self, input_dir): + files = Path(input_dir).rglob("*.parquet") + return [str(file) for file in files] + + def prepare_item(self, item_metadata): + filepath = item_metadata + start = time.time() + + try: + parquet_file = pq.ParquetFile(filepath) + # reduce RAM usage + for batch in parquet_file.iter_batches( + batch_size=8192, columns=["content"] + ): + for text in batch.to_pandas()["content"]: + yield self.tokenizer.encode(text, bos=False, eos=True) + + except Exception: + print(traceback.format_exc()) + print(f"Error reading {filepath}") + return + + parquet_file.close() + end = time.time() + print(f"Took {end - start:.2f} seconds total", filepath) + + +def prepare( + input_dir: Path = Path("data/starcoderdata"), + output_dir: Path = Path("data/starcoder"), + tokenizer_path: Path = Path("checkpoints/Llama-2-7b-hf/"), + chunk_size: int = (2049 * 8192), + fast_dev_run: bool = False, +) -> None: + tokenizer = Tokenizer(tokenizer_path) + data_recipe = StarcoderDataRecipe( + tokenizer=tokenizer, chunk_size=chunk_size + ) + data_processor = DataProcessor( + input_dir=str(input_dir), + output_dir=str(output_dir), + fast_dev_run=fast_dev_run, + num_workers=os.cpu_count(), + num_downloaders=1, + ) + + start_time = time.time() + data_processor.run(data_recipe) + elapsed_time = time.time() - start_time + print(f"Time taken: {elapsed_time:.2f} seconds") + + +if __name__ == "__main__": + CLI(prepare) diff --git a/examples/llm_finetuning/steps/__init__.py b/examples/llm_finetuning/steps/__init__.py new file mode 100644 index 00000000000..c9630597e75 --- /dev/null +++ b/examples/llm_finetuning/steps/__init__.py @@ -0,0 +1,21 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from steps.evaluate import evaluate +from steps.feature_engineering import feature_engineering +from steps.finetune import finetune +from steps.merge import merge diff --git a/examples/llm_finetuning/steps/evaluate.py b/examples/llm_finetuning/steps/evaluate.py new file mode 100644 index 00000000000..f9570dee734 --- /dev/null +++ b/examples/llm_finetuning/steps/evaluate.py @@ -0,0 +1,143 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import shutil +from pathlib import Path +from typing import Any, Dict, List, Literal, Optional + +import torch +from evaluate.lm_eval_harness import run_eval_harness +from huggingface_hub import snapshot_download +from pydantic import BaseModel +from scripts.download import download_from_hub +from scripts.merge_lora import merge_lora +from typing_extensions import Annotated + +from steps.params import LoraParameters +from steps.utils import ( + convert_to_lit_checkpoint_if_necessary, + get_huggingface_access_token, +) +from zenml import step +from zenml.logger import get_logger + +logger = get_logger(__file__) + + +class EvaluationParameters(BaseModel): + """Parameters for the evaluation step. + + If `adapter_repo` is set, it will be merged with the model. Otherwise + the model itself will be evaluated. + """ + + model_repo: str + from_safetensors: bool = False + adapter_repo: Optional[str] = None + + precision: Optional[str] = None + quantize: Optional[ + Literal[ + "bnb.nf4", + "bnb.nf4-dq", + "bnb.fp4", + "bnb.fp4-dq", + "bnb.int8-training", + ] + ] = None + + lora: LoraParameters = LoraParameters() + + eval_tasks: List[str] = [ + "arc_challenge", + "piqa", + "hellaswag", + "hendrycksTest-*", + ] + num_fewshot: int = 0 + limit: Optional[int] = None + bootstrap_iters: int = 100000 + no_cache: bool = True + + +@step +def evaluate( + config: EvaluationParameters, +) -> Annotated[Dict[str, Any], "evaluation_results"]: + """Evaluate model. + + Args: + config: Configuration for this step. + """ + torch.set_float32_matmul_precision("high") + + access_token = get_huggingface_access_token() + + checkpoint_root_dir = Path("checkpoints") + checkpoint_dir = checkpoint_root_dir / config.model_repo + + if checkpoint_dir.exists(): + logger.info( + "Checkpoint directory already exists, skipping download..." + ) + else: + download_from_hub( + repo_id=config.model_repo, + from_safetensors=config.from_safetensors, + checkpoint_dir=checkpoint_root_dir, + access_token=access_token, + ) + + convert_to_lit_checkpoint_if_necessary(checkpoint_dir=checkpoint_dir) + + if config.adapter_repo: + adapter_dir = Path("adapters") / config.adapter_repo + merged_dir = Path("output/merged") + + snapshot_download( + config.adapter_repo, + local_dir=adapter_dir, + local_dir_use_symlinks=False, + resume_download=True, + token=access_token, + ) + + lora_path = adapter_dir / "lit_model_lora_finetuned.pth" + merge_lora( + lora_path=lora_path, + checkpoint_dir=checkpoint_dir, + out_dir=merged_dir, + precision=config.precision, + **config.lora.dict(), + ) + + for path in Path(checkpoint_dir).glob("*.json"): + destination = Path(merged_dir) / path.name + shutil.copy(src=path, dst=destination) + + checkpoint_dir = merged_dir + + output_path = Path("output.json") + run_eval_harness( + checkpoint_dir=checkpoint_dir, + save_filepath=output_path, + **config.dict(exclude={"model_repo", "adapter_repo", "lora"}), + ) + + with open(output_path, "r") as f: + return json.load(f) diff --git a/examples/llm_finetuning/steps/feature_engineering.py b/examples/llm_finetuning/steps/feature_engineering.py new file mode 100644 index 00000000000..c47eb8a28e3 --- /dev/null +++ b/examples/llm_finetuning/steps/feature_engineering.py @@ -0,0 +1,89 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +import json +from dataclasses import asdict +from pathlib import Path +from typing import Any, Dict + +from lit_gpt import Config +from materializers.directory_materializer import DirectoryMaterializer +from pydantic import BaseModel +from scripts.download import download_from_hub +from typing_extensions import Annotated + +from steps.utils import get_huggingface_access_token +from zenml import log_artifact_metadata, step + + +class FeatureEngineeringParameters(BaseModel): + """Parameters for the feature engineering step.""" + + model_repo: str + dataset_name: str + + prepare_kwargs: Dict[str, Any] = {} + + +@step(output_materializers=DirectoryMaterializer) +def feature_engineering( + config: FeatureEngineeringParameters, +) -> Annotated[Path, "dataset"]: + """Prepare the dataset. + + Args: + config: Configuration for this step. + """ + access_token = get_huggingface_access_token() + + checkpoint_root_dir = Path("checkpoints") + download_from_hub( + repo_id=config.model_repo, + tokenizer_only=True, + checkpoint_dir=checkpoint_root_dir, + access_token=access_token, + ) + + checkpoint_dir = checkpoint_root_dir / config.model_repo + + model_name = checkpoint_dir.name + lit_config = Config.from_name(model_name) + lit_config_dict = asdict(lit_config) + with open(checkpoint_dir / "lit_config.json", "w") as json_config: + json.dump(lit_config_dict, json_config) + + log_artifact_metadata( + metadata={ + "model_name": model_name, + "model_config": lit_config_dict, + "dataset_name": config.dataset_name, + } + ) + destination_dir = Path("data") / config.dataset_name + + helper_module = importlib.import_module( + f"scripts.prepare_{config.dataset_name}" + ) + prepare_function = getattr(helper_module, "prepare") + + prepare_function( + checkpoint_dir=checkpoint_dir, + destination_path=destination_dir, + **config.prepare_kwargs, + ) + return destination_dir diff --git a/examples/llm_finetuning/steps/finetune.py b/examples/llm_finetuning/steps/finetune.py new file mode 100644 index 00000000000..fa3a9305e8b --- /dev/null +++ b/examples/llm_finetuning/steps/finetune.py @@ -0,0 +1,249 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import shutil +from pathlib import Path +from typing import Literal, Optional + +import torch +from finetune.lora import setup +from huggingface_hub import upload_folder +from lit_gpt.args import EvalArgs, IOArgs, TrainArgs +from materializers.directory_materializer import DirectoryMaterializer +from pydantic import BaseModel +from scripts.convert_lit_checkpoint import convert_lit_checkpoint +from scripts.download import download_from_hub +from scripts.merge_lora import merge_lora +from scripts.prepare_alpaca import prepare +from typing_extensions import Annotated + +from steps.params import LoraParameters +from steps.utils import ( + convert_to_lit_checkpoint_if_necessary, + get_huggingface_access_token, +) +from zenml import get_step_context, log_model_metadata, step +from zenml.logger import get_logger +from zenml.materializers import BuiltInMaterializer + +logger = get_logger(__file__) + + +class DataParameters(BaseModel): + """Data preprocessing parameters.""" + + seed: int = 42 + test_split_fraction: float = 0.03865 + mask_inputs: bool = False + ignore_index: int = -1 + max_seq_length: Optional[int] = None + + +class TrainingParameters(BaseModel): + """Training parameters.""" + + save_interval: int = 1000 + log_interval: int = 1 + global_batch_size: int = 64 + micro_batch_size: int = 4 + lr_warmup_steps: int = 100 + epochs: Optional[int] = None + epoch_size: Optional[int] = None + max_tokens: Optional[int] = None + max_seq_length: Optional[int] = None + + learning_rate: float = 1e-3 + weight_decay: float = 0.02 + beta1: float = 0.9 + beta2: float = 0.95 + max_norm: Optional[float] = None + min_lr: float = 6e-5 + + +class EvalParameters(BaseModel): + """Mid-training evaluation parameters.""" + + interval: int = 100 + max_new_tokens: int = 100 + max_iters: int = 100 + + +class FinetuningParameters(BaseModel): + """Parameters for the finetuning step.""" + + base_model_repo: str + from_safetensors: bool = False + + adapter_output_repo: Optional[str] = None + merged_output_repo: Optional[str] = None + convert_to_hf_checkpoint: bool = False + + precision: Optional[str] = None + quantize: Optional[ + Literal[ + "bnb.nf4", + "bnb.nf4-dq", + "bnb.fp4", + "bnb.fp4-dq", + "bnb.int8-training", + ] + ] = None + + data: DataParameters = DataParameters() + training: TrainingParameters = TrainingParameters() + eval: EvalParameters = EvalParameters() + lora: LoraParameters = LoraParameters() + + +@step(output_materializers=[DirectoryMaterializer, BuiltInMaterializer]) +def finetune( + config: FinetuningParameters, dataset_directory: Optional[Path] = None +) -> Annotated[Optional[Path], "adapter"]: + """Finetune model using LoRA. + + Args: + config: Configuration for this step. + """ + torch.set_float32_matmul_precision("high") + + access_token = get_huggingface_access_token() + + checkpoint_root_dir = Path("checkpoints") + checkpoint_dir = checkpoint_root_dir / config.base_model_repo + + if checkpoint_dir.exists(): + logger.info( + "Checkpoint directory already exists, skipping download..." + ) + else: + download_from_hub( + repo_id=config.base_model_repo, + from_safetensors=config.from_safetensors, + checkpoint_dir=checkpoint_root_dir, + access_token=access_token, + ) + + convert_to_lit_checkpoint_if_necessary(checkpoint_dir=checkpoint_dir) + + if dataset_directory: + try: + dataset_name = ( + get_step_context() + .inputs["dataset_directory"] + .run_metadata["dataset_name"] + .value + ) + except KeyError: + dataset_name = "unknown_dataset" + else: + dataset_directory = Path("data/alpaca") + dataset_name = dataset_directory.name + prepare( + destination_path=dataset_directory, + checkpoint_dir=checkpoint_dir, + test_split_fraction=config.data.test_split_fraction, + seed=config.data.seed, + mask_inputs=config.data.mask_inputs, + ignore_index=config.data.ignore_index, + max_seq_length=config.data.max_seq_length, + ) + + model_name = checkpoint_dir.name + + log_model_metadata( + metadata={"model_name": model_name, "dataset_name": dataset_name} + ) + adapter_output_dir = Path("output/lora") / dataset_name / model_name + + io_args = IOArgs( + train_data_dir=dataset_directory, + val_data_dir=dataset_directory, + checkpoint_dir=checkpoint_dir, + out_dir=adapter_output_dir, + ) + train_args = TrainArgs(**config.training.dict()) + eval_args = EvalArgs(**config.eval.dict()) + setup( + devices=1, + io=io_args, + train=train_args, + eval=eval_args, + precision=config.precision, + quantize=config.quantize, + **config.lora.dict(), + ) + + if config.merged_output_repo: + lora_path = adapter_output_dir / "lit_model_lora_finetuned.pth" + + merge_output_dir = ( + Path("output/lora_merged") / dataset_name / model_name + ) + merge_lora( + lora_path=lora_path, + checkpoint_dir=checkpoint_dir, + out_dir=merge_output_dir, + precision=config.precision, + **config.lora.dict(), + ) + + for path in Path(checkpoint_dir).glob("*.json"): + destination = Path(merge_output_dir) / path.name + shutil.copy(src=path, dst=destination) + + if config.convert_to_hf_checkpoint: + upload_dir = ( + Path("output/lora_merged_hf") / dataset_name / model_name + ) + upload_dir.mkdir(parents=True, exist_ok=True) + convert_lit_checkpoint( + checkpoint_path=config.merged_output_repo / "lit_model.pth", + config_path=config.merged_output_repo / "lit_config.json", + output_path=upload_dir / "pytorch_model", + ) + else: + upload_dir = merge_output_dir + + commit = upload_folder( + repo_id=config.merged_output_repo, + folder_path=upload_dir, + token=access_token, + ) + log_model_metadata( + metadata={ + "merged_model_huggingface_commit_hash": commit.oid, + "merged_model_huggingface_commit_url": commit.commit_url, + } + ) + + if config.adapter_output_repo: + commit = upload_folder( + repo_id=config.adapter_output_repo, + folder_path=adapter_output_dir, + token=access_token, + ) + log_model_metadata( + metadata={ + "adapter_huggingface_commit_hash": commit.oid, + "adapter_huggingface_commit_url": commit.commit_url, + } + ) + return None + else: + # If the adapter should not be uploaded to the HF Hub, we store it + # in the artifact store + return adapter_output_dir diff --git a/examples/llm_finetuning/steps/merge.py b/examples/llm_finetuning/steps/merge.py new file mode 100644 index 00000000000..bc8fa90f716 --- /dev/null +++ b/examples/llm_finetuning/steps/merge.py @@ -0,0 +1,124 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import shutil +from pathlib import Path +from typing import Optional + +from huggingface_hub import snapshot_download, upload_folder +from pydantic import BaseModel +from scripts.convert_lit_checkpoint import convert_lit_checkpoint +from scripts.download import download_from_hub +from scripts.merge_lora import merge_lora + +from steps.params import LoraParameters +from steps.utils import ( + convert_to_lit_checkpoint_if_necessary, + get_huggingface_access_token, +) +from zenml import log_model_metadata, step +from zenml.logger import get_logger + +logger = get_logger(__file__) + + +class MergeParameters(BaseModel): + """Parameters for the merging step.""" + + base_model_repo: str + from_safetensors: bool = False + + adapter_repo: str + output_repo: str + convert_to_hf_checkpoint: bool = False + + precision: Optional[str] = None + lora: LoraParameters = LoraParameters() + + +@step +def merge(config: MergeParameters) -> None: + """Merge base model and LoRA adapter. + + Args: + config: Configuration for this step. + """ + access_token = get_huggingface_access_token() + + checkpoint_root_dir = Path("checkpoints") + base_model_dir = checkpoint_root_dir / config.base_model_repo + adapter_dir = Path("adapters") / config.adapter_repo + + if base_model_dir.exists(): + logger.info( + "Checkpoint directory already exists, skipping download..." + ) + else: + download_from_hub( + repo_id=config.base_model_repo, + from_safetensors=config.from_safetensors, + checkpoint_dir=checkpoint_root_dir, + access_token=access_token, + ) + + convert_to_lit_checkpoint_if_necessary(checkpoint_dir=base_model_dir) + + snapshot_download( + config.adapter_repo, + local_dir=adapter_dir, + local_dir_use_symlinks=False, + resume_download=True, + token=access_token, + ) + + lora_path = adapter_dir / "lit_model_lora_finetuned.pth" + merged_dir = Path("output/merged") + + merge_lora( + lora_path=lora_path, + checkpoint_dir=base_model_dir, + out_dir=merged_dir, + precision=config.precision, + **config.lora.dict(), + ) + + for path in Path(base_model_dir).glob("*.json"): + destination = Path(merged_dir) / path.name + shutil.copy(src=path, dst=destination) + + if config.convert_to_hf_checkpoint: + model_name = base_model_dir.name + + output_dir = Path("output/lora_merged_hf") / model_name + output_dir.mkdir(parents=True, exist_ok=True) + convert_lit_checkpoint( + checkpoint_path=merged_dir / "lit_model.pth", + config_path=merged_dir / "lit_config.json", + output_path=output_dir / "pytorch_model", + ) + else: + output_dir = merged_dir + + commit = upload_folder( + repo_id=config.output_repo, folder_path=output_dir, token=access_token + ) + log_model_metadata( + metadata={ + "merged_model_huggingface_commit_hash": commit.oid, + "merged_model_huggingface_commit_url": commit.commit_url, + } + ) diff --git a/examples/llm_finetuning/steps/params.py b/examples/llm_finetuning/steps/params.py new file mode 100644 index 00000000000..52e450de206 --- /dev/null +++ b/examples/llm_finetuning/steps/params.py @@ -0,0 +1,32 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pydantic import BaseModel + + +class LoraParameters(BaseModel): + """Lora specific parameters.""" + + lora_r: int = 8 + lora_alpha: int = 16 + lora_dropout: float = 0.05 + lora_query: bool = True + lora_key: bool = False + lora_value: bool = True + lora_projection: bool = False + lora_mlp: bool = False + lora_head: bool = False diff --git a/examples/llm_finetuning/steps/utils.py b/examples/llm_finetuning/steps/utils.py new file mode 100644 index 00000000000..c81238fef5c --- /dev/null +++ b/examples/llm_finetuning/steps/utils.py @@ -0,0 +1,54 @@ +# Apache Software License 2.0 +# +# Copyright (c) ZenML GmbH 2024. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +from pathlib import Path +from typing import Optional + +from scripts.convert_hf_checkpoint import convert_hf_checkpoint + +from zenml.client import Client + + +def get_huggingface_access_token() -> Optional[str]: + """Get access token for huggingface. + + Returns: + The access token if one was found. + """ + try: + return ( + Client() + .get_secret("huggingface_credentials") + .secret_values["token"] + ) + except KeyError: + return os.getenv("HF_TOKEN") + + +def convert_to_lit_checkpoint_if_necessary(checkpoint_dir: Path) -> None: + """Convert an HF checkpoint to a lit checkpoint if necessary. + + Args: + checkpoint_dir: The directory of the HF checkpoint. + """ + lit_model_path = checkpoint_dir / "lit_model.pth" + + if lit_model_path.is_file(): + return + + convert_hf_checkpoint(checkpoint_dir=checkpoint_dir) diff --git a/pyproject.toml b/pyproject.toml index f86386c9498..dadc1fe0620 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zenml" -version = "0.55.5" +version = "0.56.2" packages = [{ include = "zenml", from = "src" }] description = "ZenML: Write production-ready ML code." authors = ["ZenML GmbH "] @@ -60,8 +60,8 @@ python = ">=3.8,<3.12" python-dateutil = "^2.8.1" pyyaml = ">=6.0.1" rich = { extras = ["jupyter"], version = ">=12.0.0" } -sqlalchemy_utils = "0.41.1" -sqlmodel = ">=0.0.9, <=0.0.16" +sqlalchemy_utils = "0.38.3" +sqlmodel = "0.0.8" importlib_metadata = { version = "<=7.0.0", python = "<3.10" } # Optional dependencies for the ZenServer @@ -69,6 +69,7 @@ fastapi = { version = ">=0.75,<0.100", optional = true } uvicorn = { extras = ["standard"], version = ">=0.17.5", optional = true } python-multipart = { version = "~0.0.5", optional = true } pyjwt = { extras = ["crypto"], version = "2.7.*", optional = true } +fastapi-utils = { version = "~0.2.1", optional = true } orjson = { version = "~3.8.3", optional = true } Jinja2 = { version = "*", optional = true } ipinfo = { version = ">=4.4.3", optional = true } @@ -303,6 +304,12 @@ exclude = [ "venv", '__init__.py', 'src/zenml/cli/version.py', + # LitGPT files from the LLM Finetuning example + 'examples/llm_finetuning/evaluate', + 'examples/llm_finetuning/finetune', + 'examples/llm_finetuning/generate', + 'examples/llm_finetuning/lit_gpt', + 'examples/llm_finetuning/scripts', ] src = ["src", "test"] @@ -439,6 +446,7 @@ module = [ "bentoml.*", "multipart.*", "jose.*", + "fastapi_utils.*", "sqlalchemy_utils.*", "sky.*", "copier.*", diff --git a/scripts/check-security.sh b/scripts/check-security.sh index de3fbef2060..8894b8d1eb2 100755 --- a/scripts/check-security.sh +++ b/scripts/check-security.sh @@ -8,4 +8,6 @@ SRC=${1:-"src/zenml tests examples"} export ZENML_DEBUG=1 export ZENML_ANALYTICS_OPT_IN=false -bandit -r $SRC -ll +bandit -r $SRC -ll \ + --exclude examples/llm_finetuning/scripts/prepare_alpaca.py + diff --git a/scripts/install-dashboard.sh b/scripts/install-dashboard.sh index 7c4d492a694..445097ff0d9 100755 --- a/scripts/install-dashboard.sh +++ b/scripts/install-dashboard.sh @@ -25,6 +25,17 @@ verifySupported() { fi } +# checkGitIgnore checks if the dashboard directories are ignored by Git +checkGitIgnore() { + if [ -f ".gitignore" ]; then + if grep -q -E "(^|\/)dashboard($|\/)" ".gitignore" || grep -q -E "(^|\/)src\/zenml\/zen_server\/dashboard($|\/)" ".gitignore"; then + echo "Error: The '/dashboard' or 'src/zenml/zen_server/dashboard' directory is ignored by Git." + echo "Please remove the corresponding entries from the .gitignore file to proceed with the installation." + exit 1 + fi + fi +} + # checkTagProvided checks whether TAG has provided as an environment variable # so we can skip checkLatestVersion checkTagProvided() { @@ -143,10 +154,11 @@ done set +u verifySupported +checkGitIgnore checkTagProvided || checkLatestVersion if [[ ! -z "$TAG" ]]; then downloadFile verifyFile installFile fi -cleanup \ No newline at end of file +cleanup diff --git a/scripts/test-migrations-mariadb.sh b/scripts/test-migrations-mariadb.sh index 12167e2894e..30494823381 100755 --- a/scripts/test-migrations-mariadb.sh +++ b/scripts/test-migrations-mariadb.sh @@ -7,22 +7,22 @@ function run_tests_for_version() { set -e # Exit immediately if a command exits with a non-zero status local VERSION=$1 + export ZENML_ANALYTICS_OPT_IN=false + export ZENML_DEBUG=true + echo "===== Testing version $VERSION =====" mkdir test_starter - zenml init --template starter --path test_starter --template-with-defaults --test + zenml init --template starter --path test_starter --template-with-defaults <<< $'my@mail.com\n' cd test_starter - export ZENML_ANALYTICS_OPT_IN=false - export ZENML_DEBUG=true - echo "===== Installing sklearn integration =====" zenml integration export-requirements sklearn --output-file sklearn-requirements.txt uv pip install -r sklearn-requirements.txt rm sklearn-requirements.txt echo "===== Running starter template pipeline =====" - python3 run.py + python3 run.py --feature-pipeline --training-pipeline --no-cache # Add additional CLI tests here zenml version diff --git a/scripts/test-migrations-mysql.sh b/scripts/test-migrations-mysql.sh index 4a52ecfa927..804e7cda48d 100755 --- a/scripts/test-migrations-mysql.sh +++ b/scripts/test-migrations-mysql.sh @@ -17,7 +17,11 @@ function run_tests_for_version() { local VERSION=$1 # versions pre-templates and pre-init test flag # (zenml init --test allows for a non-interactive init) - local PRE_TEMPLATE_VERSIONS=("0.40.0" "0.40.3" "0.41.0" "0.43.0" "0.44.1" "0.44.3" "0.45.2" "0.45.3" "0.45.4" "0.45.5" "0.45.6" "0.46.0" "0.47.0") + local PRE_TEMPLATE_VERSIONS=("0.40.0" "0.40.3" "0.41.0" "0.43.0") + local PRE_ARGS_VERSIONS=("0.40.0" "0.40.3" "0.41.0" "0.43.0" "0.44.1" "0.44.3" "0.45.2" "0.45.3" "0.45.4" "0.45.5" "0.45.6" "0.46.0" "0.47.0" "0.50.0" "0.51.0" "0.52.0") + + export ZENML_ANALYTICS_OPT_IN=false + export ZENML_DEBUG=true echo "===== Testing version $VERSION =====" @@ -26,7 +30,7 @@ function run_tests_for_version() { copier copy -l --trust -r release/0.43.0 https://github.com/zenml-io/template-starter.git test_starter else mkdir test_starter - zenml init --template starter --path test_starter --template-with-defaults --test + zenml init --template starter --path test_starter --template-with-defaults <<< $'my@mail.com\n' fi cd test_starter @@ -40,7 +44,11 @@ function run_tests_for_version() { rm sklearn-requirements.txt echo "===== Running starter template pipeline =====" - python3 run.py + if printf '%s\n' "${PRE_ARGS_VERSIONS[@]}" | grep -q "^$VERSION$"; then + python3 run.py --no-cache + else + python3 run.py --feature-pipeline --training-pipeline --no-cache + fi # Add additional CLI tests here zenml version @@ -88,10 +96,10 @@ do # Get the major and minor version of Python PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') - # Check if the Python version is 3.9 and VERSION is > 0.47.0 + # Check if the Python version is 3.9 and VERSION is > 0.44.0 if [[ "$PYTHON_VERSION" == "3.9" ]]; then case "$VERSION" in - "0.47.0"|"0.50.0"|"0.51.0"|"0.52.0") + "0.44.1"|"0.44.3"|"0.45.2"|"0.45.3"|"0.45.4"|"0.45.5"|"0.45.6"|"0.46.0"|"0.47.0"|"0.50.0"|"0.51.0"|"0.52.0") uv pip install importlib_metadata ;; esac diff --git a/src/zenml/VERSION b/src/zenml/VERSION index 9aaab801597..cc169d8ce70 100644 --- a/src/zenml/VERSION +++ b/src/zenml/VERSION @@ -1 +1 @@ -0.55.5 +0.56.2 \ No newline at end of file diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index 9ae9d77abfd..e8a8b1655bd 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -83,6 +83,10 @@ def copier_github_url(self) -> str: github_url="zenml-io/template-nlp", github_tag="2024.01.12", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml ), + llm_finetuning=ZenMLProjectTemplateLocation( + github_url="zenml-io/template-llm-finetuning", + github_tag="2024.03.18", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml + ), ) @@ -98,9 +102,9 @@ def copier_github_url(self) -> str: type=str, required=False, help="Name or URL of the ZenML project template to use to initialize the " - "repository, Can be a string like `e2e_batch`, `nlp`, `starter` etc. or a " - "copier URL like gh:owner/repo_name. If not specified, no template is " - "used.", + "repository, Can be a string like `e2e_batch`, `nlp`, `llm_finetuning`, " + "`starter` etc. or a copier URL like gh:owner/repo_name. If not specified, " + "no template is used.", ) @click.option( "--template-tag", diff --git a/src/zenml/cli/served_model.py b/src/zenml/cli/served_model.py index 77540039744..3d708651121 100644 --- a/src/zenml/cli/served_model.py +++ b/src/zenml/cli/served_model.py @@ -29,6 +29,7 @@ ) from zenml.console import console from zenml.enums import StackComponentType +from zenml.model_deployers import BaseModelDeployer if TYPE_CHECKING: from zenml.model_deployers import BaseModelDeployer @@ -71,14 +72,6 @@ def models(ctx: click.Context) -> None: help="Get a list of all served models within the model-deployer stack " "component.", ) - @click.option( - "--pipeline", - "-p", - type=click.STRING, - default=None, - help="Show only served models that were deployed by the indicated " - "pipeline.", - ) @click.option( "--step", "-s", @@ -88,13 +81,21 @@ def models(ctx: click.Context) -> None: "pipeline step.", ) @click.option( - "--run-name", + "--pipeline-run-id", "-r", type=click.STRING, default=None, help="Show only served models that were deployed by the indicated " "pipeline run.", ) + @click.option( + "--pipeline-name", + "-p", + type=click.STRING, + default=None, + help="Show only served models that were deployed by the indicated " + "pipeline.", + ) @click.option( "--model", "-m", @@ -102,6 +103,20 @@ def models(ctx: click.Context) -> None: default=None, help="Show only served model versions for the given model name.", ) + @click.option( + "--model-version", + "-v", + type=click.STRING, + default=None, + help="Show only served model versions for the given model version.", + ) + @click.option( + "--flavor", + "-f", + type=click.STRING, + default=None, + help="Show only served model versions for the given model flavor.", + ) @click.option( "--running", is_flag=True, @@ -110,31 +125,38 @@ def models(ctx: click.Context) -> None: @click.pass_obj def list_models( model_deployer: "BaseModelDeployer", - pipeline: Optional[str], step: Optional[str], - run_name: Optional[str], + pipeline_name: Optional[str], + pipeline_run_id: Optional[str], model: Optional[str], + model_version: Optional[str], + flavor: Optional[str], running: bool, ) -> None: """List of all served models within the model-deployer stack component. Args: model_deployer: The model-deployer stack component. - pipeline: Show only served models that were deployed by the - indicated pipeline. step: Show only served models that were deployed by the indicated pipeline step. - run_name: Show only served models that were deployed by the + pipeline_run_id: Show only served models that were deployed by the indicated pipeline run. + pipeline_name: Show only served models that were deployed by the + indicated pipeline. model: Show only served model versions for the given model name. running: Show only model servers that are currently running. + model_version: Show only served model versions for the given model + version. + flavor: Show only served model versions for the given model flavor. """ services = model_deployer.find_model_server( running=running, - pipeline_name=pipeline, - run_name=run_name, + pipeline_name=pipeline_name, + pipeline_run_id=pipeline_run_id if pipeline_run_id else None, pipeline_step_name=step, model_name=model, + model_version=model_version, + flavor=flavor, ) if services: pretty_print_model_deployer( @@ -386,14 +408,16 @@ def get_model_service_logs( ) return - for line in model_deployer.get_model_server_logs( + model_logs = model_deployer.get_model_server_logs( served_models[0].uuid, follow=follow, tail=tail - ): - # don't pretty-print log lines that are already pretty-printed - if raw or line.startswith("\x1b["): - console.print(line, markup=False) - else: - try: - console.print(line) - except MarkupError: + ) + if model_logs: + for line in model_logs: + # don't pretty-print log lines that are already pretty-printed + if raw or line.startswith("\x1b["): console.print(line, markup=False) + else: + try: + console.print(line) + except MarkupError: + console.print(line, markup=False) diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index c6879f42e04..bcefd8182d4 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -1128,6 +1128,10 @@ def get_service_state_emoji(state: "ServiceState") -> str: return ":pause_button:" if state == ServiceState.ERROR: return ":heavy_exclamation_mark:" + if state == ServiceState.PENDING_STARTUP: + return ":hourglass:" + if state == ServiceState.SCALED_TO_ZERO: + return ":chart_decreasing:" return ":hourglass_not_done:" @@ -1142,15 +1146,18 @@ def pretty_print_model_deployer( """ model_service_dicts = [] for model_service in model_services: - served_model_info = model_deployer.get_model_server_info(model_service) dict_uuid = str(model_service.uuid) dict_pl_name = model_service.config.pipeline_name dict_pl_stp_name = model_service.config.pipeline_step_name - dict_model_name = served_model_info.get("MODEL_NAME", "") + dict_model_name = model_service.config.model_name + type = model_service.SERVICE_TYPE.type + flavor = model_service.SERVICE_TYPE.flavor model_service_dicts.append( { "STATUS": get_service_state_emoji(model_service.status.state), "UUID": dict_uuid, + "TYPE": type, + "FLAVOR": flavor, "PIPELINE_NAME": dict_pl_name, "PIPELINE_STEP_NAME": dict_pl_stp_name, "MODEL_NAME": dict_model_name, @@ -1277,9 +1284,10 @@ def print_served_model_configuration( **served_model_info, "UUID": str(model_service.uuid), "STATUS": get_service_state_emoji(model_service.status.state), + "TYPE": model_service.SERVICE_TYPE.type, + "FLAVOR": model_service.SERVICE_TYPE.flavor, "STATUS_MESSAGE": model_service.status.last_error, "PIPELINE_NAME": model_service.config.pipeline_name, - "RUN_NAME": model_service.config.run_name, "PIPELINE_STEP_NAME": model_service.config.pipeline_step_name, } diff --git a/src/zenml/client.py b/src/zenml/client.py index 66f144b4ebe..a32fea00152 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -150,6 +150,10 @@ ServiceConnectorResponse, ServiceConnectorTypeModel, ServiceConnectorUpdate, + ServiceFilter, + ServiceRequest, + ServiceResponse, + ServiceUpdate, StackFilter, StackRequest, StackResponse, @@ -175,7 +179,11 @@ WorkspaceResponse, WorkspaceUpdate, ) +from zenml.services.service import ServiceConfig +from zenml.services.service_status import ServiceState +from zenml.services.service_type import ServiceType from zenml.utils import io_utils, source_utils +from zenml.utils.dict_utils import dict_to_bytes from zenml.utils.filesync_model import FileSyncModel from zenml.utils.pagination_utils import depaginate from zenml.utils.uuid_utils import is_valid_uuid @@ -1478,6 +1486,227 @@ def _validate_stack_configuration(self, stack: StackRequest) -> None: "an Orchestrator." ) + # ----------------------------- Services ----------------------------------- + + def create_service( + self, + config: ServiceConfig, + service_type: ServiceType, + model_version_id: Optional[UUID] = None, + ) -> ServiceResponse: + """Registers a service. + + Args: + config: The configuration of the service. + service_type: The type of the service. + model_version_id: The ID of the model version to associate with the + service. + + Returns: + The registered service. + """ + service_request = ServiceRequest( + name=config.service_name, + service_type=service_type, + config=config.dict(), + workspace=self.active_workspace.id, + user=self.active_user.id, + model_version_id=model_version_id, + ) + # Register the service + return self.zen_store.create_service(service_request) + + def get_service( + self, + name_id_or_prefix: Union[str, UUID], + allow_name_prefix_match: bool = True, + hydrate: bool = True, + type: Optional[str] = None, + ) -> ServiceResponse: + """Gets a service. + + Args: + name_id_or_prefix: The name or ID of the service. + allow_name_prefix_match: If True, allow matching by name prefix. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + type: The type of the service. + + Returns: + The Service + """ + + def type_scoped_list_method( + hydrate: bool = True, + **kwargs: Any, + ) -> Page[ServiceResponse]: + """Call `zen_store.list_services` with type scoping. + + Args: + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + **kwargs: Keyword arguments to pass to `ServiceFilterModel`. + + Returns: + The type-scoped list of services. + """ + service_filter_model = ServiceFilter(**kwargs) + if type: + service_filter_model.set_type(type=type) + service_filter_model.set_scope_workspace(self.active_workspace.id) + return self.zen_store.list_services( + filter_model=service_filter_model, + hydrate=hydrate, + ) + + return self._get_entity_by_id_or_name_or_prefix( + get_method=self.zen_store.get_service, + list_method=type_scoped_list_method, + name_id_or_prefix=name_id_or_prefix, + allow_name_prefix_match=allow_name_prefix_match, + hydrate=hydrate, + ) + + def list_services( + self, + sort_by: str = "created", + page: int = PAGINATION_STARTING_PAGE, + size: int = PAGE_SIZE_DEFAULT, + logical_operator: LogicalOperators = LogicalOperators.AND, + id: Optional[Union[UUID, str]] = None, + created: Optional[datetime] = None, + updated: Optional[datetime] = None, + type: Optional[str] = None, + flavor: Optional[str] = None, + workspace_id: Optional[Union[str, UUID]] = None, + user_id: Optional[Union[str, UUID]] = None, + hydrate: bool = False, + running: Optional[bool] = None, + service_name: Optional[str] = None, + pipeline_name: Optional[str] = None, + pipeline_run_id: Optional[str] = None, + pipeline_step_name: Optional[str] = None, + model_version_id: Optional[Union[str, UUID]] = None, + config: Optional[Dict[str, Any]] = None, + ) -> Page[ServiceResponse]: + """List all services. + + Args: + sort_by: The column to sort by + page: The page of items + size: The maximum size of all pages + logical_operator: Which logical operator to use [and, or] + id: Use the id of services to filter by. + created: Use to filter by time of creation + updated: Use the last updated date for filtering + type: Use the service type for filtering + flavor: Use the service flavor for filtering + workspace_id: The id of the workspace to filter by. + user_id: The id of the user to filter by. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + running: Use the running status for filtering + pipeline_name: Use the pipeline name for filtering + service_name: Use the service name or model name + for filtering + pipeline_step_name: Use the pipeline step name for filtering + model_version_id: Use the model version id for filtering + config: Use the config for filtering + pipeline_run_id: Use the pipeline run id for filtering + + Returns: + The Service response page. + """ + service_filter_model = ServiceFilter( + sort_by=sort_by, + page=page, + size=size, + logical_operator=logical_operator, + id=id, + created=created, + updated=updated, + type=type, + flavor=flavor, + workspace_id=workspace_id, + user_id=user_id, + running=running, + name=service_name, + pipeline_name=pipeline_name, + pipeline_step_name=pipeline_step_name, + model_version_id=model_version_id, + pipeline_run_id=pipeline_run_id, + config=dict_to_bytes(config) if config else None, + ) + service_filter_model.set_scope_workspace(self.active_workspace.id) + return self.zen_store.list_services( + filter_model=service_filter_model, hydrate=hydrate + ) + + def update_service( + self, + id: UUID, + name: Optional[str] = None, + service_source: Optional[str] = None, + admin_state: Optional[ServiceState] = None, + status: Optional[Dict[str, Any]] = None, + endpoint: Optional[Dict[str, Any]] = None, + labels: Optional[Dict[str, str]] = None, + prediction_url: Optional[str] = None, + health_check_url: Optional[str] = None, + model_version_id: Optional[UUID] = None, + ) -> ServiceResponse: + """Update a service. + + Args: + id: The ID of the service to update. + name: The new name of the service. + admin_state: The new admin state of the service. + status: The new status of the service. + endpoint: The new endpoint of the service. + service_source: The new service source of the service. + labels: The new labels of the service. + prediction_url: The new prediction url of the service. + health_check_url: The new health check url of the service. + model_version_id: The new model version id of the service. + + Returns: + The updated service. + """ + service_update = ServiceUpdate() + if name: + service_update.name = name + if service_source: + service_update.service_source = service_source + if admin_state: + service_update.admin_state = admin_state + if status: + service_update.status = status + if endpoint: + service_update.endpoint = endpoint + if labels: + service_update.labels = labels + if prediction_url: + service_update.prediction_url = prediction_url + if health_check_url: + service_update.health_check_url = health_check_url + if model_version_id: + service_update.model_version_id = model_version_id + return self.zen_store.update_service( + service_id=id, update=service_update + ) + + def delete_service(self, name_id_or_prefix: UUID) -> None: + """Delete a service. + + Args: + name_id_or_prefix: The name or ID of the service to delete. + """ + service = self.get_service( + name_id_or_prefix, + allow_name_prefix_match=False, + ) + self.zen_store.delete_service(service_id=service.id) + # -------------------------------- Components ------------------------------ def get_stack_component( diff --git a/src/zenml/config/server_config.py b/src/zenml/config/server_config.py index eee3844d3af..2fef2890e04 100644 --- a/src/zenml/config/server_config.py +++ b/src/zenml/config/server_config.py @@ -26,6 +26,8 @@ DEFAULT_ZENML_JWT_TOKEN_LEEWAY, DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING, DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT, + DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY, + DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE, DEFAULT_ZENML_SERVER_MAX_DEVICE_AUTH_ATTEMPTS, DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW, ENV_ZENML_SERVER_PREFIX, @@ -85,13 +87,13 @@ class ServerConfiguration(BaseModel): construct the OAuth 2.0 device authorization endpoint. If not set, a partial URL is returned to the client which is used to construct the full URL based on the server's root URL path. - device_expiration: The time in minutes that an OAuth 2.0 device is + device_expiration_minutes: The time in minutes that an OAuth 2.0 device is allowed to be used to authenticate with the ZenML server. If not set or if `jwt_token_expire_minutes` is not set, the devices are allowed to be used indefinitely. This controls the expiration time of the JWT tokens issued to clients after they have authenticated with the ZenML server using an OAuth 2.0 device. - trusted_device_expiration: The time in minutes that a trusted OAuth 2.0 + trusted_device_expiration_minutes: The time in minutes that a trusted OAuth 2.0 device is allowed to be used to authenticate with the ZenML server. If not set or if `jwt_token_expire_minutes` is not set, the devices are allowed to be used indefinitely. This controls the expiration @@ -114,11 +116,18 @@ class ServerConfiguration(BaseModel): the RBAC interface defined by `zenml.zen_server.rbac_interface.RBACInterface`. If not specified, RBAC will not be enabled for this server. + feature_gate_implementation_source: Source pointing to a class + implementing the feature gate interface defined by + `zenml.zen_server.feature_gate.feature_gate_interface.FeatureGateInterface`. + If not specified, feature usage will not be gated/tracked for this + server. workload_manager_implementation_source: Source pointing to a class implementing the workload management interface. pipeline_run_auth_window: The default time window in minutes for which a pipeline run action is allowed to authenticate with the ZenML server. + login_rate_limit_minute: The number of login attempts allowed per minute. + login_rate_limit_day: The number of login attempts allowed per day. """ deployment_type: ServerDeploymentType = ServerDeploymentType.OTHER @@ -152,11 +161,16 @@ class ServerConfiguration(BaseModel): external_server_id: Optional[UUID] = None rbac_implementation_source: Optional[str] = None + feature_gate_implementation_source: Optional[str] = None workload_manager_implementation_source: Optional[str] = None pipeline_run_auth_window: int = ( DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW ) + rate_limit_enabled: bool = False + login_rate_limit_minute: int = DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE + login_rate_limit_day: int = DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY + _deployment_id: Optional[UUID] = None @root_validator(pre=True) @@ -236,6 +250,15 @@ def rbac_enabled(self) -> bool: """ return self.rbac_implementation_source is not None + @property + def feature_gate_enabled(self) -> bool: + """Whether feature gating is enabled on the server or not. + + Returns: + Whether feature gating is enabled on the server or not. + """ + return self.feature_gate_implementation_source is not None + @property def workload_manager_enabled(self) -> bool: """Whether workload management is enabled on the server or not. diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 4abf73cd3d7..6da982c828d 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -13,10 +13,61 @@ # permissions and limitations under the License. """ZenML constants.""" +import json +import logging import os +from typing import Any, List, Optional, Type, TypeVar from zenml.enums import AuthScheme +T = TypeVar("T") + + +def handle_json_env_var( + var: str, + expected_type: Type[T], + default: Optional[List[str]] = None, +) -> Any: + """Converts a json env var into a Python object. + + Args: + var: The environment variable to convert. + default: The default value to return if the env var is not set. + expected_type: The type of the expected Python object. + + Returns: + The converted list value. + + Raises: + TypeError: In case the value of the environment variable is not of a + valid type. + + """ + # this needs to be here to avoid mutable defaults + if default is None: + default = [] + + value = os.getenv(var) + if value: + try: + loaded_value = json.loads(value) + # check if loaded value is of correct type + if expected_type is None or isinstance( + loaded_value, expected_type + ): + return loaded_value + else: + raise TypeError # if not correct type, raise TypeError + except (TypeError, json.JSONDecodeError): + # Use raw logging to avoid cyclic dependency + logging.warning( + f"Environment Variable {var} could not be loaded, into type " + f"{expected_type}, defaulting to: {default}." + ) + return default + else: + return default + def handle_bool_env_var(var: str, default: bool = False) -> bool: """Converts normal env var to boolean. @@ -100,6 +151,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int: ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_" ENV_ZENML_SERVER_DEPLOYMENT_TYPE = f"{ENV_ZENML_SERVER_PREFIX}DEPLOYMENT_TYPE" ENV_ZENML_SERVER_AUTH_SCHEME = f"{ENV_ZENML_SERVER_PREFIX}AUTH_SCHEME" +ENV_ZENML_SERVER_REPORTABLE_RESOURCES = ( + f"{ENV_ZENML_SERVER_PREFIX}REPORTABLE_RESOURCES" +) # Logging variables IS_DEBUG_ENV: bool = handle_bool_env_var(ENV_ZENML_DEBUG, default=False) @@ -178,6 +232,18 @@ def handle_int_env_var(var: str, default: int = 0) -> int: DEFAULT_HTTP_TIMEOUT = 30 ZENML_API_KEY_PREFIX = "ZENKEY_" DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW = 60 * 48 # 48 hours +DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_MINUTE = 5 +DEFAULT_ZENML_SERVER_LOGIN_RATE_LIMIT_DAY = 1000 + +# Configurations to decide which resources report their usage and check for +# entitlement in the case of a cloud deployment. Expected Format is this: +# ENV_ZENML_REPORTABLE_RESOURCES='["Foo", "bar"]' +REPORTABLE_RESOURCES: List[str] = handle_json_env_var( + ENV_ZENML_SERVER_REPORTABLE_RESOURCES, + expected_type=list, + default=["pipeline_run", "model"], +) +REQUIRES_CUSTOM_RESOURCE_REPORTING = ["pipeline"] # API Endpoint paths: ACTIVATE = "/activate" @@ -208,10 +274,6 @@ def handle_int_env_var(var: str, default: int = 0) -> int: LOGIN = "/login" LOGOUT = "/logout" LOGS = "/logs" -MODEL_VERSION_ARTIFACTS = "/model_version_artifacts" -MODEL_VERSION_PIPELINE_RUNS = "/model_version_pipeline_runs" -MODEL_VERSIONS = "/model_versions" -MODELS = "/models" PIPELINE_BUILDS = "/pipeline_builds" PIPELINE_CONFIGURATION = "/pipeline-configuration" PIPELINE_DEPLOYMENTS = "/pipeline_deployments" @@ -230,6 +292,12 @@ def handle_int_env_var(var: str, default: int = 0) -> int: SERVICE_CONNECTOR_RESOURCES = "/resources" SERVICE_CONNECTOR_TYPES = "/service_connector_types" SERVICE_CONNECTOR_VERIFY = "/verify" +SERVICE_CONNECTOR_RESOURCES = "/resources" +MODELS = "/models" +MODEL_VERSIONS = "/model_versions" +MODEL_VERSION_ARTIFACTS = "/model_version_artifacts" +MODEL_VERSION_PIPELINE_RUNS = "/model_version_pipeline_runs" +SERVICES = "/services" SERVICE_CONNECTORS = "/service_connectors" STACKS = "/stacks" STACK_COMPONENTS = "/components" diff --git a/src/zenml/container_registries/base_container_registry.py b/src/zenml/container_registries/base_container_registry.py index 4617b7db588..d8f641cf4b4 100644 --- a/src/zenml/container_registries/base_container_registry.py +++ b/src/zenml/container_registries/base_container_registry.py @@ -142,7 +142,9 @@ def docker_client(self) -> "DockerClient": ) self._docker_client = client else: - self._docker_client = DockerClient.from_env() + self._docker_client = ( + docker_utils._try_get_docker_client_from_env() + ) credentials = self.credentials if credentials: diff --git a/src/zenml/enums.py b/src/zenml/enums.py index e92da7e8871..67f6ace00f6 100644 --- a/src/zenml/enums.py +++ b/src/zenml/enums.py @@ -54,6 +54,13 @@ class VisualizationType(StrEnum): MARKDOWN = "markdown" +class ZenMLServiceType(StrEnum): + """All possible types a service can have.""" + + ZEN_SERVER = "zen_server" + MODEL_SERVING = "model-serving" + + class ExecutionStatus(StrEnum): """Enum that represents the current status of a step or pipeline run.""" diff --git a/src/zenml/event_sources/webhooks/base_webhook_event_source.py b/src/zenml/event_sources/webhooks/base_webhook_event_source.py index 0fac4f73592..035c4aabad9 100644 --- a/src/zenml/event_sources/webhooks/base_webhook_event_source.py +++ b/src/zenml/event_sources/webhooks/base_webhook_event_source.py @@ -154,10 +154,12 @@ def _validate_webhook_event_signature( Raises: AuthorizationException: If the signature validation fails. """ - signature_header = headers.get("x-hub-signature-256") + signature_header = headers.get("x-hub-signature-256") or headers.get( + "x-hub-signature" + ) if not signature_header: raise AuthorizationException( - "x-hub-signature-256 header is missing!" + "x-hub-signature-256 or x-hub-signature header is missing!" ) if not self.is_valid_signature( diff --git a/src/zenml/exceptions.py b/src/zenml/exceptions.py index f7e67339033..5ef6b0315af 100644 --- a/src/zenml/exceptions.py +++ b/src/zenml/exceptions.py @@ -253,6 +253,10 @@ class InputResolutionError(ZenMLBaseException): """Raised when step input resolving failed.""" +class SubscriptionUpgradeRequiredError(ZenMLBaseException): + """Raised when user tries to perform an action outside their current subscription tier.""" + + class HydrationError(ZenMLBaseException): """Raised when the model hydration failed.""" diff --git a/src/zenml/image_builders/local_image_builder.py b/src/zenml/image_builders/local_image_builder.py index 16a1fd29c1f..5a918e934f9 100644 --- a/src/zenml/image_builders/local_image_builder.py +++ b/src/zenml/image_builders/local_image_builder.py @@ -17,8 +17,6 @@ import tempfile from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast -from docker.client import DockerClient - from zenml.image_builders import ( BaseImageBuilder, BaseImageBuilderConfig, @@ -106,7 +104,7 @@ def build( # authenticated to access additional registries docker_client = container_registry.docker_client else: - docker_client = DockerClient.from_env() + docker_client = docker_utils._try_get_docker_client_from_env() with tempfile.TemporaryFile(mode="w+b") as f: build_context.write_archive(f) diff --git a/src/zenml/integrations/__init__.py b/src/zenml/integrations/__init__.py index 3b2e37ca377..4d1b4033eb2 100644 --- a/src/zenml/integrations/__init__.py +++ b/src/zenml/integrations/__init__.py @@ -23,6 +23,7 @@ from zenml.integrations.aws import AWSIntegration # noqa from zenml.integrations.azure import AzureIntegration # noqa from zenml.integrations.bentoml import BentoMLIntegration # noqa +from zenml.integrations.bitbucket import BitbucketIntegration # noqa from zenml.integrations.deepchecks import DeepchecksIntegration # noqa from zenml.integrations.discord import DiscordIntegration # noqa from zenml.integrations.evidently import EvidentlyIntegration # noqa diff --git a/src/zenml/integrations/airflow/__init__.py b/src/zenml/integrations/airflow/__init__.py index ddf8a79a1fd..7446a195a61 100644 --- a/src/zenml/integrations/airflow/__init__.py +++ b/src/zenml/integrations/airflow/__init__.py @@ -17,7 +17,7 @@ orchestrator. You can enable it by registering the Airflow orchestrator with the CLI tool, then bootstrap using the ``zenml orchestrator up`` command. """ -from typing import List, Type +from typing import List, Optional, Type from zenml.integrations.constants import AIRFLOW from zenml.integrations.integration import Integration @@ -32,14 +32,7 @@ class AirflowIntegration(Integration): NAME = AIRFLOW # remove pendulum version requirement once Airflow supports # pendulum>-3.0.0 - REQUIREMENTS = [ - "apache-airflow~=2.4.0", - "pendulum<3.0.0", - # We need to add this as an extra dependency to manually downgrade - # SQLModel. Otherwise, the initial installation of ZenML installs - # a higher version SQLModel and a version mismatch is created. - "sqlmodel>=0.0.9,<=0.0.16", - ] + REQUIREMENTS = ["apache-airflow~=2.4.0", "pendulum<3.0.0"] @classmethod def flavors(cls) -> List[Type[Flavor]]: diff --git a/src/zenml/integrations/bentoml/constants.py b/src/zenml/integrations/bentoml/constants.py index 318913cd19e..19395866834 100644 --- a/src/zenml/integrations/bentoml/constants.py +++ b/src/zenml/integrations/bentoml/constants.py @@ -15,5 +15,5 @@ DEFAULT_BENTO_FILENAME = "zenml_exported.bento" BENTOML_DEFAULT_PORT = 3000 -BENTOML_HEALTHCHECK_URL_PATH = "healthz" +BENTOML_HEALTHCHECK_URL_PATH = "readyz" BENTOML_PREDICTION_URL_PATH = "" diff --git a/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py b/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py index 746d13a8f67..6f782f6ce4d 100644 --- a/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py +++ b/src/zenml/integrations/bentoml/model_deployers/bentoml_model_deployer.py @@ -15,13 +15,11 @@ import os import shutil -from pathlib import Path -from typing import ClassVar, Dict, List, Optional, Type, cast +from typing import ClassVar, Dict, Optional, Type, cast from uuid import UUID from zenml.config.global_config import GlobalConfiguration from zenml.constants import DEFAULT_SERVICE_START_STOP_TIMEOUT -from zenml.integrations.bentoml.constants import BENTOML_DEFAULT_PORT from zenml.integrations.bentoml.flavors.bentoml_model_deployer_flavor import ( BentoMLModelDeployerConfig, BentoMLModelDeployerFlavor, @@ -32,8 +30,6 @@ ) from zenml.logger import get_logger from zenml.model_deployers import BaseModelDeployer, BaseModelDeployerFlavor -from zenml.services import ServiceRegistry -from zenml.services.local.local_service import SERVICE_DAEMON_CONFIG_FILE_NAME from zenml.services.service import BaseService, ServiceConfig from zenml.utils.io_utils import create_dir_recursive_if_not_exists @@ -126,7 +122,8 @@ def get_model_server_info( # type: ignore[override] ) return { - "PREDICTION_URL": service_instance.prediction_url, + "HEALTH_CHECK_URL": service_instance.get_healthcheck_url(), + "PREDICTION_URL": service_instance.get_prediction_url(), "BENTO_TAG": service_instance.config.bento, "MODEL_NAME": service_instance.config.model_name, "MODEL_URI": service_instance.config.model_uri, @@ -136,10 +133,10 @@ def get_model_server_info( # type: ignore[override] "PREDICTION_APIS_URLS": predictions_apis_urls, } - def deploy_model( + def perform_deploy_model( self, + id: UUID, config: ServiceConfig, - replace: bool = False, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, ) -> BaseService: """Create a new BentoML deployment service or update an existing one. @@ -171,10 +168,8 @@ def deploy_model( and the others are deleted. Args: + id: the UUID of the BentoML model deployer. config: the configuration of the model to be deployed with BentoML. - replace: set this flag to True to find and update an equivalent - BentoML deployment server with the new model instead of - creating and starting a new deployment server. timeout: the timeout in seconds to wait for the BentoML server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the BentoML @@ -185,49 +180,11 @@ def deploy_model( interact with the BentoML model http server. """ config = cast(BentoMLDeploymentConfig, config) - service = None - - # if replace is True, remove all existing services - if replace is True: - existing_services = self.find_model_server( - pipeline_name=config.pipeline_name, - pipeline_step_name=config.pipeline_step_name, - model_name=config.model_name, - ) - - for existing_service in existing_services: - if service is None: - # keep the most recently created service - service = cast(BentoMLDeploymentService, existing_service) - try: - # delete the older services and don't wait for them to - # be deprovisioned - self._clean_up_existing_service( - existing_service=cast( - BentoMLDeploymentService, existing_service - ), - timeout=timeout, - force=True, - ) - except RuntimeError: - # ignore errors encountered while stopping old services - pass - if service: - logger.info( - f"Updating an existing BentoML deployment service: {service}" - ) - - # set the root runtime path with the stack component's UUID - config.root_runtime_path = self.local_path - service.stop(timeout=timeout, force=True) - service.update(config) - service.start(timeout=timeout) - else: - # create a new BentoMLDeploymentService instance - service = self._create_new_service(timeout, config) - logger.info(f"Created a new BentoML deployment service: {service}") - - return cast(BaseService, service) + service = self._create_new_service( + id=id, timeout=timeout, config=config + ) + logger.info(f"Created a new BentoML deployment service: {service}") + return service def _clean_up_existing_service( self, @@ -246,12 +203,13 @@ def _clean_up_existing_service( # of workers etc.the step implementation will create a new config using # all values from the user and add values like pipeline name, model_uri def _create_new_service( - self, timeout: int, config: BentoMLDeploymentConfig + self, id: UUID, timeout: int, config: BentoMLDeploymentConfig ) -> BentoMLDeploymentService: """Creates a new BentoMLDeploymentService. Args: - timeout: the timeout in seconds to wait for the BentoML http server + id: the ID of the BentoML deployment service to be created or updated. + timeout: the timeout in seconds to wait for the BentoML server to be provisioned and successfully started or updated. config: the configuration of the model to be deployed with BentoML. @@ -262,197 +220,61 @@ def _create_new_service( # set the root runtime path with the stack component's UUID config.root_runtime_path = self.local_path # create a new service for the new model - service = BentoMLDeploymentService(config) + service = BentoMLDeploymentService(uuid=id, config=config) service.start(timeout=timeout) return service - def find_model_server( + def perform_stop_model( self, - running: bool = False, - service_uuid: Optional[UUID] = None, - pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, - ) -> List[BaseService]: - """Finds one or more model servers that match the given criteria. - - Args: - running: If true, only running services will be returned. - service_uuid: The UUID of the service that was originally used - to deploy the model. - pipeline_name: Name of the pipeline that the deployed model was part - of. - run_name: ID of the pipeline run which the deployed model - was part of. - pipeline_step_name: The name of the pipeline model deployment step - that deployed the model. - model_name: Name of the deployed model. - model_uri: URI of the deployed model. - model_type: Type/format of the deployed model. Not used in this - BentoML case. - - Returns: - One or more Service objects representing model servers that match - the input search criteria. - - Raises: - TypeError: if any of the input arguments are of an invalid type. - """ - services = [] - config = BentoMLDeploymentConfig( - model_name=model_name or "", - bento="", - port=BENTOML_DEFAULT_PORT, - model_uri=model_uri or "", - working_dir="", - pipeline_name=pipeline_name or "", - pipeline_run_id=run_name or "", - run_name=run_name or "", - pipeline_step_name=pipeline_step_name or "", - ) - - # find all services that match the input criteria - for root, _, files in os.walk(self.local_path): - if service_uuid and Path(root).name != str(service_uuid): - continue - for file in files: - if file == SERVICE_DAEMON_CONFIG_FILE_NAME: - service_config_path = os.path.join(root, file) - logger.debug( - "Loading service daemon configuration from %s", - service_config_path, - ) - existing_service_config = None - with open(service_config_path, "r") as f: - existing_service_config = f.read() - existing_service = ( - ServiceRegistry().load_service_from_json( - existing_service_config - ) - ) - if not isinstance( - existing_service, BentoMLDeploymentService - ): - raise TypeError( - f"Expected service type BentoMLDeploymentService but got " - f"{type(existing_service)} instead" - ) - existing_service.update_status() - if self._matches_search_criteria(existing_service, config): - if not running or existing_service.is_running: - services.append( - cast(BaseService, existing_service) - ) - - return services - - def _matches_search_criteria( - self, - existing_service: BentoMLDeploymentService, - config: BentoMLDeploymentConfig, - ) -> bool: - """Returns true if a service matches the input criteria. - - If any of the values in the input criteria are None, they are ignored. - This allows listing services just by common pipeline names or step - names, etc. - - Args: - existing_service: The materialized Service instance derived from - the config of the older (existing) service - config: The BentoMlDeploymentConfig object passed to the - deploy_model function holding parameters of the new service - to be created. - - Returns: - True if the service matches the input criteria. - """ - existing_service_config = existing_service.config - - # check if the existing service matches the input criteria - if ( - ( - not config.pipeline_name - or existing_service_config.pipeline_name - == config.pipeline_name - ) - and ( - not config.model_name - or existing_service_config.model_name == config.model_name - ) - and ( - not config.pipeline_step_name - or existing_service_config.pipeline_step_name - == config.pipeline_step_name - ) - and ( - not config.run_name - or existing_service_config.run_name == config.run_name - ) - ): - return True - - return False - - def stop_model_server( - self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False, - ) -> None: + ) -> BaseService: """Method to stop a model server. Args: - uuid: UUID of the model server to stop. + service: The service to stop. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, stop it - if existing_services: - existing_services[0].stop(timeout=timeout, force=force) + Returns: + The stopped service. + """ + service.stop(timeout=timeout, force=force) + return service - def start_model_server( - self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT - ) -> None: + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, + ) -> BaseService: """Method to start a model server. Args: - uuid: UUID of the model server to start. + service: The service to start. timeout: Timeout in seconds to wait for the service to start. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, start it - if existing_services: - existing_services[0].start(timeout=timeout) + Returns: + The started service. + """ + service.start(timeout=timeout) + return service - def delete_model_server( + def perform_delete_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False, ) -> None: """Method to delete all configuration of a model server. Args: - uuid: UUID of the model server to delete. + service: The service to delete. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - - # if the service exists, clean it up - if existing_services: - service = cast(BentoMLDeploymentService, existing_services[0]) - self._clean_up_existing_service( - existing_service=service, timeout=timeout, force=force - ) + service = cast(BentoMLDeploymentService, service) + self._clean_up_existing_service( + existing_service=service, timeout=timeout, force=force + ) diff --git a/src/zenml/integrations/bentoml/services/bentoml_deployment.py b/src/zenml/integrations/bentoml/services/bentoml_deployment.py index 138d3039c9b..2a826fb5077 100644 --- a/src/zenml/integrations/bentoml/services/bentoml_deployment.py +++ b/src/zenml/integrations/bentoml/services/bentoml_deployment.py @@ -94,8 +94,8 @@ class SSLBentoMLParametersConfig(BaseModel): ssl_certfile: Optional[str] = None ssl_keyfile: Optional[str] = None ssl_keyfile_password: Optional[str] = None - ssl_version: Optional[str] = None - ssl_cert_reqs: Optional[str] = None + ssl_version: Optional[int] = None + ssl_cert_reqs: Optional[int] = None ssl_ca_certs: Optional[str] = None ssl_ciphers: Optional[str] = None @@ -121,9 +121,9 @@ class BentoMLDeploymentConfig(LocalDaemonServiceConfig): bento: str bento_uri: Optional[str] = None apis: List[str] = [] - workers: Optional[int] = 1 - port: Optional[int] = None - backlog: Optional[int] = 2048 + workers: int = 1 + port: int + backlog: int = 2048 production: bool = False working_dir: str host: Optional[str] = None @@ -147,6 +147,7 @@ class BentoMLDeploymentService(LocalDaemonService, BaseDeploymentService): type="model-serving", flavor="bentoml", description="BentoML prediction service", + logo_url="https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/bentoml.png", ) config: BentoMLDeploymentConfig @@ -203,9 +204,9 @@ def run(self) -> None: serve_http_production( self.config.bento, working_dir=self.config.working_dir, - port=self.endpoint.status.port, + port=self.config.port, api_workers=self.config.workers, - host=self.endpoint.status.hostname, + host=self.config.host or DEFAULT_LOCAL_SERVICE_IP_ADDRESS, backlog=self.config.backlog, ssl_certfile=ssl_params.ssl_certfile, ssl_keyfile=ssl_params.ssl_keyfile, diff --git a/src/zenml/integrations/bentoml/steps/bentoml_deployer.py b/src/zenml/integrations/bentoml/steps/bentoml_deployer.py index 225126233ed..4bb11e1957f 100644 --- a/src/zenml/integrations/bentoml/steps/bentoml_deployer.py +++ b/src/zenml/integrations/bentoml/steps/bentoml_deployer.py @@ -87,16 +87,8 @@ def bentoml_model_deployer_step( # get pipeline name, step name and run id step_context = get_step_context() pipeline_name = step_context.pipeline.name - run_name = step_context.pipeline_run.name step_name = step_context.step_run.name - # fetch existing services with same pipeline name, step name and model name - existing_services = model_deployer.find_model_server( - pipeline_name=pipeline_name, - pipeline_step_name=step_name, - model_name=model_name, - ) - # Return the apis endpoint of the defined service to use in the predict. # This is a workaround to get the endpoints of the service defined as functions # from the user code in the BentoML service. @@ -123,7 +115,6 @@ def service_apis(bento_tag: str) -> List[str]: working_dir=working_dir or source_utils.get_source_root(), port=port, pipeline_name=pipeline_name, - run_name=run_name, pipeline_step_name=step_name, ssl_parameters=SSLBentoMLParametersConfig( ssl_certfile=ssl_certfile, @@ -136,8 +127,13 @@ def service_apis(bento_tag: str) -> List[str]: ), ) + # fetch existing services with same pipeline name, step name and model name + existing_services = model_deployer.find_model_server( + config=predictor_cfg.dict(), + service_type=BentoMLDeploymentService.SERVICE_TYPE, + ) + # Creating a new service with inactive state and status by default - service = BentoMLDeploymentService(predictor_cfg) if existing_services: service = cast(BentoMLDeploymentService, existing_services[0]) @@ -159,6 +155,7 @@ def service_apis(bento_tag: str) -> List[str]: replace=True, config=predictor_cfg, timeout=timeout, + service_type=BentoMLDeploymentService.SERVICE_TYPE, ), ) diff --git a/src/zenml/integrations/bitbucket/__init__.py b/src/zenml/integrations/bitbucket/__init__.py new file mode 100644 index 00000000000..770f355a2ee --- /dev/null +++ b/src/zenml/integrations/bitbucket/__init__.py @@ -0,0 +1,42 @@ +# Copyright (c) ZenML GmbH 2022. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Initialization of the bitbucket ZenML integration.""" +from typing import List, Type + +from zenml.integrations.constants import BITBUCKET +from zenml.integrations.integration import Integration +from zenml.plugins.base_plugin_flavor import BasePluginFlavor + +BITBUCKET_EVENT_FLAVOR = "bitbucket" + + +class BitbucketIntegration(Integration): + """Definition of bitbucket integration for ZenML.""" + + NAME = BITBUCKET + REQUIREMENTS: List[str] = [] + + @classmethod + def plugin_flavors(cls) -> List[Type[BasePluginFlavor]]: + """Declare the event flavors for the bitbucket integration. + + Returns: + List of stack component flavors for this integration. + """ + from zenml.integrations.bitbucket.plugins import BitbucketWebhookEventSourceFlavor + + return [BitbucketWebhookEventSourceFlavor] + + +BitbucketIntegration.check_installation() diff --git a/src/zenml/integrations/bitbucket/plugins/__init__.py b/src/zenml/integrations/bitbucket/plugins/__init__.py new file mode 100644 index 00000000000..c5eb3accaed --- /dev/null +++ b/src/zenml/integrations/bitbucket/plugins/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Bitbucket event flavors.""" + +from zenml.integrations.bitbucket.plugins.bitbucket_webhook_event_source_flavor import BitbucketWebhookEventSourceFlavor + +__all__ = [ + "BitbucketWebhookEventSourceFlavor" +] \ No newline at end of file diff --git a/src/zenml/integrations/bitbucket/plugins/bitbucket_webhook_event_source_flavor.py b/src/zenml/integrations/bitbucket/plugins/bitbucket_webhook_event_source_flavor.py new file mode 100644 index 00000000000..a389b6677ee --- /dev/null +++ b/src/zenml/integrations/bitbucket/plugins/bitbucket_webhook_event_source_flavor.py @@ -0,0 +1,43 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Bitbucket webhook event source flavor.""" + +from typing import ClassVar, Type + +from zenml.event_sources.webhooks.base_webhook_event_source import ( + BaseWebhookEventSourceFlavor, +) +from zenml.integrations.bitbucket import BITBUCKET_EVENT_FLAVOR +from zenml.integrations.bitbucket.plugins.event_sources.bitbucket_webhook_event_source import ( + BitbucketWebhookEventFilterConfiguration, + BitbucketWebhookEventSourceConfiguration, + BitbucketWebhookEventSourceHandler, +) + + +class BitbucketWebhookEventSourceFlavor(BaseWebhookEventSourceFlavor): + """Enables users to configure Bitbucket event sources.""" + + FLAVOR: ClassVar[str] = BITBUCKET_EVENT_FLAVOR + PLUGIN_CLASS: ClassVar[Type[BitbucketWebhookEventSourceHandler]] = ( + BitbucketWebhookEventSourceHandler + ) + + # EventPlugin specific + EVENT_SOURCE_CONFIG_CLASS: ClassVar[ + Type[BitbucketWebhookEventSourceConfiguration] + ] = BitbucketWebhookEventSourceConfiguration + EVENT_FILTER_CONFIG_CLASS: ClassVar[ + Type[BitbucketWebhookEventFilterConfiguration] + ] = BitbucketWebhookEventFilterConfiguration diff --git a/src/zenml/integrations/bitbucket/plugins/event_sources/__init__.py b/src/zenml/integrations/bitbucket/plugins/event_sources/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py b/src/zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py new file mode 100644 index 00000000000..c9a6c247958 --- /dev/null +++ b/src/zenml/integrations/bitbucket/plugins/event_sources/bitbucket_webhook_event_source.py @@ -0,0 +1,490 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Implementation of the Bitbucket webhook event source.""" + +from typing import Any, Dict, List, Optional, Type, Union +from uuid import UUID + +from pydantic import BaseModel, Extra, Field + +from zenml.enums import SecretScope +from zenml.event_sources.base_event import ( + BaseEvent, +) +from zenml.event_sources.base_event_source import EventSourceConfig +from zenml.event_sources.webhooks.base_webhook_event_source import ( + BaseWebhookEventSourceFlavor, + BaseWebhookEventSourceHandler, + WebhookEventFilterConfig, + WebhookEventSourceConfig, +) +from zenml.exceptions import AuthorizationException +from zenml.logger import get_logger +from zenml.models import ( + EventSourceRequest, + EventSourceResponse, + EventSourceUpdate, + SecretRequest, + SecretUpdate, +) +from zenml.utils.enum_utils import StrEnum +from zenml.utils.string_utils import random_str + +logger = get_logger(__name__) + +# -------------------- Utils ----------------------------------- + + +class BitbucketEventType(StrEnum): + """Collection of all possible Bitbucket Events.""" + + PUSH_EVENT = "push_event" + TAG_EVENT = "tag_event" + + +# -------------------- Bitbucket Event Models ---------------------------------- + + +class User(BaseModel): + """Bitbucket User.""" + + name: Optional[str] + email: Optional[str] + username: Optional[str] + + +class Commit(BaseModel): + """Bitbucket Commit.""" + + hash: str + message: str + links: Dict[str, Any] + author: User + + +class Repository(BaseModel): + """Bitbucket Repository.""" + + uuid: str + name: str + full_name: str + links: Dict[str, Any] + + +class PushChange(BaseModel): + """Bitbucket Push Change.""" + + new: Optional[Dict[str, Any]] + old: Optional[Dict[str, Any]] + commits: List[Commit] + + +class Push(BaseModel): + """Bitbucket Push.""" + + changes: List[PushChange] + + +class BitbucketEvent(BaseEvent): + """Bitbucket Event.""" + + actor: User + repository: Repository + push: Push + + class Config: + """Pydantic configuration class.""" + + extra = Extra.allow + + @property + def branch(self) -> Optional[str]: + """The branch the event happened on. + + Returns: + The branch name. + """ + if self.push.changes[0].new: + branch = self.push.changes[0].new.get("name", None) + if self.push.changes[0].new.get("name", None): + return str(branch) + return None + + @property + def event_type(self) -> Union[BitbucketEventType, str]: + """The type of Bitbucket event. + + Args: + The type of the event based on Bitbucket specific fields. + + Returns: + The type of the event. + """ + is_push_event = all( + [change.new is not None for change in self.push.changes] + ) + is_tag_event = all( + [ + change.new.get("type") == "tag" + for change in self.push.changes + if change.new + ] + ) + + if is_push_event: + return BitbucketEventType.PUSH_EVENT + elif is_tag_event: + return BitbucketEventType.TAG_EVENT + else: + return "unknown" + + +# -------------------- Configuration Models ---------------------------------- + + +class BitbucketWebhookEventFilterConfiguration(WebhookEventFilterConfig): + """Configuration for Bitbucket event filters.""" + + repo: Optional[str] + branch: Optional[str] + event_type: Optional[BitbucketEventType] + + def event_matches_filter(self, event: BaseEvent) -> bool: + """Checks the filter against the inbound event. + + Args: + event: The incoming event + + Returns: + Whether the event matches the filter + """ + if not isinstance(event, BitbucketEvent): + return False + if self.event_type and event.event_type != self.event_type: + # Mismatch for the action + return False + if self.repo and event.repository.full_name != self.repo: + # Mismatch for the repository + return False + if self.branch and event.branch != self.branch: + # Mismatch for the branch + return False + return True + + +class BitbucketWebhookEventSourceConfiguration(WebhookEventSourceConfig): + """Configuration for Bitbucket source filters.""" + + webhook_secret: Optional[str] = Field( + default=None, + title="The webhook secret for the event source.", + ) + webhook_secret_id: Optional[UUID] = Field( + default=None, + description="The ID of the secret containing the webhook secret.", + ) + rotate_secret: Optional[bool] = Field( + default=None, description="Set to rotate the webhook secret." + ) + + +# -------------------- Bitbucket Webhook Plugin ----------------------------------- + + +class BitbucketWebhookEventSourceHandler(BaseWebhookEventSourceHandler): + """Handler for all Bitbucket events.""" + + @property + def config_class(self) -> Type[BitbucketWebhookEventSourceConfiguration]: + """Returns the webhook event source configuration class. + + Returns: + The configuration. + """ + return BitbucketWebhookEventSourceConfiguration + + @property + def filter_class(self) -> Type[BitbucketWebhookEventFilterConfiguration]: + """Returns the webhook event filter configuration class. + + Returns: + The event filter configuration class. + """ + return BitbucketWebhookEventFilterConfiguration + + @property + def flavor_class(self) -> Type[BaseWebhookEventSourceFlavor]: + """Returns the flavor class of the plugin. + + Returns: + The flavor class of the plugin. + """ + from zenml.integrations.bitbucket.plugins.bitbucket_webhook_event_source_flavor import ( + BitbucketWebhookEventSourceFlavor, + ) + + return BitbucketWebhookEventSourceFlavor + + def _interpret_event(self, event: Dict[str, Any]) -> BitbucketEvent: + """Converts the generic event body into a event-source specific pydantic model. + + Args: + event: The generic event body + + Returns: + An instance of the event source specific pydantic model. + + Raises: + ValueError: If the event body can not be parsed into the pydantic model. + """ + try: + Bitbucket_event = BitbucketEvent(**event) + except ValueError: + raise ValueError("Event did not match the pydantic model.") + else: + return Bitbucket_event + + def _get_webhook_secret( + self, event_source: EventSourceResponse + ) -> Optional[str]: + """Get the webhook secret for the event source. + + Args: + event_source: The event source to retrieve the secret for. + + Returns: + The webhook secret associated with the event source, or None if a + secret is not applicable. + + Raises: + AuthorizationException: If the secret value could not be retrieved. + """ + # Temporary solution to get the secret value for the Event Source + config = self.validate_event_source_configuration( + event_source.configuration + ) + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + webhook_secret_id = config.webhook_secret_id + if webhook_secret_id is None: + raise AuthorizationException( + f"Webhook secret ID is missing from the event source " + f"configuration for event source '{event_source.id}'." + ) + try: + return self.zen_store.get_secret( + secret_id=webhook_secret_id + ).secret_values["webhook_secret"] + except KeyError: + logger.exception( + f"Could not retrieve secret value for webhook secret id " + f"'{webhook_secret_id}'" + ) + raise AuthorizationException( + "Could not retrieve webhook signature." + ) + + def _validate_event_source_request( + self, event_source: EventSourceRequest, config: EventSourceConfig + ) -> None: + """Validate an event source request before it is created in the database. + + The `webhook_secret`, `webhook_secret_id`, and `rotate_secret` + fields are not allowed in the request. + + Args: + event_source: Event source request. + config: Event source configuration instantiated from the request. + + Raises: + ValueError: If any of the disallowed fields are present in the + request. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + for field in ["webhook_secret", "webhook_secret_id", "rotate_secret"]: + if getattr(config, field) is not None: + raise ValueError( + f"The `{field}` field is not allowed in the event source " + "request." + ) + + def _process_event_source_request( + self, event_source: EventSourceResponse, config: EventSourceConfig + ) -> None: + """Process an event source request after it is created in the database. + + Generates a webhook secret and stores it in a secret in the database, + then attaches the secret ID to the event source configuration. + + Args: + event_source: Newly created event source + config: Event source configuration instantiated from the response. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + assert ( + event_source.user is not None + ), "User is not set for event source" + + secret_key_value = random_str(12) + webhook_secret = SecretRequest( + name=f"event_source-{str(event_source.id)}-{random_str(4)}".lower(), + values={"webhook_secret": secret_key_value}, + workspace=event_source.workspace.id, + user=event_source.user.id, + scope=SecretScope.WORKSPACE, + ) + secret = self.zen_store.create_secret(webhook_secret) + + # Store the secret ID in the event source configuration in the database + event_source_update = EventSourceUpdate.from_response(event_source) + assert event_source_update.configuration is not None + event_source_update.configuration["webhook_secret_id"] = str(secret.id) + + self.zen_store.update_event_source( + event_source_id=event_source.id, + event_source_update=event_source_update, + ) + + # Set the webhook secret in the configuration returned to the user + config.webhook_secret = secret_key_value + # Remove hidden field from the response + config.rotate_secret = None + config.webhook_secret_id = None + + def _validate_event_source_update( + self, + event_source: EventSourceResponse, + config: EventSourceConfig, + event_source_update: EventSourceUpdate, + config_update: EventSourceConfig, + ) -> None: + """Validate an event source update before it is reflected in the database. + + Ensure the webhook secret ID is preserved in the updated event source + configuration. + + Args: + event_source: Original event source before the update. + config: Event source configuration instantiated from the original + event source. + event_source_update: Event source update request. + config_update: Event source configuration instantiated from the + updated event source. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + assert isinstance( + config_update, BitbucketWebhookEventSourceConfiguration + ) + + config_update.webhook_secret_id = config.webhook_secret_id + + def _process_event_source_update( + self, + event_source: EventSourceResponse, + config: EventSourceConfig, + previous_event_source: EventSourceResponse, + previous_config: EventSourceConfig, + ) -> None: + """Process an event source after it is updated in the database. + + If the `rotate_secret` field is set to `True`, the webhook secret is + rotated and the new secret ID is attached to the event source + configuration. + + Args: + event_source: Event source after the update. + config: Event source configuration instantiated from the updated + event source. + previous_event_source: Original event source before the update. + previous_config: Event source configuration instantiated from the + original event source. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + assert isinstance( + previous_config, BitbucketWebhookEventSourceConfiguration + ) + assert config.webhook_secret_id is not None + + if config.rotate_secret: + # In case the secret is being rotated + secret_key_value = random_str(12) + webhook_secret = SecretUpdate( # type: ignore[call-arg] + values={"webhook_secret": secret_key_value} + ) + self.zen_store.update_secret( + secret_id=config.webhook_secret_id, + secret_update=webhook_secret, + ) + + # Remove the `rotate_secret` field from the configuration stored + # in the database + event_source_update = EventSourceUpdate.from_response(event_source) + assert event_source_update.configuration is not None + event_source_update.configuration.pop("rotate_secret") + self.zen_store.update_event_source( + event_source_id=event_source.id, + event_source_update=event_source_update, + ) + + # Set the new secret in the configuration returned to the user + config.webhook_secret = secret_key_value + + # Remove hidden fields from the response + config.rotate_secret = None + config.webhook_secret_id = None + + def _process_event_source_delete( + self, + event_source: EventSourceResponse, + config: EventSourceConfig, + force: Optional[bool] = False, + ) -> None: + """Process an event source before it is deleted from the database. + + Deletes the associated secret from the database. + + Args: + event_source: Event source before the deletion. + config: Validated instantiated event source configuration before + the deletion. + force: Whether to force deprovision the event source. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + if config.webhook_secret_id is not None: + try: + self.zen_store.delete_secret( + secret_id=config.webhook_secret_id + ) + except KeyError: + pass + + # Remove hidden fields from the response + config.rotate_secret = None + config.webhook_secret_id = None + + def _process_event_source_response( + self, event_source: EventSourceResponse, config: EventSourceConfig + ) -> None: + """Process an event source response before it is returned to the user. + + Removes hidden fields from the configuration. + + Args: + event_source: Event source response. + config: Event source configuration instantiated from the response. + """ + assert isinstance(config, BitbucketWebhookEventSourceConfiguration) + # Remove hidden fields from the response + config.rotate_secret = None + config.webhook_secret_id = None + config.webhook_secret = None diff --git a/src/zenml/integrations/constants.py b/src/zenml/integrations/constants.py index cb800d9b251..0a486ae63be 100644 --- a/src/zenml/integrations/constants.py +++ b/src/zenml/integrations/constants.py @@ -18,6 +18,7 @@ AZURE = "azure" AZUREML = "azureml" BENTOML = "bentoml" +BITBUCKET = "bitbucket" DASH = "dash" DEEPCHECKS = "deepchecks" DISCORD = "discord" diff --git a/src/zenml/integrations/evidently/__init__.py b/src/zenml/integrations/evidently/__init__.py index 00e0e42b6d1..6912a9ef516 100644 --- a/src/zenml/integrations/evidently/__init__.py +++ b/src/zenml/integrations/evidently/__init__.py @@ -54,13 +54,7 @@ class EvidentlyIntegration(Integration): """[Evidently](https://github.com/evidentlyai/evidently) integration for ZenML.""" NAME = EVIDENTLY - REQUIREMENTS = [ - "evidently>0.2.6,<0.4.5", # supports pyyaml 6 - # We need to add this as an extra dependency to manually downgrade - # SQLModel. Otherwise, the initial installation of ZenML installs - # a higher version SQLModel and a version mismatch is created. - "sqlmodel>=0.0.9,<=0.0.16" - ] + REQUIREMENTS = ["evidently>0.2.6,<0.4.5"] # supports pyyaml 6 @classmethod def flavors(cls) -> List[Type[Flavor]]: diff --git a/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py b/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py index 568291bc626..5b321911ad5 100644 --- a/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py +++ b/src/zenml/integrations/github/plugins/github_webhook_event_source_flavor.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing # permissions and limitations under the License. -"""Example file of what an event Plugin could look like.""" +"""Github webhook event source flavor.""" from typing import ClassVar, Type diff --git a/src/zenml/integrations/great_expectations/__init__.py b/src/zenml/integrations/great_expectations/__init__.py index 4a4e630d9fa..500b197a93f 100644 --- a/src/zenml/integrations/great_expectations/__init__.py +++ b/src/zenml/integrations/great_expectations/__init__.py @@ -35,10 +35,6 @@ class GreatExpectationsIntegration(Integration): "great-expectations>=0.15.0,<=0.15.47", # typing_extensions 4.6.0 and above doesn't work with GE "typing_extensions<4.6.0", - # We need to add this as an extra dependency to manually downgrade - # SQLModel. Otherwise, the initial installation of ZenML installs - # a higher version SQLModel and a version mismatch is created. - "sqlmodel>=0.0.9,<=0.0.16", ] @staticmethod diff --git a/src/zenml/integrations/huggingface/__init__.py b/src/zenml/integrations/huggingface/__init__.py index c1a92f48e41..5f11ebc1cb3 100644 --- a/src/zenml/integrations/huggingface/__init__.py +++ b/src/zenml/integrations/huggingface/__init__.py @@ -30,6 +30,11 @@ class HuggingfaceIntegration(Integration): "transformers<=4.31", "datasets", "huggingface_hub>0.19.0", + # temporary fix for CI issue similar to: + # - https://github.com/huggingface/datasets/issues/6737 + # - https://github.com/huggingface/datasets/issues/6697 + # TODO try relaxing it back going forward + "fsspec<=2023.12.0", ] @classmethod diff --git a/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py b/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py index d9150fe9986..f9f98b65686 100644 --- a/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py +++ b/src/zenml/integrations/huggingface/flavors/huggingface_model_deployer_flavor.py @@ -33,7 +33,6 @@ class HuggingFaceBaseConfig(BaseModel): """Hugging Face Inference Endpoint configuration.""" - endpoint_name: str = "zenml-" repository: Optional[str] = None framework: Optional[str] = None accelerator: Optional[str] = None @@ -41,15 +40,15 @@ class HuggingFaceBaseConfig(BaseModel): instance_type: Optional[str] = None region: Optional[str] = None vendor: Optional[str] = None - token: Optional[str] = None account_id: Optional[str] = None min_replica: int = 0 max_replica: int = 1 revision: Optional[str] = None task: Optional[str] = None custom_image: Optional[Dict[str, Any]] = None - namespace: Optional[str] = None endpoint_type: str = "public" + secret_name: Optional[str] = None + namespace: Optional[str] = None class HuggingFaceModelDeployerConfig( @@ -62,7 +61,7 @@ class HuggingFaceModelDeployerConfig( namespace: Hugging Face namespace used to list endpoints """ - token: str = SecretField() + token: Optional[str] = SecretField() # The namespace to list endpoints for. Set to `"*"` to list all endpoints # from all namespaces (i.e. personal namespace and all orgs the user belongs to). diff --git a/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py b/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py index 2ab93405864..eb551d5051a 100644 --- a/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py +++ b/src/zenml/integrations/huggingface/model_deployers/huggingface_model_deployer.py @@ -13,12 +13,11 @@ # permissions and limitations under the License. """Implementation of the Hugging Face Model Deployer.""" -from typing import Any, ClassVar, Dict, List, Optional, Type, cast +from typing import ClassVar, Dict, Optional, Tuple, Type, cast from uuid import UUID -from huggingface_hub import list_inference_endpoints - -from zenml.artifacts.utils import log_artifact_metadata, save_artifact +from zenml.analytics.enums import AnalyticsEvent +from zenml.analytics.utils import track_handler from zenml.client import Client from zenml.integrations.huggingface import HUGGINGFACE_SERVICE_ARTIFACT from zenml.integrations.huggingface.flavors.huggingface_model_deployer_flavor import ( @@ -35,13 +34,12 @@ DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, BaseModelDeployerFlavor, ) -from zenml.services import BaseService, ServiceConfig, ServiceRegistry +from zenml.services import BaseService, ServiceConfig +from zenml.stack.stack import Stack +from zenml.stack.stack_validator import StackValidator logger = get_logger(__name__) -ZENM_ENDPOINT_PREFIX: str = "zenml-" -UUID_SLICE_LENGTH: int = 8 - class HuggingFaceModelDeployer(BaseModelDeployer): """Hugging Face endpoint model deployer.""" @@ -61,45 +59,42 @@ def config(self) -> HuggingFaceModelDeployerConfig: return cast(HuggingFaceModelDeployerConfig, self._config) @property - def deployed_endpoints(self) -> Any: - """Get list of deployed endpoint from Hugging Face. + def validator(self) -> Optional[StackValidator]: + """Validates the stack. Returns: - List of deployed endpoints. + A validator that checks that the stack contains a remote artifact + store. """ - return list_inference_endpoints( - token=self.config.token, - namespace=self.config.namespace, - ) - - def modify_endpoint_name( - self, endpoint_name: str, artifact_version: str - ) -> str: - """Modify endpoint name by adding suffix and prefix. - - It adds a prefix "zenml-" if not present and a suffix - of first 8 characters of uuid. - Args: - endpoint_name : Name of the endpoint - artifact_version: Name of the artifact version - - Returns: - Modified endpoint name with added prefix and suffix - """ - # Add prefix if it does not start with ZENM_ENDPOINT_PREFIX - if not endpoint_name.startswith(ZENM_ENDPOINT_PREFIX): - endpoint_name = ZENM_ENDPOINT_PREFIX + endpoint_name + def _validate_if_secret_or_token_is_present( + stack: "Stack", + ) -> Tuple[bool, str]: + """Check if secret or token is present in the stack. + + Args: + stack: The stack to validate. + + Returns: + A tuple with a boolean indicating whether the stack is valid + and a message describing the validation result. + """ + return bool(self.config.token or self.config.secret_name), ( + "The Hugging Face model deployer requires either a secret name" + " or a token to be present in the stack." + ) - endpoint_name += artifact_version - return endpoint_name + return StackValidator( + custom_validation_function=_validate_if_secret_or_token_is_present, + ) def _create_new_service( - self, timeout: int, config: HuggingFaceServiceConfig + self, id: UUID, timeout: int, config: HuggingFaceServiceConfig ) -> HuggingFaceDeploymentService: """Creates a new Hugging FaceDeploymentService. Args: + id: the UUID of the model to be deployed with Hugging Face model deployer. timeout: the timeout in seconds to wait for the Hugging Face inference endpoint to be provisioned and successfully started or updated. config: the configuration of the model to be deployed with Hugging Face model deployer. @@ -109,36 +104,12 @@ def _create_new_service( with the Hugging Face inference endpoint. """ # create a new service for the new model - service = HuggingFaceDeploymentService(config) - - # Use first 8 characters of UUID as artifact version - artifact_version = str(service.dict()["uuid"])[:UUID_SLICE_LENGTH] - # Add same 8 characters as suffix to endpoint name - service.config.endpoint_name = self.modify_endpoint_name( - service.config.endpoint_name, artifact_version - ) + service = HuggingFaceDeploymentService(uuid=id, config=config) logger.info( f"Creating an artifact {HUGGINGFACE_SERVICE_ARTIFACT} with service instance attached as metadata." " If there's an active pipeline and/or model this artifact will be associated with it." ) - - save_artifact( - service, - HUGGINGFACE_SERVICE_ARTIFACT, - version=artifact_version, - is_deployment_artifact=True, - ) - - # Convert UUID object to be json serializable - service_metadata = service.dict() - service_metadata["uuid"] = str(service_metadata["uuid"]) - log_artifact_metadata( - artifact_name=HUGGINGFACE_SERVICE_ARTIFACT, - artifact_version=artifact_version, - metadata={HUGGINGFACE_SERVICE_ARTIFACT: service_metadata}, - ) - service.start(timeout=timeout) return service @@ -159,10 +130,10 @@ def _clean_up_existing_service( # stop the older service existing_service.stop(timeout=timeout, force=force) - def deploy_model( + def perform_deploy_model( self, + id: UUID, config: ServiceConfig, - replace: bool = True, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, ) -> BaseService: """Create a new Hugging Face deployment service or update an existing one. @@ -170,11 +141,8 @@ def deploy_model( This should serve the supplied model and deployment configuration. Args: + id: the UUID of the model to be deployed with Hugging Face. config: the configuration of the model to be deployed with Hugging Face. - Core - replace: set this flag to True to find and update an equivalent - Hugging Face deployment server with the new model instead of - starting a new deployment server. timeout: the timeout in seconds to wait for the Hugging Face endpoint to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the Hugging Face @@ -184,263 +152,82 @@ def deploy_model( The ZenML Hugging Face deployment service object that can be used to interact with the remote Hugging Face inference endpoint server. """ - config = cast(HuggingFaceServiceConfig, config) - service = None - - # if replace is True, remove all existing services - if replace: - existing_services = self.find_model_server( - pipeline_name=config.pipeline_name, - pipeline_step_name=config.pipeline_step_name, - ) - - for existing_service in existing_services: - if service is None: - # keep the most recently created service - service = cast( - HuggingFaceDeploymentService, existing_service - ) - try: - # delete the older services and don't wait for them to - # be deprovisioned - self._clean_up_existing_service( - existing_service=cast( - HuggingFaceDeploymentService, existing_service - ), - timeout=timeout, - force=True, - ) - except RuntimeError: - # ignore errors encountered while stopping old services - pass - - if service: - # update an equivalent service in place - logger.info( - f"Updating an existing Hugging Face deployment service: {service}" - ) - - service_metadata = service.dict() - artifact_version = str(service_metadata["uuid"])[ - :UUID_SLICE_LENGTH - ] - config.endpoint_name = self.modify_endpoint_name( - config.endpoint_name, artifact_version - ) - - service.stop(timeout=timeout, force=True) - service.update(config) - service.start(timeout=timeout) - else: + with track_handler(AnalyticsEvent.MODEL_DEPLOYED) as analytics_handler: + config = cast(HuggingFaceServiceConfig, config) # create a new HuggingFaceDeploymentService instance - service = self._create_new_service(timeout, config) + service = self._create_new_service( + id=id, timeout=timeout, config=config + ) logger.info( f"Creating a new Hugging Face inference endpoint service: {service}" ) + # Add telemetry with metadata that gets the stack metadata and + # differentiates between pure model and custom code deployments + stack = Client().active_stack + stack_metadata = { + component_type.value: component.flavor + for component_type, component in stack.components.items() + } + analytics_handler.metadata = { + "store_type": Client().zen_store.type.value, + **stack_metadata, + } - return cast(BaseService, service) - - def find_model_server( - self, - running: bool = False, - service_uuid: Optional[UUID] = None, - pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, - ) -> List[BaseService]: - """Find one or more Hugging Face model services that match the given criteria. - - Args: - running: if true, only running services will be returned. - service_uuid: the UUID of the Hugging Face service that was - originally used to create the Hugging Face deployment resource. - pipeline_name: name of the pipeline that the deployed model was part - of. - run_name: Name of the pipeline run which the deployed model was - part of. - pipeline_step_name: the name of the pipeline model deployment step - that deployed the model. - model_name: the name of the deployed model. - model_uri: URI of the deployed model. - model_type: the Hugging Face server implementation used to serve - the model - - Raises: - TypeError: If service type does not match HuggingFaceDeploymentService - - Returns: - One or more Hugging Face service objects representing Hugging Face - model servers that match the input search criteria. - """ - # Use a Hugging Face deployment service configuration to compute the labels - config = HuggingFaceServiceConfig( - pipeline_name=pipeline_name or "", - run_name=run_name or "", - pipeline_run_id=run_name or "", - pipeline_step_name=pipeline_step_name or "", - model_name=model_name or "", - model_uri=model_uri or "", - implementation=model_type or "", - ) - - services: List[BaseService] = [] - - # Find all services that match input criteria - for endpoint in self.deployed_endpoints: - if endpoint.name.startswith("zenml-"): - artifact_version = endpoint.name[-8:] - # If service_uuid is supplied, fetch service for that uuid - if ( - service_uuid is not None - and str(service_uuid)[:8] != artifact_version - ): - continue - - # Fetch the saved metadata artifact from zenml server to recreate service - client = Client() - try: - service_artifact = client.get_artifact_version( - HUGGINGFACE_SERVICE_ARTIFACT, artifact_version - ) - hf_deployment_service_dict = service_artifact.run_metadata[ - HUGGINGFACE_SERVICE_ARTIFACT - ].value - - existing_service = ( - ServiceRegistry().load_service_from_dict( - hf_deployment_service_dict # type: ignore - ) - ) - - if not isinstance( - existing_service, HuggingFaceDeploymentService - ): - raise TypeError( - f"Expected service type HuggingFaceDeploymentService but got " - f"{type(existing_service)} instead" - ) - - existing_service.update_status() - if self._matches_search_criteria(existing_service, config): - if not running or existing_service.is_running: - services.append( - cast(BaseService, existing_service) - ) - - # if endpoint is provisioned externally - # we do not have saved artifact for it. - except KeyError: - logger.error( - f"No key found for endpoint {endpoint.name} provisioned externally" - ) - - return services - - def _matches_search_criteria( - self, - existing_service: HuggingFaceDeploymentService, - config: HuggingFaceServiceConfig, - ) -> bool: - """Returns true if a service matches the input criteria. - - If any of the values in the input criteria are None, they are ignored. - This allows listing services just by common pipeline names or step - names, etc. - - Args: - existing_service: The materialized Service instance derived from - the config of the older (existing) service - config: The HuggingFaceServiceConfig object passed to the - deploy_model function holding parameters of the new service - to be created. - - Returns: - True if the service matches the input criteria. - """ - existing_service_config = existing_service.config - - # check if the existing service matches the input criteria - if ( - ( - not config.pipeline_name - or existing_service_config.pipeline_name - == config.pipeline_name - ) - and ( - not config.pipeline_step_name - or existing_service_config.pipeline_step_name - == config.pipeline_step_name - ) - and ( - not config.run_name - or existing_service_config.run_name == config.run_name - ) - ): - return True - - return False + return service - def stop_model_server( + def perform_stop_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, force: bool = False, - ) -> None: + ) -> BaseService: """Method to stop a model server. Args: - uuid: UUID of the model server to stop. + service: The service to stop. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, stop it - if existing_services: - existing_services[0].stop(timeout=timeout, force=force) + Returns: + The stopped service. + """ + service.stop(timeout=timeout, force=force) + return service - def start_model_server( - self, uuid: UUID, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT - ) -> None: + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: """Method to start a model server. Args: - uuid: UUID of the model server to start. + service: The service to start. timeout: Timeout in seconds to wait for the service to start. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, start it - if existing_services: - existing_services[0].start(timeout=timeout) + Returns: + The started service. + """ + service.start(timeout=timeout) + return service - def delete_model_server( + def perform_delete_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, force: bool = False, ) -> None: """Method to delete all configuration of a model server. Args: - uuid: UUID of the model server to delete. + service: The service to delete. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - - # if the service exists, clean it up - if existing_services: - service = cast(HuggingFaceDeploymentService, existing_services[0]) - self._clean_up_existing_service( - existing_service=service, timeout=timeout, force=force - ) + service = cast(HuggingFaceDeploymentService, service) + self._clean_up_existing_service( + existing_service=service, timeout=timeout, force=force + ) @staticmethod def get_model_server_info( # type: ignore[override] @@ -455,5 +242,6 @@ def get_model_server_info( # type: ignore[override] Model server information. """ return { - "PREDICTION_URL": service_instance.prediction_url, + "PREDICTION_URL": service_instance.get_prediction_url(), + "HEALTH_CHECK_URL": service_instance.get_healthcheck_url(), } diff --git a/src/zenml/integrations/huggingface/services/huggingface_deployment.py b/src/zenml/integrations/huggingface/services/huggingface_deployment.py index 26af08f7548..ed12e9954d1 100644 --- a/src/zenml/integrations/huggingface/services/huggingface_deployment.py +++ b/src/zenml/integrations/huggingface/services/huggingface_deployment.py @@ -26,6 +26,7 @@ from huggingface_hub.utils import HfHubHTTPError from pydantic import Field +from zenml.client import Client from zenml.integrations.huggingface.flavors.huggingface_model_deployer_flavor import ( HuggingFaceBaseConfig, ) @@ -36,16 +37,11 @@ logger = get_logger(__name__) POLLING_TIMEOUT = 1200 +UUID_SLICE_LENGTH: int = 8 class HuggingFaceServiceConfig(HuggingFaceBaseConfig, ServiceConfig): - """Hugging Face service configurations. - - Attributes: - model_name: the name of the model. - """ - - model_name: str = "default" + """Hugging Face service configurations.""" class HuggingFaceServiceStatus(ServiceStatus): @@ -81,6 +77,35 @@ def __init__(self, config: HuggingFaceServiceConfig, **attrs: Any): """ super().__init__(config=config, **attrs) + def get_token(self) -> str: + """Get the Hugging Face token. + + Raises: + ValueError: If token not found. + + Returns: + Hugging Face token. + """ + client = Client() + token = None + if self.config.secret_name: + secret = client.get_secret(self.config.secret_name) + token = secret.secret_values["token"] + else: + from zenml.integrations.huggingface.model_deployers.huggingface_model_deployer import ( + HuggingFaceModelDeployer, + ) + + model_deployer = client.active_stack.model_deployer + if not isinstance(model_deployer, HuggingFaceModelDeployer): + raise ValueError( + "HuggingFaceModelDeployer is not active in the stack." + ) + token = model_deployer.config.token or None + if not token: + raise ValueError("Token not found.") + return token + @property def hf_endpoint(self) -> InferenceEndpoint: """Get the deployed Hugging Face inference endpoint. @@ -89,22 +114,20 @@ def hf_endpoint(self) -> InferenceEndpoint: Huggingface inference endpoint. """ return get_inference_endpoint( - name=self.config.endpoint_name, - token=self.config.token, + name=self._generate_an_endpoint_name(), + token=self.get_token(), namespace=self.config.namespace, ) @property - def prediction_url(self) -> Any: + def prediction_url(self) -> Optional[str]: """The prediction URI exposed by the prediction service. Returns: The prediction URI exposed by the prediction service, or None if the service is not yet ready. """ - if not self.is_running: - return None - return self.hf_endpoint.url + return self.hf_endpoint.url if self.is_running else None @property def inference_client(self) -> InferenceClient: @@ -123,8 +146,8 @@ def provision(self) -> None: """ try: # Attempt to create and wait for the inference endpoint - _ = create_inference_endpoint( - name=self.config.endpoint_name, + hf_endpoint = create_inference_endpoint( + name=self._generate_an_endpoint_name(), repository=self.config.repository, framework=self.config.framework, accelerator=self.config.accelerator, @@ -139,20 +162,10 @@ def provision(self) -> None: task=self.config.task, custom_image=self.config.custom_image, type=self.config.endpoint_type, + token=self.get_token(), namespace=self.config.namespace, - token=self.config.token, ).wait(timeout=POLLING_TIMEOUT) - # Check if the endpoint URL is available after provisioning - if self.hf_endpoint.url is not None: - logger.info( - "Hugging Face inference endpoint successfully deployed." - ) - else: - logger.error( - "Failed to start Hugging Face inference endpoint service: No URL available." - ) - except Exception as e: self.status.update_state( new_state=ServiceState.ERROR, error=str(e) @@ -162,6 +175,16 @@ def provision(self) -> None: f"An unexpected error occurred while provisioning the Hugging Face inference endpoint: {e}" ) + # Check if the endpoint URL is available after provisioning + if hf_endpoint.url: + logger.info( + f"Hugging Face inference endpoint successfully deployed and available. Endpoint URL: {hf_endpoint.url}" + ) + else: + logger.error( + "Failed to start Hugging Face inference endpoint service: No URL available, please check the Hugging Face console for more details." + ) + def check_status(self) -> Tuple[ServiceState, str]: """Check the the current operational state of the Hugging Face deployment. @@ -170,39 +193,30 @@ def check_status(self) -> Tuple[ServiceState, str]: providing additional information about that state (e.g. a description of the error, if one is encountered). """ - # TODO: Support all different InferenceEndpointStatus try: - _ = self.hf_endpoint.status - except (InferenceEndpointError, HfHubHTTPError): - return (ServiceState.INACTIVE, "") - - if self.hf_endpoint.status == InferenceEndpointStatus.RUNNING: - return ( - ServiceState.ACTIVE, - "Hugging Face Inference Endpoint deployment is available", - ) - - elif self.hf_endpoint.status == InferenceEndpointStatus.SCALED_TO_ZERO: - return ( - ServiceState.ACTIVE, - "Hugging Face Inference Endpoint deployment is scaled to zero", - ) - - elif self.hf_endpoint.status == InferenceEndpointStatus.FAILED: - return ( - ServiceState.ERROR, - "Hugging Face Inference Endpoint deployment failed: ", - ) + status = self.hf_endpoint.status + if status == InferenceEndpointStatus.RUNNING: + return (ServiceState.ACTIVE, "") + + elif status == InferenceEndpointStatus.SCALED_TO_ZERO: + return ( + ServiceState.SCALED_TO_ZERO, + "Hugging Face Inference Endpoint is scaled to zero, but still running. It will be started on demand.", + ) - elif self.hf_endpoint.status == InferenceEndpointStatus.PENDING: + elif status == InferenceEndpointStatus.FAILED: + return ( + ServiceState.ERROR, + "Hugging Face Inference Endpoint deployment is inactive or not found", + ) + elif status == InferenceEndpointStatus.PENDING: + return (ServiceState.PENDING_STARTUP, "") + return (ServiceState.PENDING_STARTUP, "") + except (InferenceEndpointError, HfHubHTTPError): return ( - ServiceState.PENDING_STARTUP, - "Hugging Face Inference Endpoint deployment is being created: ", + ServiceState.INACTIVE, + "Hugging Face Inference Endpoint deployment is inactive or not found", ) - return ( - ServiceState.PENDING_STARTUP, - "Hugging Face Inference Endpoint deployment is being created: ", - ) def deprovision(self, force: bool = False) -> None: """Deprovision the remote Hugging Face deployment instance. @@ -217,7 +231,6 @@ def deprovision(self, force: bool = False) -> None: logger.error( "Hugging Face Inference Endpoint is deleted or cannot be found." ) - pass def predict(self, data: "Any", max_new_tokens: int) -> "Any": """Make a prediction using the service. @@ -238,7 +251,7 @@ def predict(self, data: "Any", max_new_tokens: int) -> "Any": "Hugging Face endpoint inference service is not running. " "Please start the service before making predictions." ) - if self.hf_endpoint.prediction_url is not None: + if self.prediction_url is not None: if self.hf_endpoint.task == "text-generation": result = self.inference_client.task_generation( data, max_new_tokens=max_new_tokens @@ -267,3 +280,13 @@ def get_logs( "your Endpoints through the UI in the “Logs” tab of your Endpoint" ) return # type: ignore + + def _generate_an_endpoint_name(self) -> str: + """Generate a unique name for the Hugging Face Inference Endpoint. + + Returns: + A unique name for the Hugging Face Inference Endpoint. + """ + return ( + f"{self.config.service_name}-{str(self.uuid)[:UUID_SLICE_LENGTH]}" + ) diff --git a/src/zenml/integrations/huggingface/steps/huggingface_deployer.py b/src/zenml/integrations/huggingface/steps/huggingface_deployer.py index fd123e88341..5303d89bda7 100644 --- a/src/zenml/integrations/huggingface/steps/huggingface_deployer.py +++ b/src/zenml/integrations/huggingface/steps/huggingface_deployer.py @@ -58,21 +58,17 @@ def huggingface_model_deployer_step( # get pipeline name, step name and run id context = get_step_context() pipeline_name = context.pipeline.name - run_name = context.pipeline_run.name step_name = context.step_run.name # update the step configuration with the real pipeline runtime information service_config = service_config.copy() service_config.pipeline_name = pipeline_name - service_config.run_name = run_name service_config.pipeline_step_name = step_name # fetch existing services with same pipeline name, step name and # model name existing_services = model_deployer.find_model_server( - pipeline_name=pipeline_name, - pipeline_step_name=step_name, - model_name=service_config.model_name, + config=service_config.dict() ) # even when the deploy decision is negative, if an existing model server @@ -99,7 +95,10 @@ def huggingface_model_deployer_step( service = cast( HuggingFaceDeploymentService, model_deployer.deploy_model( - service_config, replace=True, timeout=timeout + service_config, + replace=True, + timeout=timeout, + service_type=HuggingFaceDeploymentService.SERVICE_TYPE, ), ) diff --git a/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py b/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py index 0f90eec205b..18b62a0bacc 100644 --- a/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py +++ b/src/zenml/integrations/hyperai/orchestrators/hyperai_orchestrator.py @@ -17,7 +17,7 @@ import re import tempfile from shlex import quote -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast +from typing import IO, TYPE_CHECKING, Any, Dict, Optional, Type, cast import paramiko import yaml @@ -129,6 +129,36 @@ def _escape_shell_command(self, command: str) -> str: """ return quote(command) + def _scp_to_hyperai_instance( + self, + paramiko_client: paramiko.SSHClient, + f: IO[str], + directory_name: str, + file_name: str, + description: str, + ) -> None: + """Copies a file to a HyperAI instance using SCP. + + Args: + paramiko_client: The SSH client to use for the SCP transfer. + f: The file to transfer. + directory_name: The directory on the HyperAI instance to transfer + the file to. + file_name: The name of the file being transferred. + description: A description of the file being transferred. + + Raises: + RuntimeError: If the file cannot be written to the HyperAI instance. + """ + try: + scp_client = paramiko_client.open_sftp() + scp_client.put(f.name, f"{directory_name}/{file_name}") + scp_client.close() + except FileNotFoundError: + raise RuntimeError( + f"Failed to write {description} to HyperAI instance. Does the user have permissions to write?" + ) + def prepare_or_run_pipeline( self, deployment: "PipelineDeploymentResponse", @@ -230,17 +260,25 @@ def prepare_or_run_pipeline( # Add dependency on upstream steps if applicable upstream_steps = step.spec.upstream_steps - for upstream_step_name in upstream_steps: - upstream_container_name = ( - f"{deployment_id}-{upstream_step_name}" - ) + + if len(upstream_steps) > 0: compose_definition["services"][container_name][ "depends_on" - ] = { - upstream_container_name: { - "condition": "service_completed_successfully" - } - } + ] = {} + + for upstream_step_name in upstream_steps: + upstream_container_name = ( + f"{deployment_id}-{upstream_step_name}" + ) + compose_definition["services"][container_name][ + "depends_on" + ].update( + { + upstream_container_name: { + "condition": "service_completed_successfully" + } + } + ) # Convert into yaml logger.info("Finalizing Docker Compose definition.") @@ -373,14 +411,33 @@ def prepare_or_run_pipeline( f_.write(compose_definition_yaml) # Scp Docker Compose file to HyperAI instance - try: - scp_client = paramiko_client.open_sftp() - scp_client.put(f.name, f"{directory_name}/docker-compose.yaml") - scp_client.close() - except FileNotFoundError: - raise RuntimeError( - "Failed to write Docker Compose file to HyperAI instance. Does the user have permissions to write?" - ) + self._scp_to_hyperai_instance( + paramiko_client, + f, + directory_name, + file_name="docker-compose.yml", + description="Docker Compose file", + ) + + # Create temporary file and write script to it + with tempfile.NamedTemporaryFile(mode="w", delete=True) as f: + # Define bash line and command line + bash_line = "#!/bin/bash\n" + command_line = f'cd {directory_name} && echo {ENV_ZENML_HYPERAI_RUN_ID}="{deployment_id}_$(date +\%s)" > .env && docker compose up -d' + + # Write script to temporary file + with f.file as f_: + f_.write(bash_line) + f_.write(command_line) + + # Scp script to HyperAI instance + self._scp_to_hyperai_instance( + paramiko_client, + f, + directory_name, + file_name="run_pipeline.sh", + description="startup script", + ) # Run or schedule Docker Compose file depending on settings if not deployment.schedule: @@ -413,7 +470,7 @@ def prepare_or_run_pipeline( # Create cron job for scheduled pipeline on HyperAI instance stdin, stdout, stderr = paramiko_client.exec_command( # nosec - f"(crontab -l ; echo '{cron_expression} cd {directory_name} && echo {ENV_ZENML_HYPERAI_RUN_ID}=\"{deployment_id}_$(date +\%s)\" > .env && docker compose up -d') | crontab -" + f"(crontab -l ; echo '{cron_expression} bash {directory_name}/run_pipeline.sh') | crontab -" ) logger.info("Pipeline scheduled successfully.") diff --git a/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py b/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py index 7bc785dff46..09e13cad35e 100644 --- a/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py +++ b/src/zenml/integrations/kaniko/flavors/kaniko_image_builder_flavor.py @@ -16,7 +16,7 @@ import json from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union -from pydantic import validator +from pydantic import PositiveInt, validator from zenml.image_builders import BaseImageBuilderConfig, BaseImageBuilderFlavor from zenml.integrations.kaniko import KANIKO_IMAGE_BUILDER_FLAVOR @@ -29,6 +29,7 @@ DEFAULT_KANIKO_EXECUTOR_IMAGE = ( f"gcr.io/kaniko-project/executor:{KANIKO_EXECUTOR_IMAGE_TAG}" ) +DEFAULT_KANIKO_POD_RUNNING_TIMEOUT = 300 class KanikoImageBuilderConfig(BaseImageBuilderConfig): @@ -47,6 +48,8 @@ class KanikoImageBuilderConfig(BaseImageBuilderConfig): Kaniko pod. This namespace will not be created and must already exist. executor_image: The image of the Kaniko executor to use. + pod_running_timeout: The timeout to wait until the pod is running + in seconds. Defaults to `300`. env: `env` section of the Kubernetes container spec. env_from: `envFrom` section of the Kubernetes container spec. volume_mounts: `volumeMounts` section of the Kubernetes container spec. @@ -67,6 +70,7 @@ class KanikoImageBuilderConfig(BaseImageBuilderConfig): kubernetes_context: str kubernetes_namespace: str = "zenml-kaniko" executor_image: str = DEFAULT_KANIKO_EXECUTOR_IMAGE + pod_running_timeout: PositiveInt = DEFAULT_KANIKO_POD_RUNNING_TIMEOUT env: List[Dict[str, Any]] = [] env_from: List[Dict[str, Any]] = [] diff --git a/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py b/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py index 314e31f0657..ebb3f09fefa 100644 --- a/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py +++ b/src/zenml/integrations/kaniko/image_builders/kaniko_image_builder.py @@ -257,6 +257,8 @@ def _run_kaniko_build( self.config.executor_image, "--overrides", json.dumps(spec_overrides), + "--pod-running-timeout", + f"{self.config.pod_running_timeout}s", ] logger.debug("Running Kaniko build with command: %s", command) with subprocess.Popen( diff --git a/src/zenml/integrations/mlflow/__init__.py b/src/zenml/integrations/mlflow/__init__.py index c7cc82e790d..3cd8a0d146f 100644 --- a/src/zenml/integrations/mlflow/__init__.py +++ b/src/zenml/integrations/mlflow/__init__.py @@ -35,7 +35,7 @@ class MlflowIntegration(Integration): # does not pin it. They fixed this in a later version, so we can probably # remove this once we update the mlflow version. REQUIREMENTS = [ - "mlflow>=2.1.1,<=2.10.2", + "mlflow>=2.1.1,<=2.11.3", "mlserver>=1.3.3", "mlserver-mlflow>=1.3.3", # TODO: remove this requirement once rapidjson is fixed diff --git a/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py b/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py index d163ae09a5e..1c0de9b43d2 100644 --- a/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py +++ b/src/zenml/integrations/mlflow/model_deployers/mlflow_model_deployer.py @@ -15,8 +15,7 @@ import os import shutil -from pathlib import Path -from typing import ClassVar, Dict, List, Optional, Type, cast +from typing import ClassVar, Dict, Optional, Type, cast from uuid import UUID from zenml.config.global_config import GlobalConfiguration @@ -31,8 +30,6 @@ ) from zenml.logger import get_logger from zenml.model_deployers import BaseModelDeployer, BaseModelDeployerFlavor -from zenml.services import ServiceRegistry -from zenml.services.local.local_service import SERVICE_DAEMON_CONFIG_FILE_NAME from zenml.services.service import BaseService, ServiceConfig from zenml.utils.io_utils import create_dir_recursive_if_not_exists @@ -120,12 +117,15 @@ def get_model_server_info( # type: ignore[override] "REGISTRY_MODEL_VERSION": service_instance.config.registry_model_version, "SERVICE_PATH": service_instance.status.runtime_path, "DAEMON_PID": str(service_instance.status.pid), + "HEALTH_CHECK_URL": service_instance.endpoint.monitor.get_healthcheck_uri( + service_instance.endpoint + ), } - def deploy_model( + def perform_deploy_model( self, + id: UUID, config: ServiceConfig, - replace: bool = False, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, ) -> BaseService: """Create a new MLflow deployment service or update an existing one. @@ -157,10 +157,8 @@ def deploy_model( and the others are deleted. Args: + id: the ID of the MLflow deployment service to be created or updated. config: the configuration of the model to be deployed with MLflow. - replace: set this flag to True to find and update an equivalent - MLflow deployment server with the new model instead of - creating and starting a new deployment server. timeout: the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the MLflow @@ -171,49 +169,11 @@ def deploy_model( interact with the MLflow model server. """ config = cast(MLFlowDeploymentConfig, config) - service = None - - # if replace is True, remove all existing services - if replace is True: - existing_services = self.find_model_server( - pipeline_name=config.pipeline_name, - pipeline_step_name=config.pipeline_step_name, - model_name=config.model_name, - ) - - for existing_service in existing_services: - if service is None: - # keep the most recently created service - service = cast(MLFlowDeploymentService, existing_service) - try: - # delete the older services and don't wait for them to - # be deprovisioned - self._clean_up_existing_service( - existing_service=cast( - MLFlowDeploymentService, existing_service - ), - timeout=timeout, - force=True, - ) - except RuntimeError: - # ignore errors encountered while stopping old services - pass - if service: - logger.info( - f"Updating an existing MLflow deployment service: {service}" - ) - - # set the root runtime path with the stack component's UUID - config.root_runtime_path = self.local_path - service.stop(timeout=timeout, force=True) - service.update(config) - service.start(timeout=timeout) - else: - # create a new MLFlowDeploymentService instance - service = self._create_new_service(timeout, config) - logger.info(f"Created a new MLflow deployment service: {service}") - - return cast(BaseService, service) + service = self._create_new_service( + id=id, timeout=timeout, config=config + ) + logger.info(f"Created a new MLflow deployment service: {service}") + return service def _clean_up_existing_service( self, @@ -232,11 +192,12 @@ def _clean_up_existing_service( # of workers etc.the step implementation will create a new config using # all values from the user and add values like pipeline name, model_uri def _create_new_service( - self, timeout: int, config: MLFlowDeploymentConfig + self, id: UUID, timeout: int, config: MLFlowDeploymentConfig ) -> MLFlowDeploymentService: """Creates a new MLFlowDeploymentService. Args: + id: the ID of the MLflow deployment service to be created or updated. timeout: the timeout in seconds to wait for the MLflow server to be provisioned and successfully started or updated. config: the configuration of the model to be deployed with MLflow. @@ -248,213 +209,61 @@ def _create_new_service( # set the root runtime path with the stack component's UUID config.root_runtime_path = self.local_path # create a new service for the new model - service = MLFlowDeploymentService(config) + service = MLFlowDeploymentService(uuid=id, config=config) service.start(timeout=timeout) return service - def find_model_server( + def perform_stop_model( self, - running: bool = False, - service_uuid: Optional[UUID] = None, - pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, - registry_model_name: Optional[str] = None, - registry_model_version: Optional[str] = None, - ) -> List[BaseService]: - """Finds one or more model servers that match the given criteria. - - Args: - running: If true, only running services will be returned. - service_uuid: The UUID of the service that was originally used - to deploy the model. - pipeline_name: Name of the pipeline that the deployed model was part - of. - run_name: Name of the pipeline run which the deployed model - was part of. - pipeline_step_name: The name of the pipeline model deployment step - that deployed the model. - model_name: Name of the deployed model. - model_uri: URI of the deployed model. - model_type: Type/format of the deployed model. Not used in this - MLflow case. - registry_model_name: Name of the registered model that the - deployed model belongs to. - registry_model_version: Version of the registered model that - the deployed model belongs to. - - Returns: - One or more Service objects representing model servers that match - the input search criteria. - - Raises: - TypeError: if any of the input arguments are of an invalid type. - """ - services = [] - config = MLFlowDeploymentConfig( - model_name=model_name or "", - model_uri=model_uri or "", - pipeline_name=pipeline_name or "", - pipeline_run_id=run_name or "", - run_name=run_name or "", - pipeline_step_name=pipeline_step_name or "", - registry_model_name=registry_model_name, - registry_model_version=registry_model_version, - ) - - # find all services that match the input criteria - for root, _, files in os.walk(self.local_path): - if service_uuid and Path(root).name != str(service_uuid): - continue - for file in files: - if file == SERVICE_DAEMON_CONFIG_FILE_NAME: - service_config_path = os.path.join(root, file) - logger.debug( - "Loading service daemon configuration from %s", - service_config_path, - ) - existing_service_config = None - with open(service_config_path, "r") as f: - existing_service_config = f.read() - existing_service = ( - ServiceRegistry().load_service_from_json( - existing_service_config - ) - ) - if not isinstance( - existing_service, MLFlowDeploymentService - ): - raise TypeError( - f"Expected service type MLFlowDeploymentService but got " - f"{type(existing_service)} instead" - ) - existing_service.update_status() - if self._matches_search_criteria(existing_service, config): - if not running or existing_service.is_running: - services.append( - cast(BaseService, existing_service) - ) - - return services - - def _matches_search_criteria( - self, - existing_service: MLFlowDeploymentService, - config: MLFlowDeploymentConfig, - ) -> bool: - """Returns true if a service matches the input criteria. - - If any of the values in the input criteria are None, they are ignored. - This allows listing services just by common pipeline names or step - names, etc. - - Args: - existing_service: The materialized Service instance derived from - the config of the older (existing) service - config: The MLFlowDeploymentConfig object passed to the - deploy_model function holding parameters of the new service - to be created. - - Returns: - True if the service matches the input criteria. - """ - existing_service_config = existing_service.config - # check if the existing service matches the input criteria - if ( - ( - not config.pipeline_name - or existing_service_config.pipeline_name - == config.pipeline_name - ) - and ( - not config.model_name - or existing_service_config.model_name == config.model_name - ) - and ( - not config.pipeline_step_name - or existing_service_config.pipeline_step_name - == config.pipeline_step_name - ) - and ( - not config.run_name - or existing_service_config.run_name == config.run_name - ) - and ( - ( - not config.registry_model_name - and not config.registry_model_version - ) - or ( - existing_service_config.registry_model_name - == config.registry_model_name - and existing_service_config.registry_model_version - == config.registry_model_version - ) - ) - ): - return True - - return False - - def stop_model_server( - self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False, - ) -> None: + ) -> BaseService: """Method to stop a model server. Args: - uuid: UUID of the model server to stop. + service: The service to stop. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, stop it - if existing_services: - existing_services[0].stop(timeout=timeout, force=force) + Returns: + The service that was stopped. + """ + service.stop(timeout=timeout, force=force) + return service - def start_model_server( - self, uuid: UUID, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT - ) -> None: + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, + ) -> BaseService: """Method to start a model server. Args: - uuid: UUID of the model server to start. + service: The service to start. timeout: Timeout in seconds to wait for the service to start. - """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - # if the service exists, start it - if existing_services: - existing_services[0].start(timeout=timeout) + Returns: + The service that was started. + """ + service.start(timeout=timeout) + return service - def delete_model_server( + def perform_delete_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SERVICE_START_STOP_TIMEOUT, force: bool = False, ) -> None: """Method to delete all configuration of a model server. Args: - uuid: UUID of the model server to delete. + service: The service to delete. timeout: Timeout in seconds to wait for the service to stop. force: If True, force the service to stop. """ - # get list of all services - existing_services = self.find_model_server(service_uuid=uuid) - - # if the service exists, clean it up - if existing_services: - service = cast(MLFlowDeploymentService, existing_services[0]) - self._clean_up_existing_service( - existing_service=service, timeout=timeout, force=force - ) + service = cast(MLFlowDeploymentService, service) + self._clean_up_existing_service( + existing_service=service, timeout=timeout, force=force + ) diff --git a/src/zenml/integrations/mlflow/services/mlflow_deployment.py b/src/zenml/integrations/mlflow/services/mlflow_deployment.py index 114f7e66a16..2cdccdbbf09 100644 --- a/src/zenml/integrations/mlflow/services/mlflow_deployment.py +++ b/src/zenml/integrations/mlflow/services/mlflow_deployment.py @@ -101,8 +101,6 @@ class MLFlowDeploymentConfig(LocalDaemonServiceConfig): timeout: timeout in seconds for starting and stopping the service """ - # TODO: ServiceConfig should have additional fields such as "pipeline_run_uuid" - # and "pipeline_uuid" to allow for better tracking of the service. model_uri: str model_name: str registry_model_name: Optional[str] = None @@ -128,6 +126,7 @@ class MLFlowDeploymentService(LocalDaemonService, BaseDeploymentService): type="model-serving", flavor="mlflow", description="MLflow prediction service", + logo_url="https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/mlflow.png", ) config: MLFlowDeploymentConfig diff --git a/src/zenml/integrations/mlflow/steps/mlflow_deployer.py b/src/zenml/integrations/mlflow/steps/mlflow_deployer.py index a93a42d75e9..6d0fde1a9df 100644 --- a/src/zenml/integrations/mlflow/steps/mlflow_deployer.py +++ b/src/zenml/integrations/mlflow/steps/mlflow_deployer.py @@ -118,32 +118,30 @@ def mlflow_model_deployer_step( run_id=mlflow_run_id, artifact_path=model_name ) - # Fetch existing services with same pipeline name, step name and model name - existing_services = model_deployer.find_model_server( + predictor_cfg = MLFlowDeploymentConfig( + model_name=model_name or "", + model_uri=model_uri, + workers=workers, + mlserver=mlserver, pipeline_name=pipeline_name, pipeline_step_name=step_name, - model_name=model_name, + timeout=timeout, + ) + + # Fetch existing services with same pipeline name, step name and model name + existing_services = model_deployer.find_model_server( + config=predictor_cfg.dict(), ) # Check whether to deploy a new service if model_uri and deploy_decision: - predictor_cfg = MLFlowDeploymentConfig( - model_name=model_name or "", - model_uri=model_uri, - workers=workers, - mlserver=mlserver, - pipeline_name=pipeline_name, - run_name=run_name, - pipeline_run_id=run_name, - pipeline_step_name=step_name, - timeout=timeout, - ) new_service = cast( MLFlowDeploymentService, model_deployer.deploy_model( replace=True, config=predictor_cfg, timeout=timeout, + service_type=MLFlowDeploymentService.SERVICE_TYPE, ), ) logger.info( @@ -277,26 +275,25 @@ def mlflow_model_registry_deployer_step( f"using this step." ) # fetch existing services with same pipeline name, step name and model name + existing_services = ( model_deployer.find_model_server( - registry_model_name=model_version.registered_model.name, + model_name=registry_model_name, + model_version=model_version.version, ) if replace_existing else None ) - # create a config for the new model service metadata = model_version.metadata or ModelRegistryModelMetadata() predictor_cfg = MLFlowDeploymentConfig( - model_name=model_name or "", + name=model_name or None, + model_name=registry_model_name, + model_version=model_version.version, model_uri=model_version.model_source_uri, - registry_model_name=model_version.registered_model.name, - registry_model_version=model_version.version, - registry_model_stage=model_version.stage.value, workers=workers, mlserver=mlserver, pipeline_name=metadata.zenml_pipeline_name or "", - run_name=metadata.zenml_run_name or "", pipeline_step_name=metadata.zenml_step_name or "", timeout=timeout, ) @@ -308,6 +305,7 @@ def mlflow_model_registry_deployer_step( replace=True, config=predictor_cfg, timeout=timeout, + service_type=MLFlowDeploymentService.SERVICE_TYPE, ), ) diff --git a/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py b/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py index 9529ccbcc85..8ae9282bee5 100644 --- a/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py +++ b/src/zenml/integrations/seldon/model_deployers/seldon_model_deployer.py @@ -15,7 +15,6 @@ import json import re -from datetime import datetime from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Type, cast from uuid import UUID @@ -479,10 +478,10 @@ def _delete_kubernetes_secret(self, secret_name: str) -> None: return self.seldon_client.delete_secret(secret_name) - def deploy_model( + def perform_deploy_model( self, + id: UUID, config: ServiceConfig, - replace: bool = False, timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT, ) -> BaseService: """Create a new Seldon Core deployment or update an existing one. @@ -517,11 +516,9 @@ def deploy_model( to be updated and the others are deleted. Args: + id: the UUID of the model server to deploy. config: the configuration of the model to be deployed with Seldon. Core - replace: set this flag to True to find and update an equivalent - Seldon Core deployment server with the new model instead of - starting a new deployment server. timeout: the timeout in seconds to wait for the Seldon Core server to be provisioned and successfully started or updated. If set to 0, the method will return immediately after the Seldon Core @@ -541,31 +538,6 @@ def deploy_model( """ with track_handler(AnalyticsEvent.MODEL_DEPLOYED) as analytics_handler: config = cast(SeldonDeploymentConfig, config) - service = None - - # if replace is True, find equivalent Seldon Core deployments - if replace is True: - equivalent_services = self.find_model_server( - running=False, - pipeline_name=config.pipeline_name, - pipeline_step_name=config.pipeline_step_name, - model_name=config.model_name, - ) - - for equivalent_service in equivalent_services: - if service is None: - # keep the most recently created service - service = equivalent_service - else: - try: - # delete the older services and don't wait for - # them to be deprovisioned - service.stop() - except RuntimeError: - # ignore errors encountered while stopping old - # services - pass - # if a custom Kubernetes secret is not explicitly specified in the # SeldonDeploymentConfig, try to create one from the ZenML secret # configured for the model deployer @@ -573,19 +545,9 @@ def deploy_model( config.secret_name or self._create_or_update_kubernetes_secret() ) - - if service: - # update an equivalent service in place - service.update(config) - logger.info( - f"Updating an existing Seldon deployment service: {service}" - ) - else: - # create a new service - service = SeldonDeploymentService(config=config) - logger.info( - f"Creating a new Seldon deployment service: {service}" - ) + # create a new service + service = SeldonDeploymentService(uuid=id, config=config) + logger.info(f"Creating a new Seldon deployment service: {service}") # start the service which in turn provisions the Seldon Core # deployment server and waits for it to reach a ready state @@ -606,95 +568,16 @@ def deploy_model( return service - def find_model_server( - self, - running: bool = False, - service_uuid: Optional[UUID] = None, - pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, - pipeline_step_name: Optional[str] = None, - model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, - ) -> List[BaseService]: - """Find one or more Seldon Core model services that match the given criteria. - - The Seldon Core deployment services that meet the search criteria are - returned sorted in descending order of their creation time (i.e. more - recent deployments first). - - Args: - running: if true, only running services will be returned. - service_uuid: the UUID of the Seldon Core service that was - originally used to create the Seldon Core deployment resource. - pipeline_name: name of the pipeline that the deployed model was part - of. - run_name: Name of the pipeline run which the deployed model was - part of. - pipeline_step_name: the name of the pipeline model deployment step - that deployed the model. - model_name: the name of the deployed model. - model_uri: URI of the deployed model. - model_type: the Seldon Core server implementation used to serve - the model - - Returns: - One or more Seldon Core service objects representing Seldon Core - model servers that match the input search criteria. - """ - # Use a Seldon deployment service configuration to compute the labels - config = SeldonDeploymentConfig( - pipeline_name=pipeline_name or "", - run_name=run_name or "", - pipeline_run_id=run_name or "", - pipeline_step_name=pipeline_step_name or "", - model_name=model_name or "", - model_uri=model_uri or "", - implementation=model_type or "", - ) - labels = config.get_seldon_deployment_labels() - if service_uuid: - # the service UUID is not a label covered by the Seldon - # deployment service configuration, so we need to add it - # separately - labels["zenml.service_uuid"] = str(service_uuid) - - deployments = self.seldon_client.find_deployments(labels=labels) - # sort the deployments in descending order of their creation time - deployments.sort( - key=lambda deployment: datetime.strptime( - deployment.metadata.creationTimestamp, - "%Y-%m-%dT%H:%M:%SZ", - ) - if deployment.metadata.creationTimestamp - else datetime.min, - reverse=True, - ) - - services: List[BaseService] = [] - for deployment in deployments: - # recreate the Seldon deployment service object from the Seldon - # deployment resource - service = SeldonDeploymentService.create_from_deployment( - deployment=deployment - ) - if running and not service.is_running: - # skip non-running services - continue - services.append(service) - - return services - - def stop_model_server( + def perform_stop_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT, force: bool = False, - ) -> None: + ) -> BaseService: """Stop a Seldon Core model server. Args: - uuid: UUID of the model server to stop. + service: The service to stop. timeout: timeout in seconds to wait for the service to stop. force: if True, force the service to stop. @@ -707,15 +590,15 @@ def stop_model_server( "deleting the Seldon Core model server instead." ) - def start_model_server( + def perform_start_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT, - ) -> None: + ) -> BaseService: """Start a Seldon Core model deployment server. Args: - uuid: UUID of the model server to start. + service: The service to start. timeout: timeout in seconds to wait for the service to become active. . If set to 0, the method will return immediately after provisioning the service, without waiting for it to become @@ -729,28 +612,22 @@ def start_model_server( "Starting Seldon Core model servers is not implemented" ) - def delete_model_server( + def perform_delete_model( self, - uuid: UUID, + service: BaseService, timeout: int = DEFAULT_SELDON_DEPLOYMENT_START_STOP_TIMEOUT, force: bool = False, ) -> None: """Delete a Seldon Core model deployment server. Args: - uuid: UUID of the model server to delete. + service: The service to delete. timeout: timeout in seconds to wait for the service to stop. If set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop. force: if True, force the service to stop. """ - services = self.find_model_server(service_uuid=uuid) - if len(services) == 0: - return - - service = services[0] - - assert isinstance(service, SeldonDeploymentService) + service = cast(SeldonDeploymentService, service) service.stop(timeout=timeout, force=force) if service.config.secret_name: diff --git a/src/zenml/integrations/seldon/services/seldon_deployment.py b/src/zenml/integrations/seldon/services/seldon_deployment.py index 28c6a1c1822..5d3c56a1b04 100644 --- a/src/zenml/integrations/seldon/services/seldon_deployment.py +++ b/src/zenml/integrations/seldon/services/seldon_deployment.py @@ -86,8 +86,6 @@ def get_seldon_deployment_labels(self) -> Dict[str, str]: labels = {} if self.pipeline_name: labels["zenml.pipeline_name"] = self.pipeline_name - if self.run_name: - labels["zenml.run_name"] = self.run_name if self.pipeline_step_name: labels["zenml.pipeline_step_name"] = self.pipeline_step_name if self.model_name: @@ -174,6 +172,7 @@ class SeldonDeploymentService(BaseDeploymentService): type="model-serving", flavor="seldon", description="Seldon Core prediction service", + logo_url="https://public-flavor-logos.s3.eu-central-1.amazonaws.com/model_deployer/seldon.png", ) config: SeldonDeploymentConfig diff --git a/src/zenml/integrations/seldon/steps/seldon_deployer.py b/src/zenml/integrations/seldon/steps/seldon_deployer.py index e89944e9f11..0b527252e4b 100644 --- a/src/zenml/integrations/seldon/steps/seldon_deployer.py +++ b/src/zenml/integrations/seldon/steps/seldon_deployer.py @@ -73,13 +73,11 @@ def seldon_model_deployer_step( # get pipeline name, step name and run id context = get_step_context() pipeline_name = context.pipeline.name - run_name = context.pipeline_run.name step_name = context.step_run.name # update the step configuration with the real pipeline runtime information service_config = service_config.copy() service_config.pipeline_name = pipeline_name - service_config.run_name = run_name service_config.pipeline_step_name = step_name def prepare_service_config(model_uri: str) -> SeldonDeploymentConfig: @@ -143,9 +141,7 @@ def prepare_service_config(model_uri: str) -> SeldonDeploymentConfig: # fetch existing services with same pipeline name, step name and # model name existing_services = model_deployer.find_model_server( - pipeline_name=pipeline_name, - pipeline_step_name=step_name, - model_name=service_config.model_name, + config=service_config.dict() ) # even when the deploy decision is negative, if an existing model server @@ -173,7 +169,10 @@ def prepare_service_config(model_uri: str) -> SeldonDeploymentConfig: service = cast( SeldonDeploymentService, model_deployer.deploy_model( - service_config, replace=True, timeout=timeout + service_config, + replace=True, + timeout=timeout, + service_type=SeldonDeploymentService.SERVICE_TYPE, ), ) @@ -231,21 +230,17 @@ def seldon_custom_model_deployer_step( # get pipeline name, step name, run id context = get_step_context() pipeline_name = context.pipeline.name - run_name = context.pipeline_run.name step_name = context.step_run.name # update the step configuration with the real pipeline runtime information service_config.pipeline_name = pipeline_name - service_config.run_name = run_name service_config.pipeline_step_name = step_name service_config.is_custom_deployment = True # fetch existing services with the same pipeline name, step name and # model name existing_services = model_deployer.find_model_server( - pipeline_name=pipeline_name, - pipeline_step_name=step_name, - model_name=service_config.model_name, + config=service_config.dict() ) # even when the deploy decision is negative if an existing model server # is not running for this pipeline/step, we still have to serve the @@ -325,7 +320,10 @@ def seldon_custom_model_deployer_step( service = cast( SeldonDeploymentService, model_deployer.deploy_model( - service_config, replace=True, timeout=timeout + service_config, + replace=True, + timeout=timeout, + service_type=SeldonDeploymentService.SERVICE_TYPE, ), ) @@ -476,7 +474,10 @@ def seldon_mlflow_registry_deployer_step( service = cast( SeldonDeploymentService, model_deployer.deploy_model( - service_config, replace=True, timeout=timeout + service_config, + replace=True, + timeout=timeout, + service_type=SeldonDeploymentService.SERVICE_TYPE, ), ) diff --git a/src/zenml/integrations/tensorboard/services/tensorboard_service.py b/src/zenml/integrations/tensorboard/services/tensorboard_service.py index adda572fd42..b6c61a8d017 100644 --- a/src/zenml/integrations/tensorboard/services/tensorboard_service.py +++ b/src/zenml/integrations/tensorboard/services/tensorboard_service.py @@ -13,6 +13,7 @@ # permissions and limitations under the License. """Implementation of the TensorBoard service.""" +import uuid from typing import Any, Dict, Union from tensorboard import default, program # type: ignore [import-untyped] @@ -103,7 +104,7 @@ def __init__( ), ) attrs["endpoint"] = endpoint - super().__init__(config=config, **attrs) + super().__init__(config=config, uuid=uuid.uuid4(), **attrs) def run(self) -> None: """Initialize and run the TensorBoard server.""" diff --git a/src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py b/src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py index 5be0e01d878..bc9b0c20f00 100644 --- a/src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +++ b/src/zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py @@ -113,6 +113,7 @@ def visualize( service = TensorboardService( TensorboardServiceConfig( logdir=logdir, + name=f"zenml-tensorboard-{logdir}", ) ) service.start(timeout=60) diff --git a/src/zenml/materializers/service_materializer.py b/src/zenml/materializers/service_materializer.py index 7659cabe0ca..a8294433ab0 100644 --- a/src/zenml/materializers/service_materializer.py +++ b/src/zenml/materializers/service_materializer.py @@ -14,13 +14,13 @@ """Implementation of a materializer to read and write ZenML service instances.""" import os +import uuid from typing import TYPE_CHECKING, Any, ClassVar, Dict, Tuple, Type from zenml.client import Client from zenml.enums import ArtifactType from zenml.materializers.base_materializer import BaseMaterializer -from zenml.services.service import BaseService -from zenml.services.service_registry import ServiceRegistry +from zenml.services.service import BaseDeploymentService, BaseService if TYPE_CHECKING: from zenml.metadata.metadata_types import MetadataType @@ -49,8 +49,11 @@ def load(self, data_type: Type[Any]) -> BaseService: artifact_store = Client().active_stack.artifact_store filepath = os.path.join(self.uri, SERVICE_CONFIG_FILENAME) with artifact_store.open(filepath, "r") as f: - service = ServiceRegistry().load_service_from_json(f.read()) - return service + service_id = f.read().strip() + + client = Client() + service = client.get_service(name_id_or_prefix=uuid.UUID(service_id)) + return BaseDeploymentService.from_model(service) def save(self, service: BaseService) -> None: """Writes a ZenML service. @@ -64,7 +67,7 @@ def save(self, service: BaseService) -> None: artifact_store = Client().active_stack.artifact_store filepath = os.path.join(self.uri, SERVICE_CONFIG_FILENAME) with artifact_store.open(filepath, "w") as f: - f.write(service.json(indent=4)) + f.write(str(service.uuid)) def extract_metadata( self, service: BaseService @@ -79,6 +82,6 @@ def extract_metadata( """ from zenml.metadata.metadata_types import Uri - if service.endpoint and service.endpoint.status.uri: - return {"uri": Uri(service.endpoint.status.uri)} + if prediction_url := service.get_prediction_url() or None: + return {"uri": Uri(prediction_url)} return {} diff --git a/src/zenml/model/model.py b/src/zenml/model/model.py index 260d0d67c08..5374ffe2ce0 100644 --- a/src/zenml/model/model.py +++ b/src/zenml/model/model.py @@ -503,9 +503,13 @@ def _root_validator(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["suppress_class_validation_warnings"] = True return values - def _validate_config_in_runtime(self) -> None: - """Validate that config doesn't conflict with runtime environment.""" - self._get_or_create_model_version() + def _validate_config_in_runtime(self) -> "ModelVersionResponse": + """Validate that config doesn't conflict with runtime environment. + + Returns: + The model version based on configuration. + """ + return self._get_or_create_model_version() def _get_or_create_model(self) -> "ModelResponse": """This method should get or create a model from Model Control Plane. @@ -545,12 +549,10 @@ def _get_or_create_model(self) -> "ModelResponse": ) logger.info(f"New model `{self.name}` was created implicitly.") except EntityExistsError: - # this is backup logic, if model was created somehow in between get and create calls - pass - finally: model = zenml_client.zen_store.get_model( model_name_or_id=self.name ) + self._model_id = model.id return model @@ -722,7 +724,9 @@ def _get_or_create_model_version( retries_made += 1 self.version = model_version.name self.was_created_in_this_run = True + logger.info(f"New model version `{self.version}` was created.") + self._id = model_version.id self._model_id = model_version.model.id self._number = model_version.number diff --git a/src/zenml/model/utils.py b/src/zenml/model/utils.py index 824a4ef71e7..f947d182397 100644 --- a/src/zenml/model/utils.py +++ b/src/zenml/model/utils.py @@ -23,7 +23,10 @@ from zenml.logger import get_logger from zenml.metadata.metadata_types import MetadataType from zenml.model.model import Model -from zenml.models import ModelVersionArtifactRequest +from zenml.models import ( + ModelVersionArtifactRequest, + ServiceUpdate, +) from zenml.new.steps.step_context import get_step_context logger = get_logger(__name__) @@ -219,3 +222,49 @@ def link_artifact_to_model( artifact_version_id=artifact_version_id, model=model, ) + + +def link_service_to_model( + service_id: UUID, + model: Optional["Model"] = None, + model_version_id: Optional[UUID] = None, +) -> None: + """Links a service to a model. + + Args: + service_id: The ID of the service to link to the model. + model: The model to link the service to. + model_version_id: The ID of the model version to link the service to. + + Raises: + RuntimeError: If no model is provided and the model context cannot be + identified. + """ + client = Client() + + # If no model is provided, try to get it from the context + if not model and not model_version_id: + is_issue = False + try: + step_context = get_step_context() + model = step_context.model + except StepContextError: + is_issue = True + + if model is None or is_issue: + raise RuntimeError( + "`link_service_to_model` called without `model` parameter " + "and configured model context cannot be identified. Consider " + "passing the `model` explicitly or configuring it in " + "@step or @pipeline decorator." + ) + + model_version_id = ( + model_version_id or model._get_or_create_model_version().id + if model + else None + ) + update_service = ServiceUpdate(model_version_id=model_version_id) + client.zen_store.update_service( + service_id=service_id, update=update_service + ) diff --git a/src/zenml/model_deployers/base_model_deployer.py b/src/zenml/model_deployers/base_model_deployer.py index ccc3831b850..747c61fd674 100644 --- a/src/zenml/model_deployers/base_model_deployer.py +++ b/src/zenml/model_deployers/base_model_deployer.py @@ -13,9 +13,10 @@ # permissions and limitations under the License. """Base class for all ZenML model deployers.""" +import contextlib from abc import ABC, abstractmethod from typing import ( - TYPE_CHECKING, + Any, ClassVar, Dict, Generator, @@ -27,19 +28,16 @@ from uuid import UUID from zenml.client import Client -from zenml.constants import METADATA_DEPLOYED_MODEL_URL from zenml.enums import StackComponentType -from zenml.metadata.metadata_types import Uri +from zenml.logger import get_logger from zenml.services import BaseService, ServiceConfig from zenml.services.service import BaseDeploymentService +from zenml.services.service_type import ServiceType from zenml.stack import StackComponent from zenml.stack.flavor import Flavor from zenml.stack.stack_component import StackComponentConfig -if TYPE_CHECKING: - from zenml.config.step_run_info import StepRunInfo - from zenml.metadata.metadata_types import MetadataType - +logger = get_logger(__name__) DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT = 300 @@ -125,11 +123,118 @@ def get_active_model_deployer(cls) -> "BaseModelDeployer": return model_deployer - @abstractmethod def deploy_model( self, config: ServiceConfig, + service_type: ServiceType, replace: bool = False, + continuous_deployment_mode: bool = False, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: + """Deploy a model. + + the deploy_model method is the main entry point for deploying models + using the model deployer. It is used to deploy a model to a model server + instance that is running on a remote serving platform or service. The + method is responsible for detecting if there is an existing model server + instance running serving one or more previous versions of the same model + and deploying the model to the serving platform or updating the existing + model server instance to include the new model version. The method + returns a Service object that is a representation of the external model + server instance. The Service object must implement basic operational + state tracking and lifecycle management operations for the model server + (e.g. start, stop, etc.). + + Args: + config: Custom Service configuration parameters for the model + deployer. Can include the pipeline name, the run id, the step + name, the model name, the model uri, the model type etc. + replace: If True, it will replace any existing model server instances + that serve the same model. If False, it does not replace any + existing model server instance. + continuous_deployment_mode: If True, it will replace any existing + model server instances that serve the same model, regardless of + the configuration. If False, it will only replace existing model + server instances that serve the same model if the configuration + is exactly the same. + timeout: The maximum time in seconds to wait for the model server + to start serving the model. + service_type: The type of the service to deploy. If not provided, + the default service type of the model deployer will be used. + + Raises: + RuntimeError: if the model deployment fails. + + Returns: + The deployment Service object. + """ + # Instantiate the client + client = Client() + if not continuous_deployment_mode: + # Find existing model server + services = self.find_model_server( + config=config.dict(), + service_type=service_type, + ) + if len(services) > 0: + logger.info( + f"Existing model server found for {config.name or config.model_name} with the exact same configuration. Returning the existing service named {services[0].config.service_name}." + ) + return services[0] + else: + # Find existing model server + services = self.find_model_server( + pipeline_name=config.pipeline_name, + pipeline_step_name=config.pipeline_step_name, + model_name=config.model_name, + service_type=service_type, + ) + if len(services) > 0: + logger.info( + f"Existing model server found for {config.pipeline_name} and {config.pipeline_step_name}, since continuous deployment mode is enabled, replacing the existing service named {services[0].config.service_name}." + ) + service = services[0] + self.delete_model_server(service.uuid) + logger.info( + f"Deploying model server for {config.model_name} with the following configuration: {config.dict()}" + ) + service_response = client.create_service( + config=config, + service_type=service_type, + model_version_id=get_model_version_id_if_exists( + config.model_name, config.model_version + ), + ) + try: + service = self.perform_deploy_model( + id=service_response.id, + config=config, + timeout=timeout, + ) + except Exception as e: + client.delete_service(service_response.id) + raise RuntimeError( + f"Failed to deploy model server for {config.model_name}: {e}" + ) from e + # Update the service in store + client.update_service( + id=service.uuid, + name=service.config.service_name, + service_source=service.dict().get("type"), + admin_state=service.admin_state, + status=service.status.dict(), + endpoint=service.endpoint.dict() if service.endpoint else None, + # labels=service.config.get_service_labels() # TODO: fix labels in services and config + prediction_url=service.get_prediction_url(), + health_check_url=service.get_healthcheck_url(), + ) + return service + + @abstractmethod + def perform_deploy_model( + self, + id: UUID, + config: ServiceConfig, timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, ) -> BaseService: """Abstract method to deploy a model. @@ -146,12 +251,10 @@ def deploy_model( start, stop, etc.) Args: + id: UUID of the service that was originally used to deploy the model. config: Custom Service configuration parameters for the model deployer. Can include the pipeline name, the run id, the step name, the model name, the model uri, the model type etc. - replace: If True, it will replace any existing model server instances - that serve the same model. If False, it does not replace any - existing model server instance. timeout: The maximum time in seconds to wait for the model server to start serving the model. @@ -173,17 +276,20 @@ def get_model_server_info( A dictionary containing the relevant model server properties. """ - @abstractmethod def find_model_server( self, - running: bool = False, + config: Optional[Dict[str, Any]] = None, + running: Optional[bool] = None, service_uuid: Optional[UUID] = None, pipeline_name: Optional[str] = None, - run_name: Optional[str] = None, pipeline_step_name: Optional[str] = None, + service_name: Optional[str] = None, model_name: Optional[str] = None, - model_uri: Optional[str] = None, - model_type: Optional[str] = None, + model_version: Optional[str] = None, + service_type: Optional[ServiceType] = None, + type: Optional[str] = None, + flavor: Optional[str] = None, + pipeline_run_id: Optional[str] = None, ) -> List[BaseService]: """Abstract method to find one or more a model servers that match the given criteria. @@ -191,23 +297,91 @@ def find_model_server( running: If true, only running services will be returned. service_uuid: The UUID of the service that was originally used to deploy the model. - pipeline_name: name of the pipeline that the deployed model was part - of. - run_name: Name of the pipeline run which the deployed model was - part of. - pipeline_step_name: the name of the pipeline model deployment step - that deployed the model. - model_name: the name of the deployed model. - model_uri: URI of the deployed model. - model_type: the implementation specific type/format of the deployed - model. + pipeline_step_name: The name of the pipeline step that was originally used + to deploy the model. + pipeline_name: The name of the pipeline that was originally used to deploy + the model from the model registry. + model_name: The name of the model that was originally used to deploy + the model from the model registry. + model_version: The version of the model that was originally used to + deploy the model from the model registry. + service_type: The type of the service to find. + type: The type of the service to find. + flavor: The flavor of the service to find. + pipeline_run_id: The UUID of the pipeline run that was originally used + to deploy the model. + config: Custom Service configuration parameters for the model + deployer. Can include the pipeline name, the run id, the step + name, the model name, the model uri, the model type etc. + service_name: The name of the service to find. Returns: One or more Service objects representing model servers that match the input search criteria. """ + client = Client() + service_responses = client.list_services( + sort_by="desc:created", + id=service_uuid, + running=running, + service_name=service_name, + pipeline_name=pipeline_name, + pipeline_step_name=pipeline_step_name, + model_version_id=get_model_version_id_if_exists( + model_name, model_version + ), + pipeline_run_id=pipeline_run_id, + config=config, + type=type or service_type.type if service_type else None, + flavor=flavor or service_type.flavor if service_type else None, + hydrate=True, + ) + services = [] + for service_response in service_responses.items: + if not service_response.service_source: + client.delete_service(service_response.id) + continue + service = BaseDeploymentService.from_model(service_response) + service.update_status() + if service.status.dict() != service_response.status: + client.update_service( + id=service.uuid, + admin_state=service.admin_state, + status=service.status.dict(), + endpoint=service.endpoint.dict() + if service.endpoint + else None, + ) + if running and not service.is_running: + logger.warning( + f"Service {service.uuid} is in an unexpected state. " + f"Expected running={running}, but found running={service.is_running}." + ) + continue + services.append(service) + return services @abstractmethod + def perform_stop_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> BaseService: + """Abstract method to stop a model server. + + This operation should be reversible. A stopped model server should still + show up in the list of model servers returned by `find_model_server` and + it should be possible to start it again by calling `start_model_server`. + + Args: + service: The service to stop. + timeout: timeout in seconds to wait for the service to stop. If + set to 0, the method will return immediately after + deprovisioning the service, without waiting for it to stop. + force: if True, force the service to stop. + """ + def stop_model_server( self, uuid: UUID, @@ -226,9 +400,43 @@ def stop_model_server( set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop. force: if True, force the service to stop. + + Raises: + RuntimeError: if the model server is not found. """ + client = Client() + try: + service = self.find_model_server(service_uuid=uuid)[0] + updated_service = self.perform_stop_model(service, timeout, force) + client.update_service( + id=updated_service.uuid, + admin_state=updated_service.admin_state, + status=updated_service.status.dict(), + endpoint=updated_service.endpoint.dict() + if updated_service.endpoint + else None, + ) + except Exception as e: + raise RuntimeError( + f"Failed to stop model server with UUID {uuid}: {e}" + ) from e @abstractmethod + def perform_start_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + ) -> BaseService: + """Abstract method to start a model server. + + Args: + service: The service to start. + timeout: timeout in seconds to wait for the service to start. If + set to 0, the method will return immediately after + provisioning the service, without waiting for it to become + active. + """ + def start_model_server( self, uuid: UUID, @@ -242,9 +450,47 @@ def start_model_server( set to 0, the method will return immediately after provisioning the service, without waiting for it to become active. + + Raises: + RuntimeError: if the model server is not found. """ + client = Client() + try: + service = self.find_model_server(service_uuid=uuid)[0] + updated_service = self.perform_start_model(service, timeout) + client.update_service( + id=updated_service.uuid, + admin_state=updated_service.admin_state, + status=updated_service.status.dict(), + endpoint=updated_service.endpoint.dict() + if updated_service.endpoint + else None, + ) + except Exception as e: + raise RuntimeError( + f"Failed to start model server with UUID {uuid}: {e}" + ) from e @abstractmethod + def perform_delete_model( + self, + service: BaseService, + timeout: int = DEFAULT_DEPLOYMENT_START_STOP_TIMEOUT, + force: bool = False, + ) -> None: + """Abstract method to delete a model server. + + This operation is irreversible. A deleted model server must no longer + show up in the list of model servers returned by `find_model_server`. + + Args: + service: The service to delete. + timeout: timeout in seconds to wait for the service to stop. If + set to 0, the method will return immediately after + deprovisioning the service, without waiting for it to stop. + force: if True, force the service to stop. + """ + def delete_model_server( self, uuid: UUID, @@ -262,7 +508,19 @@ def delete_model_server( set to 0, the method will return immediately after deprovisioning the service, without waiting for it to stop. force: if True, force the service to stop. + + Raises: + RuntimeError: if the model server is not found. """ + client = Client() + try: + service = self.find_model_server(service_uuid=uuid)[0] + self.perform_delete_model(service, timeout, force) + client.delete_service(uuid) + except Exception as e: + raise RuntimeError( + f"Failed to delete model server with UUID {uuid}: {e}" + ) from e def get_model_server_logs( self, @@ -288,32 +546,21 @@ def get_model_server_logs( raise RuntimeError(f"No model server found with UUID {uuid}") return services[0].get_logs(follow=follow, tail=tail) - def get_step_run_metadata( - self, info: "StepRunInfo" - ) -> Dict[str, "MetadataType"]: - """Get component- and step-specific metadata after a step ran. - - For model deployers, this extracts the prediction URL of the deployed - model. + def load_service( + self, + service_id: UUID, + ) -> BaseService: + """Load a service from a URI. Args: - info: Info about the step that was executed. + service_id: The ID of the service to load. Returns: - A dictionary of metadata. + The loaded service. """ - existing_services = self.find_model_server( - run_name=info.run_name, - ) - if existing_services: - existing_service = existing_services[0] - if ( - isinstance(existing_service, BaseDeploymentService) - and existing_service.is_running - ): - deployed_model_url = existing_service.prediction_url - return {METADATA_DEPLOYED_MODEL_URL: Uri(deployed_model_url)} - return {} + client = Client() + service = client.get_service(service_id) + return BaseDeploymentService.from_model(service) class BaseModelDeployerFlavor(Flavor): @@ -341,3 +588,26 @@ def config_class(self) -> Type[BaseModelDeployerConfig]: @abstractmethod def implementation_class(self) -> Type[BaseModelDeployer]: """The class that implements the model deployer.""" + + +def get_model_version_id_if_exists( + model_name: Optional[str], + model_version: Optional[str], +) -> Optional[UUID]: + """Get the model version id if it exists. + + Args: + model_name: The name of the model. + model_version: The version of the model. + + Returns: + The model version id if it exists. + """ + client = Client() + if model_name: + with contextlib.suppress(KeyError): + return client.get_model_version( + model_name_or_id=model_name, + model_version_name_or_number_or_id=model_version, + ).id + return None diff --git a/src/zenml/models/__init__.py b/src/zenml/models/__init__.py index db5ff386a69..9e9480f5797 100644 --- a/src/zenml/models/__init__.py +++ b/src/zenml/models/__init__.py @@ -89,6 +89,15 @@ ArtifactVisualizationResponseBody, ArtifactVisualizationResponseMetadata, ) +from zenml.models.v2.core.service import ( + ServiceResponse, + ServiceResponseBody, + ServiceResponseMetadata, + ServiceUpdate, + ServiceFilter, + ServiceRequest, + ServiceResponseResources, +) from zenml.models.v2.core.code_reference import ( CodeReferenceRequest, CodeReferenceResponse, @@ -157,6 +166,7 @@ ModelVersionResponseMetadata, ModelVersionFilter, ModelVersionUpdate, + ModelVersionResponseResources, ) from zenml.models.v2.core.model_version_artifact import ( ModelVersionArtifactFilter, @@ -402,6 +412,15 @@ FlavorResponseMetadata.update_forward_refs( WorkspaceResponse=WorkspaceResponse, ) +ServiceResponseBody.update_forward_refs( + UserResponse=UserResponse, +) +ServiceResponseMetadata.update_forward_refs( + WorkspaceResponse=WorkspaceResponse, +) +ServiceResponseResources.update_forward_refs( + ModelVersionResponse=ModelVersionResponse, +) ModelResponseBody.update_forward_refs( UserResponse=UserResponse, TagResponse=TagResponse, @@ -418,6 +437,9 @@ WorkspaceResponse=WorkspaceResponse, RunMetadataResponse=RunMetadataResponse, ) +ModelVersionResponseResources.update_forward_refs( + ServiceResponse=ServiceResponse, +) ModelVersionArtifactResponseBody.update_forward_refs( ArtifactVersionResponse=ArtifactVersionResponse, ) @@ -639,6 +661,7 @@ "ModelVersionResponse", "ModelVersionResponseBody", "ModelVersionResponseMetadata", + "ModelVersionResponseResources", "ModelVersionUpdate", "ModelVersionArtifactFilter", "ModelVersionArtifactRequest", @@ -765,6 +788,13 @@ "WorkspaceResponse", "WorkspaceResponseBody", "WorkspaceResponseMetadata", + "ServiceResponse", + "ServiceResponseBody", + "ServiceResponseMetadata", + "ServiceUpdate", + "ServiceFilter", + "ServiceRequest", + "ServiceResponseResources", # V2 Misc "AuthenticationMethodModel", "ServiceConnectorResourcesModel", diff --git a/src/zenml/models/v2/core/model_version.py b/src/zenml/models/v2/core/model_version.py index 4e6dc97c489..04d8da143de 100644 --- a/src/zenml/models/v2/core/model_version.py +++ b/src/zenml/models/v2/core/model_version.py @@ -21,6 +21,7 @@ from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH from zenml.enums import ModelStages from zenml.models.v2.base.filter import AnyQuery +from zenml.models.v2.base.page import Page from zenml.models.v2.base.scoped import ( WorkspaceScopedRequest, WorkspaceScopedResponse, @@ -29,6 +30,7 @@ WorkspaceScopedResponseResources, WorkspaceScopedTaggableFilter, ) +from zenml.models.v2.core.service import ServiceResponse from zenml.models.v2.core.tag import TagResponse if TYPE_CHECKING: @@ -176,6 +178,10 @@ class ModelVersionResponseMetadata(WorkspaceScopedResponseMetadata): class ModelVersionResponseResources(WorkspaceScopedResponseResources): """Class for all resource models associated with the model version entity.""" + services: Page[ServiceResponse] = Field( + description="Services linked to the model version", + ) + class ModelVersionResponse( WorkspaceScopedResponse[ diff --git a/src/zenml/models/v2/core/service.py b/src/zenml/models/v2/core/service.py new file mode 100644 index 00000000000..b1bbc2c8210 --- /dev/null +++ b/src/zenml/models/v2/core/service.py @@ -0,0 +1,479 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Models representing Services.""" + +from datetime import datetime +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Dict, + List, + Optional, + Type, + Union, +) +from uuid import UUID + +from pydantic import BaseModel, Field +from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList +from sqlmodel import SQLModel + +from zenml.constants import STR_FIELD_MAX_LENGTH +from zenml.models.v2.base.scoped import ( + WorkspaceScopedFilter, + WorkspaceScopedRequest, + WorkspaceScopedResponse, + WorkspaceScopedResponseBody, + WorkspaceScopedResponseMetadata, + WorkspaceScopedResponseResources, + WorkspaceScopedTaggableFilter, +) +from zenml.services.service_status import ServiceState +from zenml.services.service_type import ServiceType + +if TYPE_CHECKING: + pass + +# ------------------ Request Model ------------------ + + +class ServiceRequest(WorkspaceScopedRequest): + """Request model for services.""" + + name: str = Field( + title="The name of the service.", + max_length=STR_FIELD_MAX_LENGTH, + ) + + service_type: ServiceType = Field( + title="The type of the service.", + ) + + service_source: Optional[str] = Field( + title="The class of the service.", + description="The fully qualified class name of the service implementation.", + ) + + admin_state: Optional[ServiceState] = Field( + title="The admin state of the service.", + description="The administrative state of the service, e.g., ACTIVE, INACTIVE.", + ) + + config: Dict[str, Any] = Field( + title="The service config.", + description="A dictionary containing configuration parameters for the service.", + ) + + labels: Optional[Dict[str, str]] = Field( + default=None, + title="The service labels.", + ) + + status: Optional[Dict[str, Any]] = Field( + title="The status of the service.", + ) + + endpoint: Optional[Dict[str, Any]] = Field( + default=None, + title="The service endpoint.", + ) + + prediction_url: Optional[str] = Field( + default=None, + title="The service endpoint URL.", + ) + + health_check_url: Optional[str] = Field( + default=None, + title="The service health check URL.", + ) + + model_version_id: Optional[UUID] = Field( + default=None, + title="The model version id linked to the service.", + ) + pipeline_run_id: Optional[Union[UUID, str]] = Field( + default=None, + description="By the event source this trigger is attached to.", + ) + + +# ------------------ Update Model ------------------ + + +class ServiceUpdate(BaseModel): + """Update model for stack components.""" + + name: Optional[str] = Field( + title="The name of the service.", + max_length=STR_FIELD_MAX_LENGTH, + ) + + admin_state: Optional[ServiceState] = Field( + title="The admin state of the service.", + description="The administrative state of the service, e.g., ACTIVE, INACTIVE.", + ) + + service_source: Optional[str] = Field( + title="The class of the service.", + description="The fully qualified class name of the service implementation.", + ) + + status: Optional[Dict[str, Any]] = Field( + title="The status of the service.", + ) + + endpoint: Optional[Dict[str, Any]] = Field( + title="The service endpoint.", + ) + + prediction_url: Optional[str] = Field( + title="The service endpoint URL.", + ) + + health_check_url: Optional[str] = Field( + title="The service health check URL.", + ) + + labels: Optional[Dict[str, str]] = Field( + default=None, + title="The service labels.", + ) + + model_version_id: Optional[UUID] = Field( + default=None, + title="The model version id linked to the service.", + ) + + +# ------------------ Response Model ------------------ + + +class ServiceResponseBody(WorkspaceScopedResponseBody): + """Response body for services.""" + + service_type: ServiceType = Field( + title="The type of the service.", + ) + labels: Optional[Dict[str, str]] = Field( + default=None, + title="The service labels.", + ) + created: datetime = Field( + title="The timestamp when this component was created." + ) + updated: datetime = Field( + title="The timestamp when this component was last updated.", + ) + state: Optional[ServiceState] = Field( + default=None, + title="The current state of the service.", + ) + + +class ServiceResponseMetadata(WorkspaceScopedResponseMetadata): + """Response metadata for services.""" + + service_source: Optional[str] = Field( + title="The class of the service.", + ) + admin_state: Optional[ServiceState] = Field( + title="The admin state of the service.", + ) + config: Dict[str, Any] = Field( + title="The service config.", + ) + status: Optional[Dict[str, Any]] = Field( + title="The status of the service.", + ) + endpoint: Optional[Dict[str, Any]] = Field( + default=None, + title="The service endpoint.", + ) + prediction_url: Optional[str] = Field( + default=None, + title="The service endpoint URL.", + ) + health_check_url: Optional[str] = Field( + default=None, + title="The service health check URL.", + ) + + +class ServiceResponseResources(WorkspaceScopedResponseResources): + """Class for all resource models associated with the service entity.""" + + +class ServiceResponse( + WorkspaceScopedResponse[ + ServiceResponseBody, ServiceResponseMetadata, ServiceResponseResources + ] +): + """Response model for services.""" + + name: str = Field( + title="The name of the service.", + max_length=STR_FIELD_MAX_LENGTH, + ) + + def get_hydrated_version(self) -> "ServiceResponse": + """Get the hydrated version of this artifact. + + Returns: + an instance of the same entity with the metadata field attached. + """ + from zenml.client import Client + + return Client().zen_store.get_service(self.id) + + # Body and metadata properties + + @property + def service_type(self) -> ServiceType: + """The `service_type` property. + + Returns: + the value of the property. + """ + return self.get_body().service_type + + @property + def labels(self) -> Optional[Dict[str, str]]: + """The `labels` property. + + Returns: + the value of the property. + """ + return self.get_body().labels + + @property + def service_source(self) -> Optional[str]: + """The `service_source` property. + + Returns: + the value of the property. + """ + return self.get_metadata().service_source + + @property + def config(self) -> Dict[str, Any]: + """The `config` property. + + Returns: + the value of the property. + """ + return self.get_metadata().config + + @property + def status(self) -> Optional[Dict[str, Any]]: + """The `status` property. + + Returns: + the value of the property. + """ + return self.get_metadata().status + + @property + def endpoint(self) -> Optional[Dict[str, Any]]: + """The `endpoint` property. + + Returns: + the value of the property. + """ + return self.get_metadata().endpoint + + @property + def created(self) -> datetime: + """The `created` property. + + Returns: + the value of the property. + """ + return self.get_body().created + + @property + def updated(self) -> datetime: + """The `updated` property. + + Returns: + the value of the property. + """ + return self.get_body().updated + + @property + def admin_state(self) -> Optional[ServiceState]: + """The `admin_state` property. + + Returns: + the value of the property. + """ + return self.get_metadata().admin_state + + @property + def prediction_url(self) -> Optional[str]: + """The `prediction_url` property. + + Returns: + the value of the property. + """ + return self.get_metadata().prediction_url + + @property + def health_check_url(self) -> Optional[str]: + """The `health_check_url` property. + + Returns: + the value of the property. + """ + return self.get_metadata().health_check_url + + @property + def state(self) -> Optional[ServiceState]: + """The `state` property. + + Returns: + the value of the property. + """ + return self.get_body().state + + +# ------------------ Filter Model ------------------ + + +class ServiceFilter(WorkspaceScopedFilter): + """Model to enable advanced filtering of services. + + The Service needs additional scoping. As such the `_scope_user` field + can be set to the user that is doing the filtering. The + `generate_filter()` method of the baseclass is overwritten to include the + scoping. + """ + + name: Optional[str] = Field( + description="Name of the service. Use this to filter services by their name.", + ) + workspace_id: Optional[Union[UUID, str]] = Field( + default=None, description="Workspace of the service" + ) + user_id: Optional[Union[UUID, str]] = Field( + default=None, description="User of the service" + ) + type: Optional[str] = Field( + default=None, + description="Type of the service. Filter services by their type.", + ) + flavor: Optional[str] = Field( + default=None, + description="Flavor of the service. Use this to filter services by their flavor.", + ) + config: Optional[bytes] = Field( + default=None, + description="Config of the service. Use this to filter services by their config.", + ) + pipeline_name: Optional[str] = Field( + default=None, + description="Pipeline name responsible for deploying the service", + ) + pipeline_step_name: Optional[str] = Field( + default=None, + description="Pipeline step name responsible for deploying the service", + ) + running: Optional[bool] = Field( + default=None, description="Whether the service is running" + ) + model_version_id: Optional[Union[UUID, str]] = Field( + default=None, + description="By the model version this service is attached to.", + ) + pipeline_run_id: Optional[Union[UUID, str]] = Field( + default=None, + description="By the pipeline run this service is attached to.", + ) + + def set_type(self, type: str) -> None: + """Set the type of the service. + + Args: + type: The type of the service. + """ + self.type = type + + def set_flavor(self, flavor: str) -> None: + """Set the flavor of the service. + + Args: + flavor: The flavor of the service. + """ + self.flavor = flavor + + # Artifact name and type are not DB fields and need to be handled separately + FILTER_EXCLUDE_FIELDS = [ + *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS, + "flavor", + "type", + "pipeline_step_name", + "running", + "pipeline_name", + "config", + ] + CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [ + *WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS, + "workspace_id", + "user_id", + "flavor", + "type", + "pipeline_step_name", + "running", + "pipeline_name", + ] + + def generate_filter( + self, table: Type["SQLModel"] + ) -> Union["BinaryExpression[Any]", "BooleanClauseList[Any]"]: + """Generate the filter for the query. + + Services can be scoped by type to narrow the search. + + Args: + table: The Table that is being queried from. + + Returns: + The filter expression for the query. + """ + from sqlalchemy import and_ + + base_filter = super().generate_filter(table) + + if self.type: + type_filter = getattr(table, "type") == self.type + base_filter = and_(base_filter, type_filter) + + if self.flavor: + flavor_filter = getattr(table, "flavor") == self.flavor + base_filter = and_(base_filter, flavor_filter) + + if self.pipeline_name: + pipeline_name_filter = ( + getattr(table, "pipeline_name") == self.pipeline_name + ) + base_filter = and_(base_filter, pipeline_name_filter) + + if self.pipeline_step_name: + pipeline_step_name_filter = ( + getattr(table, "pipeline_step_name") == self.pipeline_step_name + ) + base_filter = and_(base_filter, pipeline_step_name_filter) + + return base_filter diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index c7803f2665c..7ff767327fb 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -748,10 +748,6 @@ def _run( run_id=run.id if run else None, ) - deploy_pipeline( - deployment=deployment_model, stack=stack, placeholder_run=run - ) - if run: run_url = dashboard_utils.get_run_url(run) if run_url: @@ -763,6 +759,10 @@ def _run( "`zenml up`." ) + deploy_pipeline( + deployment=deployment_model, stack=stack, placeholder_run=run + ) + return run @staticmethod diff --git a/src/zenml/new/pipelines/run_utils.py b/src/zenml/new/pipelines/run_utils.py index e98caef2792..2b3d750ba6a 100644 --- a/src/zenml/new/pipelines/run_utils.py +++ b/src/zenml/new/pipelines/run_utils.py @@ -28,6 +28,7 @@ from zenml.new.pipelines.model_utils import NewModelRequest from zenml.orchestrators.utils import get_run_name from zenml.stack import Stack +from zenml.utils import cloud_utils if TYPE_CHECKING: from zenml.config.source import Source @@ -232,6 +233,7 @@ def _validate_new_version_requests( new_versions_requested: A dict of new model version request objects. """ + is_cloud_model = True for key, data in new_versions_requested.items(): model_name, model_version = key if len(data.requesters) > 1: @@ -241,4 +243,12 @@ def _validate_new_version_requests( "that `Model` requesting new version is configured only in one " "place of the pipeline." ) - data.model._validate_config_in_runtime() + model_version_response = data.model._validate_config_in_runtime() + is_cloud_model &= cloud_utils.is_cloud_model_version( + model_version_response + ) + if not is_cloud_model: + logger.info( + "Models can be viewed in the dashboard using ZenML Cloud. Sign up " + "for a free trial at https://www.zenml.io/cloud/" + ) diff --git a/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py b/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py index c5cd80bbc43..cd5d254d442 100644 --- a/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py +++ b/src/zenml/orchestrators/local_docker/local_docker_orchestrator.py @@ -38,7 +38,7 @@ ContainerizedOrchestrator, ) from zenml.stack import Stack, StackValidator -from zenml.utils import string_utils +from zenml.utils import docker_utils, string_utils if TYPE_CHECKING: from zenml.models import PipelineDeploymentResponse @@ -117,9 +117,8 @@ def prepare_or_run_pipeline( "and the pipeline will be run immediately." ) - from docker.client import DockerClient + docker_client = docker_utils._try_get_docker_client_from_env() - docker_client = DockerClient.from_env() entrypoint = StepEntrypointConfiguration.get_entrypoint_command() # Add the local stores path as a volume mount diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 4ab21d96767..b47d0c8aa37 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -44,7 +44,9 @@ from zenml.logger import get_logger from zenml.logging.step_logging import StepLogsStorageContext, redirected from zenml.materializers.base_materializer import BaseMaterializer -from zenml.model.utils import link_step_artifacts_to_model +from zenml.model.utils import ( + link_step_artifacts_to_model, +) from zenml.new.steps.step_context import StepContext, get_step_context from zenml.orchestrators.publish_utils import ( publish_step_run_metadata, diff --git a/src/zenml/service_connectors/docker_service_connector.py b/src/zenml/service_connectors/docker_service_connector.py index 70f6c92f26d..13f9d632035 100644 --- a/src/zenml/service_connectors/docker_service_connector.py +++ b/src/zenml/service_connectors/docker_service_connector.py @@ -37,6 +37,7 @@ AuthenticationConfig, ServiceConnector, ) +from zenml.utils import docker_utils from zenml.utils.enum_utils import StrEnum logger = get_logger(__name__) @@ -258,7 +259,9 @@ def _connect_to_resource( An authenticated python-docker client object. """ assert self.resource_id is not None - docker_client = DockerClient.from_env() + + docker_client = docker_utils._try_get_docker_client_from_env() + self._authorize_client(docker_client, self.resource_id) return docker_client diff --git a/src/zenml/services/__init__.py b/src/zenml/services/__init__.py index 55ef932dc48..95646020d5e 100644 --- a/src/zenml/services/__init__.py +++ b/src/zenml/services/__init__.py @@ -51,7 +51,6 @@ TCPEndpointHealthMonitor, TCPEndpointHealthMonitorConfig, ) -from zenml.services.service_registry import ServiceRegistry from zenml.services.service_status import ServiceState, ServiceStatus from zenml.services.service_type import ServiceType @@ -84,5 +83,4 @@ "LocalDaemonServiceEndpointConfig", "LocalDaemonServiceEndpointStatus", "LocalDaemonServiceEndpoint", - "ServiceRegistry", ] diff --git a/src/zenml/services/container/container_service.py b/src/zenml/services/container/container_service.py index 5c8dcb3b8cb..28089b1bdf9 100644 --- a/src/zenml/services/container/container_service.py +++ b/src/zenml/services/container/container_service.py @@ -33,6 +33,7 @@ ) from zenml.services.service import BaseService, ServiceConfig from zenml.services.service_status import ServiceState, ServiceStatus +from zenml.utils import docker_utils from zenml.utils.io_utils import ( create_dir_recursive_if_not_exists, get_global_config_directory, @@ -177,7 +178,9 @@ def docker_client(self) -> DockerClient: The docker client. """ if self._docker_client is None: - self._docker_client = DockerClient.from_env() + self._docker_client = ( + docker_utils._try_get_docker_client_from_env() + ) return self._docker_client @property diff --git a/src/zenml/services/container/entrypoint.py b/src/zenml/services/container/entrypoint.py index 2f0956a192e..b7476bc19b9 100644 --- a/src/zenml/services/container/entrypoint.py +++ b/src/zenml/services/container/entrypoint.py @@ -19,6 +19,7 @@ import os import sys +from typing import cast import click @@ -50,7 +51,7 @@ def launch_service(service_config_file: str) -> None: # with messages before daemonization is complete from zenml.integrations.registry import integration_registry from zenml.logger import get_logger - from zenml.services import ContainerService, ServiceRegistry + from zenml.services import ContainerService logger = get_logger(__name__) @@ -63,7 +64,7 @@ def launch_service(service_config_file: str) -> None: logger.debug( "Running containerized service with configuration:\n %s", config ) - service = ServiceRegistry().load_service_from_json(config) + service = cast("ContainerService", ContainerService.from_json(config)) if not isinstance(service, ContainerService): raise TypeError( f"Expected service type ContainerService but got " diff --git a/src/zenml/services/local/local_daemon_entrypoint.py b/src/zenml/services/local/local_daemon_entrypoint.py index 3d2cf42f8a3..33d03685cd9 100644 --- a/src/zenml/services/local/local_daemon_entrypoint.py +++ b/src/zenml/services/local/local_daemon_entrypoint.py @@ -18,6 +18,7 @@ """ import os +from typing import cast import click @@ -68,7 +69,7 @@ def launch_service(service_config_file: str) -> None: # with messages before daemonization is complete from zenml.integrations.registry import integration_registry from zenml.logger import get_logger - from zenml.services import LocalDaemonService, ServiceRegistry + from zenml.services import LocalDaemonService logger = get_logger(__name__) @@ -81,7 +82,9 @@ def launch_service(service_config_file: str) -> None: integration_registry.activate_integrations() logger.debug("Running service daemon with configuration:\n %s", config) - service = ServiceRegistry().load_service_from_json(config) + service = cast( + "LocalDaemonService", LocalDaemonService.from_json(config) + ) if not isinstance(service, LocalDaemonService): raise TypeError( f"Expected service type LocalDaemonService but got " diff --git a/src/zenml/services/service.py b/src/zenml/services/service.py index ba2664be586..446cee28709 100644 --- a/src/zenml/services/service.py +++ b/src/zenml/services/service.py @@ -13,10 +13,12 @@ # permissions and limitations under the License. """Implementation of the ZenML Service class.""" +import json import time from abc import abstractmethod from functools import wraps from typing import ( + TYPE_CHECKING, Any, Callable, ClassVar, @@ -26,24 +28,27 @@ Tuple, Type, TypeVar, - cast, ) -from uuid import UUID, uuid4 - -from pydantic import Field +from uuid import UUID from zenml.console import console from zenml.logger import get_logger from zenml.services.service_endpoint import BaseServiceEndpoint -from zenml.services.service_registry import ServiceRegistry +from zenml.services.service_monitor import HTTPEndpointHealthMonitor from zenml.services.service_status import ServiceState, ServiceStatus from zenml.services.service_type import ServiceType -from zenml.utils.typed_model import BaseTypedModel, BaseTypedModelMeta +from zenml.utils import source_utils +from zenml.utils.typed_model import BaseTypedModel logger = get_logger(__name__) T = TypeVar("T", bound=Callable[..., Any]) +if TYPE_CHECKING: + from zenml.models.v2.core.service import ServiceResponse + +ZENM_ENDPOINT_PREFIX = "zenml-" + def update_service_status( pre_status: Optional[ServiceState] = None, @@ -108,107 +113,42 @@ class ServiceConfig(BaseTypedModel): description: str = "" pipeline_name: str = "" pipeline_step_name: str = "" - run_name: str = "" + model_name: str = "" + model_version: str = "" + service_name: str = "" - -class BaseServiceMeta(BaseTypedModelMeta): - """Metaclass responsible for registering different BaseService subclasses. - - This metaclass has two main responsibilities: - 1. register all BaseService types in the service registry. This is relevant - when services are deserialized and instantiated from their JSON or dict - representation, because their type needs to be known beforehand. - 2. ensuring BaseService instance uniqueness by enforcing that no two - service instances have the same UUID value. Implementing this at the - constructor level guarantees that deserializing a service instance from - a JSON representation multiple times always returns the same service object. - """ - - def __new__( - mcs, name: str, bases: Tuple[Type[Any], ...], dct: Dict[str, Any] - ) -> "BaseServiceMeta": - """Creates a BaseService class and registers it in the `ServiceRegistry`. + def __init__(self, **data: Any): + """Initialize the service configuration. Args: - name: name of the class. - bases: tuple of base classes. - dct: dictionary of class attributes. - - Returns: - the created BaseServiceMeta class. + **data: keyword arguments. Raises: - TypeError: if the 'service_type' reserved attribute name is used. + ValueError: if neither 'name' nor 'model_name' is set. """ - service_type = dct.get("SERVICE_TYPE", None) - - # register only classes of concrete service implementations - if service_type: - # add the service type class attribute to the class as a regular - # immutable attribute to include it in the JSON representation - if "service_type" in dct: - raise TypeError( - "`service_type` is a reserved attribute name for BaseService " - "subclasses" - ) - dct.setdefault("__annotations__", dict())["service_type"] = ( - ServiceType + super().__init__(**data) + if self.name or self.model_name: + self.service_name = data.get( + "service_name", + f"{ZENM_ENDPOINT_PREFIX}{self.name or self.model_name}", ) - dct["service_type"] = Field(service_type, allow_mutation=False) - - cls = cast(Type["BaseService"], super().__new__(mcs, name, bases, dct)) - - # register only classes of concrete service implementations - if service_type: - # register the service type in the service registry - ServiceRegistry().register_service_type(cls) - return cls - - def __call__(cls, *args: Any, **kwargs: Any) -> "BaseServiceMeta": - """Validate the creation of a service. + else: + raise ValueError("Either 'name' or 'model_name' must be set.") - Args: - *args: positional arguments. - **kwargs: keyword arguments. + def get_service_labels(self) -> Dict[str, str]: + """Get the service labels. Returns: - the created BaseServiceMeta class. - - Raises: - AttributeError: if the service UUID is untyped. - ValueError: if the service UUID is not a UUID type. + a dictionary of service labels. """ - if not getattr(cls, "SERVICE_TYPE", None): - raise AttributeError( - f"Untyped service instances are not allowed. Please set the " - f"SERVICE_TYPE class attribute for {cls}." - ) - uuid = kwargs.get("uuid", None) - if uuid: - if isinstance(uuid, str): - uuid = UUID(uuid) - if not isinstance(uuid, UUID): - raise ValueError( - f"The `uuid` argument for {cls} must be a UUID instance or a " - f"string representation of a UUID." - ) - - # if a service instance with the same UUID is already registered, - # return the existing instance rather than the newly created one - existing_service = ServiceRegistry().get_service(uuid) - if existing_service: - logger.debug( - f"Reusing existing service '{existing_service}' " - f"instead of creating a new service with the same UUID." - ) - return cast("BaseServiceMeta", existing_service) - - svc = cast("BaseService", super().__call__(*args, **kwargs)) - ServiceRegistry().register_service(svc) - return cast("BaseServiceMeta", svc) + labels = {} + for k, v in self.dict().items(): + label = f"zenml_{k}".upper() + labels[label] = str(v) + return labels -class BaseService(BaseTypedModel, metaclass=BaseServiceMeta): +class BaseService(BaseTypedModel): """Base service class. This class implements generic functionality concerning the life-cycle @@ -227,7 +167,7 @@ class BaseService(BaseTypedModel, metaclass=BaseServiceMeta): SERVICE_TYPE: ClassVar[ServiceType] - uuid: UUID = Field(default_factory=uuid4, allow_mutation=False) + uuid: UUID admin_state: ServiceState = ServiceState.INACTIVE config: ServiceConfig status: ServiceStatus @@ -246,6 +186,49 @@ def __init__( super().__init__(**attrs) self.config.name = self.config.name or self.__class__.__name__ + @classmethod + def from_model(cls, model: "ServiceResponse") -> "BaseService": + """Loads a service from a model. + + Args: + model: The ServiceResponse to load from. + + Returns: + The loaded service object. + + Raises: + ValueError: if the service source is not found in the model. + """ + if not model.service_source: + raise ValueError("Service source not found in the model.") + class_: Type[BaseService] = source_utils.load_and_validate_class( + source=model.service_source, expected_class=BaseService + ) + return class_( + uuid=model.id, + admin_state=model.admin_state, + config=model.config, + status=model.status, + service_type=model.service_type.dict(), + endpoint=model.endpoint, + ) + + @classmethod + def from_json(cls, json_str: str) -> "BaseTypedModel": + """Loads a service from a JSON string. + + Args: + json_str: the JSON string to load from. + + Returns: + The loaded service object. + """ + service_dict = json.loads(json_str) + class_: Type[BaseService] = source_utils.load_and_validate_class( + source=service_dict["type"], expected_class=BaseService + ) + return class_.from_dict(service_dict) + @abstractmethod def check_status(self) -> Tuple[ServiceState, str]: """Check the the current operational state of the external service. @@ -449,19 +432,15 @@ def start(self, timeout: int = 0) -> None: timeout: amount of time to wait for the service to become active. If set to 0, the method will return immediately after checking the service status. - - Raises: - RuntimeError: if the service cannot be started """ with console.status(f"Starting service '{self}'.\n"): self.admin_state = ServiceState.ACTIVE self.provision() - if timeout > 0: - if not self.poll_service_status(timeout): - raise RuntimeError( - f"Failed to start service {self}\n" - + self.get_service_status_message() - ) + if timeout > 0 and not self.poll_service_status(timeout): + logger.error( + f"Failed to start service {self}\n" + + self.get_service_status_message() + ) @update_service_status( pre_status=ServiceState.PENDING_SHUTDOWN, @@ -476,9 +455,6 @@ def stop(self, timeout: int = 0, force: bool = False) -> None: the service status. force: if True, the service will be stopped even if it is not currently running. - - Raises: - RuntimeError: if the service cannot be stopped """ with console.status(f"Stopping service '{self}'.\n"): self.admin_state = ServiceState.INACTIVE @@ -486,12 +462,40 @@ def stop(self, timeout: int = 0, force: bool = False) -> None: if timeout > 0: self.poll_service_status(timeout) if not self.is_stopped: - raise RuntimeError( + logger.error( f"Failed to stop service {self}. Last state: " f"'{self.status.state.value}'. Last error: " f"'{self.status.last_error}'" ) + def get_prediction_url(self) -> Optional[str]: + """Gets the prediction URL for the endpoint. + + Returns: + the prediction URL for the endpoint + """ + prediction_url = None + if isinstance(self, BaseDeploymentService) and self.prediction_url: + prediction_url = self.prediction_url + elif self.endpoint: + prediction_url = ( + self.endpoint.status.uri if self.endpoint.status else None + ) + return prediction_url + + def get_healthcheck_url(self) -> Optional[str]: + """Gets the healthcheck URL for the endpoint. + + Returns: + the healthcheck URL for the endpoint + """ + return ( + self.endpoint.monitor.get_healthcheck_uri(self.endpoint) + if (self.endpoint and self.endpoint.monitor) + and isinstance(self.endpoint.monitor, HTTPEndpointHealthMonitor) + else None + ) + def __repr__(self) -> str: """String representation of the service. @@ -529,3 +533,12 @@ def prediction_url(self) -> Optional[str]: the prediction URL for the endpoint """ return None + + @property + def healthcheck_url(self) -> Optional[str]: + """Gets the healthcheck URL for the endpoint. + + Returns: + the healthcheck URL for the endpoint + """ + return None diff --git a/src/zenml/services/service_registry.py b/src/zenml/services/service_registry.py deleted file mode 100644 index c88cdab8b5b..00000000000 --- a/src/zenml/services/service_registry.py +++ /dev/null @@ -1,214 +0,0 @@ -# Copyright (c) ZenML GmbH 2022. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at: -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing -# permissions and limitations under the License. -"""Implementation of the ZenML service registry.""" - -import json -from typing import TYPE_CHECKING, Any, Dict, Optional, Type, cast -from uuid import UUID - -from zenml.logger import get_logger -from zenml.services.service_type import ServiceType -from zenml.utils.singleton import SingletonMetaClass - -logger = get_logger(__name__) - -if TYPE_CHECKING: - from zenml.services.service import BaseService - - -class ServiceRegistry(metaclass=SingletonMetaClass): - """Registry of service types and service instances. - - The service registry provides a central place to register service types - as well as service instances. - """ - - def __init__(self) -> None: - """Initialize the service registry.""" - self.service_types: Dict[ServiceType, Type["BaseService"]] = {} - self.services: Dict[UUID, "BaseService"] = {} - - def register_service_type(self, cls: Type["BaseService"]) -> None: - """Registers a new service type. - - Args: - cls: a BaseService subclass. - - Raises: - TypeError: if the service type is already registered. - """ - service_type = cls.SERVICE_TYPE - if service_type not in self.service_types: - self.service_types[service_type] = cls - logger.debug( - f"Registered service class {cls} for " - f"service type `{service_type}`" - ) - else: - raise TypeError( - f"Found existing service type for {service_type}: " - f"{self.service_types[service_type]}. Skipping registration " - f"of {cls}." - ) - - def get_service_type( - self, service_type: ServiceType - ) -> Optional[Type["BaseService"]]: - """Get the service class registered for a service type. - - Args: - service_type: service type. - - Returns: - `BaseService` subclass that was registered for the service type or - None, if no service class was registered for the service type. - """ - return self.service_types.get(service_type) - - def get_service_types( - self, - ) -> Dict[ServiceType, Type["BaseService"]]: - """Get all registered service types. - - Returns: - Dictionary of service types indexed by their service type. - """ - return self.service_types.copy() - - def service_type_is_registered(self, service_type: ServiceType) -> bool: - """Check if a service type is registered. - - Args: - service_type: service type. - - Returns: - True, if a service type is registered for the service type, False - otherwise. - """ - return service_type in self.service_types - - def register_service(self, service: "BaseService") -> None: - """Registers a new service instance. - - Args: - service: a BaseService instance. - - Raises: - TypeError: if the service instance has a service type that is not - registered. - Exception: if a preexisting service is found for that UUID. - """ - service_type = service.SERVICE_TYPE - if service_type not in self.service_types: - raise TypeError( - f"Service type `{service_type}` is not registered." - ) - - if service.uuid not in self.services: - self.services[service.uuid] = service - logger.debug(f"Registered service {service}") - else: - existing_service = self.services[service.uuid] - raise Exception( - f"Found existing service {existing_service} for UUID: " - f"{service.uuid}. Skipping registration for service " - f"{service}." - ) - - def get_service(self, uuid: UUID) -> Optional["BaseService"]: - """Get the service instance registered for a UUID. - - Args: - uuid: service instance identifier. - - Returns: - `BaseService` instance that was registered for the UUID or - None, if no matching service instance was found. - """ - return self.services.get(uuid) - - def get_services(self) -> Dict[UUID, "BaseService"]: - """Get all service instances currently registered. - - Returns: - Dictionary of `BaseService` instances indexed by their UUID with - all services that are currently registered. - """ - return self.services.copy() - - def service_is_registered(self, uuid: UUID) -> bool: - """Check if a service instance is registered. - - Args: - uuid: service instance identifier. - - Returns: - True, if a service instance is registered for the UUID, False - otherwise. - """ - return uuid in self.services - - def load_service_from_dict( - self, service_dict: Dict[str, Any] - ) -> "BaseService": - """Load a service instance from its dict representation. - - Creates, registers and returns a service instantiated from the dict - representation of the service configuration and last known status - information. - - If an existing service instance with the same UUID is already - present in the service registry, it is returned instead. - - Args: - service_dict: dict representation of the service configuration and - last known status - - Returns: - A new or existing ZenML service instance. - - Raises: - TypeError: if the service type is not registered. - ValueError: if the service type is not valid. - """ - service_type = service_dict.get("service_type") - if not service_type: - raise ValueError( - "Service type not present in the service dictionary" - ) - service_type = ServiceType.parse_obj(service_type) - service_class = self.get_service_type(service_type) - if not service_class: - raise TypeError( - f"Cannot load service with unregistered service " - f"type: {service_type}" - ) - service = cast("BaseService", service_class.from_dict(service_dict)) - return service - - def load_service_from_json(self, json_str: str) -> "BaseService": - """Load a service instance from its JSON representation. - - Creates and returns a service instantiated from the JSON serialized - service configuration and last known status information. - - Args: - json_str: JSON string representation of the service configuration - and last known status - - Returns: - A ZenML service instance. - """ - service_dict = json.loads(json_str) - return self.load_service_from_dict(service_dict) diff --git a/src/zenml/services/service_status.py b/src/zenml/services/service_status.py index 368a96f90f0..fc21e3f328e 100644 --- a/src/zenml/services/service_status.py +++ b/src/zenml/services/service_status.py @@ -25,11 +25,12 @@ class ServiceState(StrEnum): """Possible states for the service and service endpoint.""" + INACTIVE = "inactive" ACTIVE = "active" PENDING_STARTUP = "pending_startup" - INACTIVE = "inactive" PENDING_SHUTDOWN = "pending_shutdown" ERROR = "error" + SCALED_TO_ZERO = "scaled_to_zero" class ServiceStatus(BaseTypedModel): diff --git a/src/zenml/services/service_type.py b/src/zenml/services/service_type.py index 8942c87bbda..a83539d336d 100644 --- a/src/zenml/services/service_type.py +++ b/src/zenml/services/service_type.py @@ -24,12 +24,14 @@ class ServiceType(BaseModel): flavor: service flavor name: name of the service type description: description of the service type + logo_url: logo of the service type """ type: str flavor: str name: str = "" description: str = "" + logo_url: str = "" class Config: """Pydantic configuration class.""" diff --git a/src/zenml/utils/cloud_utils.py b/src/zenml/utils/cloud_utils.py new file mode 100644 index 00000000000..cad6b9dcb98 --- /dev/null +++ b/src/zenml/utils/cloud_utils.py @@ -0,0 +1,40 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Utilities for ZenML Cloud.""" + +from zenml.logger import get_logger +from zenml.models.v2.core.model_version import ModelVersionResponse +from zenml.utils.dashboard_utils import get_model_version_url + +logger = get_logger(__name__) + + +def is_cloud_model_version(model_version: ModelVersionResponse) -> bool: + """Check if a model version is from a ZenML Cloud server. + + Args: + model_version: The model version to check. + + Returns: + True if the model version is from a ZenML Cloud server, else False. + """ + model_version_url = get_model_version_url(model_version.id) + if model_version_url: + logger.info( + f"Dashboard URL for Model Version with name {model_version.name} " + f": {model_version_url}" + ) + return True + else: + return False diff --git a/src/zenml/utils/dashboard_utils.py b/src/zenml/utils/dashboard_utils.py index 172dfc5805b..23b59bdc3cc 100644 --- a/src/zenml/utils/dashboard_utils.py +++ b/src/zenml/utils/dashboard_utils.py @@ -14,6 +14,7 @@ """Utility class to help with interacting with the dashboard.""" from typing import Optional +from uuid import UUID from zenml import constants from zenml.client import Client @@ -34,6 +35,17 @@ def get_base_url() -> Optional[str]: client = Client() if client.zen_store.type == StoreType.REST: + # if the server config has a base URL use that + server_model = client.zen_store.get_store_info() + if server_model.base_url: + url = server_model.base_url + # if the base url has cloud.zenml.io in it, then it is a cloud + # deployment and there isn't a workspace in the URL + if "cloud.zenml.io" in url: + return url + return ( + url + f"{constants.WORKSPACES}/{client.active_workspace.name}" + ) url = ( client.zen_store.url + f"{constants.WORKSPACES}/{client.active_workspace.name}" @@ -85,8 +97,13 @@ def get_run_url(run: PipelineRunResponse) -> Optional[str]: Returns: the URL to the pipeline run if the dashboard is available, else None. """ + client = Client() base_url = get_base_url() if base_url: + server_model = client.zen_store.get_store_info() + # if the server is a zenml cloud tenant, use a different URL + if server_model.metadata.get("organization_id"): + return f"{base_url}{constants.RUNS}/{run.id}" if run.pipeline: return f"{base_url}{constants.PIPELINES}/{run.pipeline.id}{constants.RUNS}/{run.id}/dag" else: @@ -94,6 +111,28 @@ def get_run_url(run: PipelineRunResponse) -> Optional[str]: return None +def get_model_version_url(model_version_id: UUID) -> Optional[str]: + """Function to get the dashboard URL of a given model version. + + Args: + model_version_id: the id of the model version. + + Returns: + the URL to the model version if the dashboard is available, else None. + """ + client = Client() + server_model = client.zen_store.get_store_info() + # if organization_id exists as key in server_config.metadata + # only then output a URL. + if server_model.metadata.get("organization_id"): + base_url = get_base_url() + if base_url: + # TODO MODEL_VERSIONS resolves to /model_versions but on the + # cloud, the URL is /model-versions. This should be fixed? + return f"{base_url}/model-versions/{str(model_version_id)}" + return None + + def show_dashboard(url: str) -> None: """Show the ZenML dashboard at the given URL. diff --git a/src/zenml/utils/dict_utils.py b/src/zenml/utils/dict_utils.py index 5c14b548968..fe5e9fb6dfe 100644 --- a/src/zenml/utils/dict_utils.py +++ b/src/zenml/utils/dict_utils.py @@ -13,8 +13,12 @@ # permissions and limitations under the License. """Util functions for dictionaries.""" +import base64 +import json from typing import Any, Dict +from pydantic.json import pydantic_encoder + def recursive_update( original: Dict[str, Any], update: Dict[str, Any] @@ -69,3 +73,21 @@ def _maybe_recurse(value: Any) -> Any: return value return {k: _maybe_recurse(v) for k, v in dict_.items() if v is not None} + + +def dict_to_bytes(dict_: Dict[str, Any]) -> bytes: + """Converts a dictionary to bytes. + + Args: + dict_: The dictionary to convert. + + Returns: + The dictionary as bytes. + """ + return base64.b64encode( + json.dumps( + dict_, + sort_keys=False, + default=pydantic_encoder, + ).encode("utf-8") + ) diff --git a/src/zenml/utils/docker_utils.py b/src/zenml/utils/docker_utils.py index 225a1ab5e09..4b8097542dc 100644 --- a/src/zenml/utils/docker_utils.py +++ b/src/zenml/utils/docker_utils.py @@ -29,6 +29,7 @@ ) from docker.client import DockerClient +from docker.errors import DockerException from docker.utils import build as docker_build_utils from zenml.io import fileio @@ -227,7 +228,8 @@ def build_image( logger.info("Building the image might take a while...") - docker_client = DockerClient.from_env() + docker_client = _try_get_docker_client_from_env() + # We use the client api directly here, so we can stream the logs output_stream = docker_client.images.client.api.build( fileobj=build_context, @@ -258,7 +260,7 @@ def push_image( RuntimeError: If fetching the repository digest of the image failed. """ logger.info("Pushing Docker image `%s`.", image_name) - docker_client = docker_client or DockerClient.from_env() + docker_client = _try_get_docker_client_from_env() output_stream = docker_client.images.push(image_name, stream=True) aux_info = _process_stream(output_stream) logger.info("Finished pushing Docker image.") @@ -283,7 +285,7 @@ def tag_image(image_name: str, target: str) -> None: image_name: The name of the image to tag. target: The full target name including a tag. """ - docker_client = DockerClient.from_env() + docker_client = _try_get_docker_client_from_env() image = docker_client.images.get(image_name) image.tag(target) @@ -298,7 +300,8 @@ def get_image_digest(image_name: str) -> Optional[str]: Returns the repo digest for the given image if there exists exactly one. If there are zero or multiple repo digests, returns `None`. """ - docker_client = DockerClient.from_env() + docker_client = _try_get_docker_client_from_env() + image = docker_client.images.get(image_name) repo_digests = image.attrs["RepoDigests"] if len(repo_digests) == 1: @@ -321,7 +324,7 @@ def is_local_image(image_name: str) -> bool: Returns: `True` if the image was pulled from a registry, `False` otherwise. """ - docker_client = DockerClient.from_env() + docker_client = _try_get_docker_client_from_env() images = docker_client.images.list(name=image_name) if images: # An image with this name is available locally -> now check whether it @@ -333,6 +336,23 @@ def is_local_image(image_name: str) -> bool: return False +def _try_get_docker_client_from_env() -> DockerClient: + """Tries to create a Docker client from the environment. + + Raises: + RuntimeError: If creating a Docker client from the environment failed. + + Returns: + A Docker client created from the environment. + """ + try: + return DockerClient.from_env() + except DockerException as e: + raise RuntimeError( + "Could not create a Docker client from the environment. Is your Docker daemon running?" + ) from e + + def _process_stream(stream: Iterable[bytes]) -> List[Dict[str, Any]]: """Processes the output stream of a docker command call. diff --git a/src/zenml/utils/pipeline_docker_image_builder.py b/src/zenml/utils/pipeline_docker_image_builder.py index 32de37d04a5..88a090b8599 100644 --- a/src/zenml/utils/pipeline_docker_image_builder.py +++ b/src/zenml/utils/pipeline_docker_image_builder.py @@ -626,25 +626,24 @@ def _generate_zenml_pipeline_dockerfile( f"--no-install-recommends {apt_packages}" ) + if ( + docker_settings.python_package_installer + == PythonPackageInstaller.PIP + ): + install_command = "pip install --default-timeout=60" + elif ( + docker_settings.python_package_installer + == PythonPackageInstaller.UV + ): + lines.append("RUN pip install uv") + install_command = "uv pip install --system" + else: + raise ValueError("Unsupported python package installer.") + for file, _, options in requirements_files: lines.append(f"COPY {file} .") - option_string = " ".join(options) - if ( - docker_settings.python_package_installer - == PythonPackageInstaller.PIP - ): - install_command = "pip install --default-timeout=60" - elif ( - docker_settings.python_package_installer - == PythonPackageInstaller.UV - ): - lines.append("RUN pip install uv") - install_command = "uv pip install --system" - else: - raise ValueError("Unsupported python package installer.") - lines.append( f"RUN {install_command} --no-cache-dir " f"{option_string} -r {file}" diff --git a/src/zenml/utils/source_utils.py b/src/zenml/utils/source_utils.py index ea1fcfe6d0d..ba9e9e3e91f 100644 --- a/src/zenml/utils/source_utils.py +++ b/src/zenml/utils/source_utils.py @@ -231,7 +231,9 @@ def get_source_root() -> str: raise RuntimeError( "Unable to determine source root because the main module does not " "have an associated file. This could be because you're running in " - "an interactive Python environment." + "an interactive Python environment. If you are trying to run from " + "within a Jupyter notebook, please run `zenml init` from the root " + "where your notebook is located and restart your notebook server. " ) path = Path(main_module.__file__).resolve().parent diff --git a/src/zenml/zen_server/cloud_utils.py b/src/zenml/zen_server/cloud_utils.py new file mode 100644 index 00000000000..eabac1396de --- /dev/null +++ b/src/zenml/zen_server/cloud_utils.py @@ -0,0 +1,201 @@ +"""Utils concerning anything concerning the cloud control plane backend.""" + +import os +from typing import Any, Dict, Optional + +import requests +from pydantic import BaseModel, validator +from requests.adapters import HTTPAdapter, Retry + +from zenml.exceptions import SubscriptionUpgradeRequiredError + +ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_" + + +class ZenMLCloudConfiguration(BaseModel): + """ZenML Cloud RBAC configuration.""" + + api_url: str + + oauth2_client_id: str + oauth2_client_secret: str + oauth2_audience: str + auth0_domain: str + + @validator("api_url") + def _strip_trailing_slashes_url(cls, url: str) -> str: + """Strip any trailing slashes on the API URL. + + Args: + url: The API URL. + + Returns: + The API URL with potential trailing slashes removed. + """ + return url.rstrip("/") + + @classmethod + def from_environment(cls) -> "ZenMLCloudConfiguration": + """Get the RBAC configuration from environment variables. + + Returns: + The RBAC configuration. + """ + env_config: Dict[str, Any] = {} + for k, v in os.environ.items(): + if v == "": + continue + if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX): + env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v + + return ZenMLCloudConfiguration(**env_config) + + class Config: + """Pydantic configuration class.""" + + # Allow extra attributes from configs of previous ZenML versions to + # permit downgrading + extra = "allow" + + +class ZenMLCloudSession: + """Class to use for communication between server and control plane.""" + + def __init__(self) -> None: + """Initialize the RBAC component.""" + self._config = ZenMLCloudConfiguration.from_environment() + self._session: Optional[requests.Session] = None + + def _get( + self, endpoint: str, params: Optional[Dict[str, Any]] + ) -> requests.Response: + """Send a GET request using the active session. + + Args: + endpoint: The endpoint to send the request to. This will be appended + to the base URL. + params: Parameters to include in the request. + + Raises: + RuntimeError: If the request failed. + SubscriptionUpgradeRequiredError: In case the current subscription + tier is insufficient for the attempted operation. + + Returns: + The response. + """ + url = self._config.api_url + endpoint + + response = self.session.get(url=url, params=params, timeout=7) + if response.status_code == 401: + # Refresh the auth token and try again + self._clear_session() + response = self.session.get(url=url, params=params, timeout=7) + + try: + response.raise_for_status() + except requests.HTTPError: + if response.status_code == 402: + raise SubscriptionUpgradeRequiredError(response.json()) + else: + raise RuntimeError( + f"Failed with the following error {response.json()}" + ) + + return response + + def _post( + self, + endpoint: str, + params: Optional[Dict[str, Any]] = None, + data: Optional[Dict[str, Any]] = None, + ) -> requests.Response: + """Send a POST request using the active session. + + Args: + endpoint: The endpoint to send the request to. This will be appended + to the base URL. + params: Parameters to include in the request. + data: Data to include in the request. + + Raises: + RuntimeError: If the request failed. + + Returns: + The response. + """ + url = self._config.api_url + endpoint + + response = self.session.post( + url=url, params=params, json=data, timeout=7 + ) + if response.status_code == 401: + # Refresh the auth token and try again + self._clear_session() + response = self.session.post( + url=url, params=params, json=data, timeout=7 + ) + + try: + response.raise_for_status() + except requests.HTTPError as e: + raise RuntimeError( + f"Failed while trying to contact the central zenml cloud " + f"service: {e}" + ) + + return response + + @property + def session(self) -> requests.Session: + """Authenticate to the ZenML Cloud API. + + Returns: + A requests session with the authentication token. + """ + if self._session is None: + self._session = requests.Session() + token = self._fetch_auth_token() + self._session.headers.update({"Authorization": "Bearer " + token}) + + retries = Retry(total=5, backoff_factor=0.1) + self._session.mount("https://", HTTPAdapter(max_retries=retries)) + + return self._session + + def _clear_session(self) -> None: + """Clear the authentication session.""" + self._session = None + + def _fetch_auth_token(self) -> str: + """Fetch an auth token for the Cloud API from auth0. + + Raises: + RuntimeError: If the auth token can't be fetched. + + Returns: + Auth token. + """ + # Get an auth token from auth0 + auth0_url = f"https://{self._config.auth0_domain}/oauth/token" + headers = {"content-type": "application/x-www-form-urlencoded"} + payload = { + "client_id": self._config.oauth2_client_id, + "client_secret": self._config.oauth2_client_secret, + "audience": self._config.oauth2_audience, + "grant_type": "client_credentials", + } + try: + response = requests.post( + auth0_url, headers=headers, data=payload, timeout=7 + ) + response.raise_for_status() + except Exception as e: + raise RuntimeError(f"Error fetching auth token from auth0: {e}") + + access_token = response.json().get("access_token", "") + + if not access_token or not isinstance(access_token, str): + raise RuntimeError("Could not fetch auth token from auth0.") + + return str(access_token) diff --git a/src/zenml/zen_server/deploy/docker/docker_provider.py b/src/zenml/zen_server/deploy/docker/docker_provider.py index aae7060bc96..2353ecf30ba 100644 --- a/src/zenml/zen_server/deploy/docker/docker_provider.py +++ b/src/zenml/zen_server/deploy/docker/docker_provider.py @@ -15,6 +15,7 @@ import shutil from typing import ClassVar, List, Optional, Tuple, Type, cast +from uuid import uuid4 from zenml.enums import ServerProviderType from zenml.logger import get_logger @@ -131,7 +132,9 @@ def _create_service( config=monitor_cfg, ), ) - service = DockerZenServer(config=service_config, endpoint=endpoint) + service = DockerZenServer( + uuid=uuid4(), config=service_config, endpoint=endpoint + ) service.start(timeout=timeout) return service diff --git a/src/zenml/zen_server/deploy/docker/docker_zen_server.py b/src/zenml/zen_server/deploy/docker/docker_zen_server.py index 188aed6f15f..58c02165833 100644 --- a/src/zenml/zen_server/deploy/docker/docker_zen_server.py +++ b/src/zenml/zen_server/deploy/docker/docker_zen_server.py @@ -132,14 +132,11 @@ def get_service(cls) -> Optional["DockerZenServer"]: The docker ZenML server service or None, if the docker server deployment is not found. """ - from zenml.services import ServiceRegistry - config_filename = os.path.join(cls.config_path(), "service.json") try: with open(config_filename, "r") as f: return cast( - DockerZenServer, - ServiceRegistry().load_service_from_json(f.read()), + "DockerZenServer", DockerZenServer.from_json(f.read()) ) except FileNotFoundError: return None diff --git a/src/zenml/zen_server/deploy/helm/Chart.yaml b/src/zenml/zen_server/deploy/helm/Chart.yaml index 673505615ab..e6bc01d1a2b 100644 --- a/src/zenml/zen_server/deploy/helm/Chart.yaml +++ b/src/zenml/zen_server/deploy/helm/Chart.yaml @@ -1,6 +1,6 @@ apiVersion: v2 name: zenml -version: "0.55.5" +version: "0.56.2" description: Open source MLOps framework for portable production ready ML pipelines keywords: - mlops diff --git a/src/zenml/zen_server/deploy/helm/README.md b/src/zenml/zen_server/deploy/helm/README.md index 2b228e3f33e..0f678b869bf 100644 --- a/src/zenml/zen_server/deploy/helm/README.md +++ b/src/zenml/zen_server/deploy/helm/README.md @@ -20,8 +20,8 @@ ZenML is an open-source MLOps framework designed to help you create robust, main To install the ZenML chart directly from Amazon ECR, use the following command: ```bash -# example command for version 0.55.5 -helm install my-zenml oci://public.ecr.aws/zenml/zenml --version 0.55.5 +# example command for version 0.56.2 +helm install my-zenml oci://public.ecr.aws/zenml/zenml --version 0.56.2 ``` Note: Ensure you have OCI support enabled in your Helm client and that you are authenticated with Amazon ECR. diff --git a/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml b/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml index 9af1688e6ea..e71d4fd1d49 100644 --- a/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml +++ b/src/zenml/zen_server/deploy/helm/templates/server-db-job.yaml @@ -110,16 +110,16 @@ spec: envFrom: - secretRef: name: {{ include "zenml.fullname" . }}-db-migration - {{- with .Values.resources }} - resources: - {{- toYaml . | nindent 12 }} + {{- with .Values.resources }} + resources: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- with .Values.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} {{- end }} - {{- with .Values.tolerations }} - tolerations: - {{- toYaml . | nindent 8 }} - {{- end }} - {{- with .Values.nodeSelector }} - nodeSelector: - {{- toYaml . | nindent 8 }} - {{- end }} {{- end }} \ No newline at end of file diff --git a/src/zenml/zen_server/deploy/local/local_provider.py b/src/zenml/zen_server/deploy/local/local_provider.py index 3d8a9b8fe45..b380017ae7e 100644 --- a/src/zenml/zen_server/deploy/local/local_provider.py +++ b/src/zenml/zen_server/deploy/local/local_provider.py @@ -15,6 +15,7 @@ import shutil from typing import ClassVar, List, Optional, Tuple, Type, cast +from uuid import uuid4 from zenml import __version__ from zenml.enums import ServerProviderType @@ -61,6 +62,7 @@ def check_local_server_dependencies() -> None: try: # Make sure the ZenML Server dependencies are installed import fastapi # noqa + import fastapi_utils # noqa import jwt # noqa import multipart # noqa import uvicorn # noqa @@ -92,7 +94,6 @@ def _get_service_configuration( The service, service endpoint and endpoint monitor configuration. """ assert isinstance(server_config, LocalServerDeploymentConfig) - return ( LocalZenServerConfig( root_runtime_path=LocalZenServer.config_path(), @@ -156,7 +157,9 @@ def _create_service( config=monitor_cfg, ), ) - service = LocalZenServer(config=service_config, endpoint=endpoint) + service = LocalZenServer( + uuid=uuid4(), config=service_config, endpoint=endpoint + ) service.start(timeout=timeout) return service diff --git a/src/zenml/zen_server/deploy/local/local_zen_server.py b/src/zenml/zen_server/deploy/local/local_zen_server.py index 6425b2829bc..8f5041d9de1 100644 --- a/src/zenml/zen_server/deploy/local/local_zen_server.py +++ b/src/zenml/zen_server/deploy/local/local_zen_server.py @@ -127,14 +127,11 @@ def get_service(cls) -> Optional["LocalZenServer"]: The local ZenML server service or None, if the local server deployment is not found. """ - from zenml.services import ServiceRegistry - config_filename = os.path.join(cls.config_path(), "service.json") try: with open(config_filename, "r") as f: return cast( - LocalZenServer, - ServiceRegistry().load_service_from_json(f.read()), + "LocalZenServer", LocalZenServer.from_json(f.read()) ) except FileNotFoundError: return None diff --git a/src/zenml/zen_server/deploy/terraform/providers/terraform_provider.py b/src/zenml/zen_server/deploy/terraform/providers/terraform_provider.py index 0215e7d929c..7f25e4fb87d 100644 --- a/src/zenml/zen_server/deploy/terraform/providers/terraform_provider.py +++ b/src/zenml/zen_server/deploy/terraform/providers/terraform_provider.py @@ -15,6 +15,7 @@ import os from typing import ClassVar, List, Optional, Tuple, Type, cast +from uuid import uuid4 from zenml.config.global_config import GlobalConfiguration from zenml.logger import get_logger @@ -153,7 +154,7 @@ def _create_service( monitor_cfg, ) = self._get_service_configuration(config) - service = TerraformZenServer(config=service_config) + service = TerraformZenServer(uuid=uuid4(), config=service_config) service.start(timeout=timeout) return service diff --git a/src/zenml/zen_server/deploy/terraform/terraform_zen_server.py b/src/zenml/zen_server/deploy/terraform/terraform_zen_server.py index 1b1441ddaf0..61b838afdd9 100644 --- a/src/zenml/zen_server/deploy/terraform/terraform_zen_server.py +++ b/src/zenml/zen_server/deploy/terraform/terraform_zen_server.py @@ -184,13 +184,10 @@ def get_service(cls) -> Optional["TerraformZenServer"]: The terraform ZenML server service or None, if the terraform server deployment is not found. """ - from zenml.services import ServiceRegistry - try: with open(TERRAFORM_ZENML_SERVER_CONFIG_FILENAME, "r") as f: return cast( - TerraformZenServer, - ServiceRegistry().load_service_from_json(f.read()), + TerraformZenServer, TerraformZenServer.from_json(f.read()) ) except FileNotFoundError: return None diff --git a/src/zenml/zen_server/exceptions.py b/src/zenml/zen_server/exceptions.py index 31d3464d82d..0a3d379fc93 100644 --- a/src/zenml/zen_server/exceptions.py +++ b/src/zenml/zen_server/exceptions.py @@ -27,6 +27,7 @@ SecretExistsError, StackComponentExistsError, StackExistsError, + SubscriptionUpgradeRequiredError, ValidationError, ZenKeyError, ) @@ -77,6 +78,8 @@ class ErrorModel(BaseModel): (IllegalOperationError, 403), # 401 Unauthorized (AuthorizationException, 401), + # 402 Payment required + (SubscriptionUpgradeRequiredError, 402), # 404 Not Found (DoesNotExistException, 404), (ZenKeyError, 404), diff --git a/src/zenml/zen_server/feature_gate/__init__.py b/src/zenml/zen_server/feature_gate/__init__.py new file mode 100644 index 00000000000..b6bdfa91873 --- /dev/null +++ b/src/zenml/zen_server/feature_gate/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. \ No newline at end of file diff --git a/src/zenml/zen_server/feature_gate/endpoint_utils.py b/src/zenml/zen_server/feature_gate/endpoint_utils.py new file mode 100644 index 00000000000..3b509e9a494 --- /dev/null +++ b/src/zenml/zen_server/feature_gate/endpoint_utils.py @@ -0,0 +1,59 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""All endpoint utils for the feature gate implementations.""" + +from uuid import UUID + +from zenml.zen_server.rbac.models import ResourceType +from zenml.zen_server.utils import feature_gate, server_config + + +def check_entitlement(resource_type: ResourceType) -> None: + """Queries the feature gate to see if the operation falls within the tenants entitlements. + + Raises an exception if the user is not entitled to create an instance of the + resource. Otherwise, simply returns. + + Args: + resource_type: The type of resource to check for. + """ + if not server_config().feature_gate_enabled: + return + return feature_gate().check_entitlement(resource=resource_type) + + +def report_usage(resource_type: ResourceType, resource_id: UUID) -> None: + """Reports the creation/usage of a feature/resource. + + Args: + resource_type: The type of resource to report a usage for + resource_id: ID of the resource that was created. + """ + if not server_config().feature_gate_enabled: + return + feature_gate().report_event( + resource=resource_type, resource_id=resource_id + ) + + +def report_decrement(resource_type: ResourceType, resource_id: UUID) -> None: + """Reports the deletion/deactivation of a feature/resource. + + Args: + resource_type: The type of resource to report a decrement in count for. + resource_id: ID of the resource that was deleted. + """ + feature_gate().report_event( + resource=resource_type, resource_id=resource_id, is_decrement=True + ) diff --git a/src/zenml/zen_server/feature_gate/feature_gate_interface.py b/src/zenml/zen_server/feature_gate/feature_gate_interface.py new file mode 100644 index 00000000000..df4a5d3fc70 --- /dev/null +++ b/src/zenml/zen_server/feature_gate/feature_gate_interface.py @@ -0,0 +1,49 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Definition of the feature gate interface.""" + +from abc import ABC, abstractmethod +from uuid import UUID + +from zenml.zen_server.rbac.models import ResourceType + + +class FeatureGateInterface(ABC): + """RBAC interface definition.""" + + @abstractmethod + def check_entitlement(self, resource: ResourceType) -> None: + """Checks if a user is entitled to create a resource. + + Args: + resource: The resource the user wants to create + + Raises: + UpgradeRequiredError in case a subscription limit is reached + """ + + @abstractmethod + def report_event( + self, + resource: ResourceType, + resource_id: UUID, + is_decrement: bool = False, + ) -> None: + """Reports the usage of a feature to the aggregator backend. + + Args: + resource: The resource the user created + resource_id: ID of the resource that was created/deleted. + is_decrement: In case this event reports an actual decrement of usage + """ diff --git a/src/zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py b/src/zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py new file mode 100644 index 00000000000..f928539ad4b --- /dev/null +++ b/src/zenml/zen_server/feature_gate/zenml_cloud_feature_gate.py @@ -0,0 +1,119 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""ZenML Cloud implementation of the feature gate.""" + +from typing import Any, Dict +from uuid import UUID + +from pydantic import BaseModel, Field + +from zenml.config.server_config import ServerConfiguration +from zenml.exceptions import SubscriptionUpgradeRequiredError +from zenml.logger import get_logger +from zenml.zen_server.cloud_utils import ZenMLCloudSession +from zenml.zen_server.feature_gate.feature_gate_interface import ( + FeatureGateInterface, +) +from zenml.zen_server.rbac.models import ResourceType + +logger = get_logger(__name__) + +server_config = ServerConfiguration.get_server_config() + +ORGANIZATION_ID = server_config.metadata.get("organization_id", "unknown") + +USAGE_EVENT_ENDPOINT = "/usage-event" +ENTITLEMENT_ENDPOINT = f"/organizations/{ORGANIZATION_ID}/entitlement" + + +class RawUsageEvent(BaseModel): + """Model for reporting raw usage of a feature. + + In case of consumables the UsageReport allows the Pricing Backend to + increment the usage per time-frame by 1. + """ + + organization_id: str = Field( + description="The organization that this usage can be attributed to.", + ) + feature: ResourceType = Field( + description="The feature whose usage is being reported.", + ) + total: int = Field( + description="The total amount of entities of this type." + ) + metadata: Dict[str, Any] = Field( + default={}, + description="Allows attaching additional metadata to events.", + ) + + +class ZenMLCloudFeatureGateInterface(FeatureGateInterface, ZenMLCloudSession): + """Feature Gate interface definition.""" + + def check_entitlement(self, resource: ResourceType) -> None: + """Checks if a user is entitled to create a resource. + + Args: + resource: The resource the user wants to create + + Raises: + SubscriptionUpgradeRequiredError: in case a subscription limit is reached + """ + try: + response = self._get( + endpoint=ENTITLEMENT_ENDPOINT + "/" + resource, params=None + ) + except SubscriptionUpgradeRequiredError: + raise SubscriptionUpgradeRequiredError( + f"Your subscription reached its `{resource}` limit. Please " + f"upgrade your subscription or reach out to us." + ) + + if response.status_code != 200: + logger.warning( + "Unexpected response status code from entitlement " + f"endpoint: {response.status_code}. Message: " + f"{response.json()}" + ) + + def report_event( + self, + resource: ResourceType, + resource_id: UUID, + is_decrement: bool = False, + ) -> None: + """Reports the usage of a feature to the aggregator backend. + + Args: + resource: The resource the user created + resource_id: ID of the resource that was created/deleted. + is_decrement: In case this event reports an actual decrement of usage + """ + data = RawUsageEvent( + organization_id=ORGANIZATION_ID, + feature=resource, + total=1 if not is_decrement else -1, + metadata={ + "tenant_id": str(server_config.external_server_id), + "resource_id": str(resource_id), + }, + ).dict() + response = self._post(endpoint=USAGE_EVENT_ENDPOINT, data=data) + if response.status_code != 200: + logger.error( + "Usage report not accepted by upstream backend. " + f"Status Code: {response.status_code}, Message: " + f"{response.json()}." + ) diff --git a/src/zenml/zen_server/rate_limit.py b/src/zenml/zen_server/rate_limit.py new file mode 100644 index 00000000000..520025778d6 --- /dev/null +++ b/src/zenml/zen_server/rate_limit.py @@ -0,0 +1,184 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Rate limiting for the ZenML Server.""" + +import inspect +import time +from collections import defaultdict +from functools import wraps +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + TypeVar, + cast, +) + +from starlette.requests import Request + +from zenml.logger import get_logger +from zenml.zen_server.utils import server_config + +logger = get_logger(__name__) +F = TypeVar("F", bound=Callable[..., Any]) + + +class RequestLimiter: + """Simple in-memory rate limiter.""" + + def __init__( + self, + day_limit: Optional[int] = None, + minute_limit: Optional[int] = None, + ): + """Initializes the limiter. + + Args: + day_limit: The number of requests allowed per day. + minute_limit: The number of requests allowed per minute. + + Raises: + ValueError: If both day_limit and minute_limit are None. + """ + self.limiting_enabled = server_config().rate_limit_enabled + if not self.limiting_enabled: + return + if day_limit is None and minute_limit is None: + raise ValueError("Pass either day or minuter limits, or both.") + self.day_limit = day_limit + self.minute_limit = minute_limit + self.limiter: Dict[str, List[float]] = defaultdict(list) + + def hit_limiter(self, request: Request) -> None: + """Increase the number of hits in the limiter. + + Args: + request: Request object. + + Raises: + HTTPException: If the request limit is exceeded. + """ + if not self.limiting_enabled: + return + from fastapi import HTTPException + + requester = self._get_ipaddr(request) + now = time.time() + minute_ago = now - 60 + day_ago = now - 60 * 60 * 24 + self.limiter[requester].append(now) + + from bisect import bisect_left + + # remove failures older than a day + older_index = bisect_left(self.limiter[requester], day_ago) + self.limiter[requester] = self.limiter[requester][older_index:] + + if self.day_limit and len(self.limiter[requester]) > self.day_limit: + raise HTTPException( + status_code=429, detail="Daily request limit exceeded." + ) + minute_requests = len( + [ + limiter_hit + for limiter_hit in self.limiter[requester][::-1] + if limiter_hit >= minute_ago + ] + ) + if self.minute_limit and minute_requests > self.minute_limit: + raise HTTPException( + status_code=429, detail="Minute request limit exceeded." + ) + + def reset_limiter(self, request: Request) -> None: + """Resets the limiter on successful request. + + Args: + request: Request object. + """ + if self.limiting_enabled: + requester = self._get_ipaddr(request) + if requester in self.limiter: + del self.limiter[requester] + + def _get_ipaddr(self, request: Request) -> str: + """Returns the IP address for the current request. + + Based on the X-Forwarded-For headers or client information. + + Args: + request: The request object. + + Returns: + The ip address for the current request (or 127.0.0.1 if none found). + """ + if "X_FORWARDED_FOR" in request.headers: + return request.headers["X_FORWARDED_FOR"] + else: + if not request.client or not request.client.host: + return "127.0.0.1" + + return request.client.host + + +def rate_limit_requests( + day_limit: Optional[int] = None, + minute_limit: Optional[int] = None, +) -> Callable[..., Any]: + """Decorator to handle exceptions in the API. + + Args: + day_limit: Number of requests allowed per day. + minute_limit: Number of requests allowed per minute. + + Returns: + Decorated function. + """ + limiter = RequestLimiter(day_limit=day_limit, minute_limit=minute_limit) + + def decorator(func: F) -> F: + request_arg, request_kwarg = None, None + parameters = inspect.signature(func).parameters + for arg_num, arg_name in enumerate(parameters): + if parameters[arg_name].annotation == Request: + request_arg = arg_num + request_kwarg = arg_name + break + if request_arg is None or request_kwarg is None: + raise ValueError( + "Rate limiting APIs must have argument of `Request` type." + ) + + @wraps(func) + def decorated( + *args: Any, + **kwargs: Any, + ) -> Any: + if request_kwarg in kwargs: + request = kwargs[request_kwarg] + else: + request = args[request_arg] + limiter.hit_limiter(request) + + ret = func(*args, **kwargs) + + # if request was successful - reset limiter + limiter.reset_limiter(request) + return ret + + return cast(F, decorated) + + return decorator diff --git a/src/zenml/zen_server/rbac/endpoint_utils.py b/src/zenml/zen_server/rbac/endpoint_utils.py index 6cc78ddcc97..1f8abe8d6ea 100644 --- a/src/zenml/zen_server/rbac/endpoint_utils.py +++ b/src/zenml/zen_server/rbac/endpoint_utils.py @@ -1,3 +1,16 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. """High-level helper functions to write endpoints with RBAC.""" from typing import Any, Callable, TypeVar, Union @@ -5,6 +18,10 @@ from pydantic import BaseModel +from zenml.constants import ( + REPORTABLE_RESOURCES, + REQUIRES_CUSTOM_RESOURCE_REPORTING, +) from zenml.exceptions import IllegalOperationError from zenml.models import ( BaseFilter, @@ -14,6 +31,10 @@ UserScopedRequest, ) from zenml.zen_server.auth import get_auth_context +from zenml.zen_server.feature_gate.endpoint_utils import ( + check_entitlement, + report_usage, +) from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import ( dehydrate_page, @@ -58,12 +79,21 @@ def verify_permissions_and_create_entity( f"Not allowed to create resource '{resource_type}' for a " "different user." ) + verify_permission(resource_type=resource_type, action=Action.CREATE) - verify_permission( - resource_type=resource_type, - action=Action.CREATE, + needs_usage_increment = ( + resource_type in REPORTABLE_RESOURCES + and resource_type not in REQUIRES_CUSTOM_RESOURCE_REPORTING ) - return create_method(request_model) + if needs_usage_increment: + check_entitlement(resource_type) + + created = create_method(request_model) + + if needs_usage_increment: + report_usage(resource_type, resource_id=created.id) + + return created def verify_permissions_and_get_entity( @@ -141,18 +171,23 @@ def verify_permissions_and_delete_entity( id: UUIDOrStr, get_method: Callable[[UUIDOrStr], AnyResponse], delete_method: Callable[[UUIDOrStr], None], -) -> None: +) -> AnyResponse: """Verify permissions and delete an entity. Args: id: The ID of the entity to delete. get_method: The method to fetch the entity. delete_method: The method to delete the entity. + + Returns: + The deleted entity. """ model = get_method(id) verify_permission_for_model(model, action=Action.DELETE) delete_method(model.id) + return model + def verify_permissions_and_prune_entities( resource_type: ResourceType, diff --git a/src/zenml/zen_server/rbac/models.py b/src/zenml/zen_server/rbac/models.py index 4a7459db1a5..eb136685e77 100644 --- a/src/zenml/zen_server/rbac/models.py +++ b/src/zenml/zen_server/rbac/models.py @@ -58,6 +58,7 @@ class ResourceType(StrEnum): PIPELINE_DEPLOYMENT = "pipeline_deployment" PIPELINE_BUILD = "pipeline_build" USER = "user" + SERVICE = "service" RUN_METADATA = "run_metadata" SECRET = "secret" SERVICE_ACCOUNT = "service_account" diff --git a/src/zenml/zen_server/rbac/utils.py b/src/zenml/zen_server/rbac/utils.py index da64e417899..692b7f8d89c 100644 --- a/src/zenml/zen_server/rbac/utils.py +++ b/src/zenml/zen_server/rbac/utils.py @@ -400,6 +400,7 @@ def get_resource_type_for_model( SecretResponse, ServiceAccountResponse, ServiceConnectorResponse, + ServiceResponse, StackResponse, TagResponse, UserResponse, @@ -429,6 +430,7 @@ def get_resource_type_for_model( PipelineRunResponse: ResourceType.PIPELINE_RUN, TagResponse: ResourceType.TAG, ServiceAccountResponse: ResourceType.SERVICE_ACCOUNT, + ServiceResponse: ResourceType.SERVICE, } return mapping.get(type(model)) @@ -536,6 +538,7 @@ def get_schema_for_resource_type( RunMetadataSchema, SecretSchema, ServiceConnectorSchema, + ServiceSchema, StackComponentSchema, StackSchema, TagSchema, @@ -555,6 +558,7 @@ def get_schema_for_resource_type( ResourceType.ARTIFACT: ArtifactSchema, ResourceType.ARTIFACT_VERSION: ArtifactVersionSchema, ResourceType.SECRET: SecretSchema, + ResourceType.SERVICE: ServiceSchema, ResourceType.TAG: TagSchema, ResourceType.SERVICE_ACCOUNT: UserSchema, ResourceType.WORKSPACE: WorkspaceSchema, diff --git a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py index deeed246c51..fd534b313a9 100644 --- a/src/zenml/zen_server/rbac/zenml_cloud_rbac.py +++ b/src/zenml/zen_server/rbac/zenml_cloud_rbac.py @@ -13,13 +13,9 @@ # permissions and limitations under the License. """Cloud RBAC implementation.""" -import os -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple - -import requests -from pydantic import BaseModel, validator -from requests.adapters import HTTPAdapter, Retry +from typing import TYPE_CHECKING, Dict, List, Set, Tuple +from zenml.zen_server.cloud_utils import ZenMLCloudSession from zenml.zen_server.rbac.models import Action, Resource from zenml.zen_server.rbac.rbac_interface import RBACInterface from zenml.zen_server.utils import server_config @@ -28,7 +24,6 @@ from zenml.models import UserResponse -ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_" PERMISSIONS_ENDPOINT = "/rbac/check_permissions" ALLOWED_RESOURCE_IDS_ENDPOINT = "/rbac/allowed_resource_ids" RESOURCE_MEMBERSHIP_ENDPOINT = "/rbac/resource_members" @@ -79,60 +74,9 @@ def _convert_from_cloud_resource(cloud_resource: str) -> Resource: return Resource(type=resource_type_and_id) -class ZenMLCloudRBACConfiguration(BaseModel): - """ZenML Cloud RBAC configuration.""" - - api_url: str - - oauth2_client_id: str - oauth2_client_secret: str - oauth2_audience: str - auth0_domain: str - - @validator("api_url") - def _strip_trailing_slashes_url(cls, url: str) -> str: - """Strip any trailing slashes on the API URL. - - Args: - url: The API URL. - - Returns: - The API URL with potential trailing slashes removed. - """ - return url.rstrip("/") - - @classmethod - def from_environment(cls) -> "ZenMLCloudRBACConfiguration": - """Get the RBAC configuration from environment variables. - - Returns: - The RBAC configuration. - """ - env_config: Dict[str, Any] = {} - for k, v in os.environ.items(): - if v == "": - continue - if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX): - env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v - - return ZenMLCloudRBACConfiguration(**env_config) - - class Config: - """Pydantic configuration class.""" - - # Allow extra attributes from configs of previous ZenML versions to - # permit downgrading - extra = "allow" - - -class ZenMLCloudRBAC(RBACInterface): +class ZenMLCloudRBAC(RBACInterface, ZenMLCloudSession): """RBAC implementation that uses the ZenML Cloud API as a backend.""" - def __init__(self) -> None: - """Initialize the RBAC component.""" - self._config = ZenMLCloudRBACConfiguration.from_environment() - self._session: Optional[requests.Session] = None - def check_permissions( self, user: "UserResponse", resources: Set[Resource], action: Action ) -> Dict[Resource, bool]: @@ -234,129 +178,3 @@ def update_resource_membership( "actions": [str(action) for action in actions], } self._post(endpoint=RESOURCE_MEMBERSHIP_ENDPOINT, data=data) - - def _get(self, endpoint: str, params: Dict[str, Any]) -> requests.Response: - """Send a GET request using the active session. - - Args: - endpoint: The endpoint to send the request to. This will be appended - to the base URL. - params: Parameters to include in the request. - - Raises: - RuntimeError: If the request failed. - - Returns: - The response. - """ - url = self._config.api_url + endpoint - - response = self.session.get(url=url, params=params, timeout=7) - if response.status_code == 401: - # Refresh the auth token and try again - self._clear_session() - response = self.session.get(url=url, params=params, timeout=7) - - try: - response.raise_for_status() - except requests.HTTPError as e: - raise RuntimeError( - f"Failed while trying to contact RBAC service: {e}" - ) - - return response - - def _post( - self, - endpoint: str, - params: Optional[Dict[str, Any]] = None, - data: Optional[Dict[str, Any]] = None, - ) -> requests.Response: - """Send a POST request using the active session. - - Args: - endpoint: The endpoint to send the request to. This will be appended - to the base URL. - params: Parameters to include in the request. - data: Data to include in the request. - - Raises: - RuntimeError: If the request failed. - - Returns: - The response. - """ - url = self._config.api_url + endpoint - - response = self.session.post( - url=url, params=params, json=data, timeout=7 - ) - if response.status_code == 401: - # Refresh the auth token and try again - self._clear_session() - response = self.session.post( - url=url, params=params, json=data, timeout=7 - ) - - try: - response.raise_for_status() - except requests.HTTPError as e: - raise RuntimeError( - f"Failed while trying to contact RBAC service: {e}" - ) - - return response - - @property - def session(self) -> requests.Session: - """Authenticate to the ZenML Cloud API. - - Returns: - A requests session with the authentication token. - """ - if self._session is None: - self._session = requests.Session() - token = self._fetch_auth_token() - self._session.headers.update({"Authorization": "Bearer " + token}) - - retries = Retry(total=5, backoff_factor=0.1) - self._session.mount("https://", HTTPAdapter(max_retries=retries)) - - return self._session - - def _clear_session(self) -> None: - """Clear the authentication session.""" - self._session = None - - def _fetch_auth_token(self) -> str: - """Fetch an auth token for the Cloud API from auth0. - - Raises: - RuntimeError: If the auth token can't be fetched. - - Returns: - Auth token. - """ - # Get an auth token from auth0 - auth0_url = f"https://{self._config.auth0_domain}/oauth/token" - headers = {"content-type": "application/x-www-form-urlencoded"} - payload = { - "client_id": self._config.oauth2_client_id, - "client_secret": self._config.oauth2_client_secret, - "audience": self._config.oauth2_audience, - "grant_type": "client_credentials", - } - try: - response = requests.post( - auth0_url, headers=headers, data=payload, timeout=7 - ) - response.raise_for_status() - except Exception as e: - raise RuntimeError(f"Error fetching auth token from auth0: {e}") - - access_token = response.json().get("access_token", "") - - if not access_token or not isinstance(access_token, str): - raise RuntimeError("Could not fetch auth token from auth0.") - - return str(access_token) diff --git a/src/zenml/zen_server/routers/auth_endpoints.py b/src/zenml/zen_server/routers/auth_endpoints.py index f6b2d289ecb..41137a1e18a 100644 --- a/src/zenml/zen_server/routers/auth_endpoints.py +++ b/src/zenml/zen_server/routers/auth_endpoints.py @@ -65,6 +65,7 @@ ) from zenml.zen_server.exceptions import error_response from zenml.zen_server.jwt import JWTToken +from zenml.zen_server.rate_limit import rate_limit_requests from zenml.zen_server.rbac.models import Action, ResourceType from zenml.zen_server.rbac.utils import verify_permission from zenml.zen_server.utils import ( @@ -255,6 +256,10 @@ def generate_access_token( LOGIN, response_model=Union[OAuthTokenResponse, OAuthRedirectResponse], ) +@rate_limit_requests( + day_limit=server_config().login_rate_limit_day, + minute_limit=server_config().login_rate_limit_minute, +) @handle_exceptions def token( request: Request, diff --git a/src/zenml/zen_server/routers/models_endpoints.py b/src/zenml/zen_server/routers/models_endpoints.py index c43026d0bdf..124660e5cfc 100644 --- a/src/zenml/zen_server/routers/models_endpoints.py +++ b/src/zenml/zen_server/routers/models_endpoints.py @@ -22,6 +22,7 @@ API, MODEL_VERSIONS, MODELS, + REPORTABLE_RESOURCES, VERSION_1, ) from zenml.models import ( @@ -34,6 +35,7 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.feature_gate.endpoint_utils import report_decrement from zenml.zen_server.rbac.endpoint_utils import ( verify_permissions_and_delete_entity, verify_permissions_and_get_entity, @@ -48,6 +50,7 @@ from zenml.zen_server.utils import ( handle_exceptions, make_dependable, + server_config, zen_store, ) @@ -160,12 +163,16 @@ def delete_model( Args: model_name_or_id: The name or ID of the model to delete. """ - verify_permissions_and_delete_entity( + model = verify_permissions_and_delete_entity( id=model_name_or_id, get_method=zen_store().get_model, delete_method=zen_store().delete_model, ) + if server_config().feature_gate_enabled: + if ResourceType.MODEL in REPORTABLE_RESOURCES: + report_decrement(ResourceType.MODEL, resource_id=model.id) + ################# # Model Versions diff --git a/src/zenml/zen_server/routers/pipelines_endpoints.py b/src/zenml/zen_server/routers/pipelines_endpoints.py index f4e5ad27808..fb1510ac772 100644 --- a/src/zenml/zen_server/routers/pipelines_endpoints.py +++ b/src/zenml/zen_server/routers/pipelines_endpoints.py @@ -18,7 +18,14 @@ from fastapi import APIRouter, Depends, Security from zenml.config.pipeline_spec import PipelineSpec -from zenml.constants import API, PIPELINE_SPEC, PIPELINES, RUNS, VERSION_1 +from zenml.constants import ( + API, + PIPELINE_SPEC, + PIPELINES, + REPORTABLE_RESOURCES, + RUNS, + VERSION_1, +) from zenml.models import ( Page, PipelineFilter, @@ -31,6 +38,7 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.feature_gate.endpoint_utils import report_decrement from zenml.zen_server.rbac.endpoint_utils import ( verify_permissions_and_delete_entity, verify_permissions_and_get_entity, @@ -154,12 +162,20 @@ def delete_pipeline( Args: pipeline_id: ID of the pipeline to delete. """ - verify_permissions_and_delete_entity( + pipeline = verify_permissions_and_delete_entity( id=pipeline_id, get_method=zen_store().get_pipeline, delete_method=zen_store().delete_pipeline, ) + should_decrement = ( + ResourceType.PIPELINE in REPORTABLE_RESOURCES + and zen_store().count_pipelines(PipelineFilter(name=pipeline.name)) + == 0 + ) + if should_decrement: + report_decrement(ResourceType.PIPELINE, resource_id=pipeline_id) + @router.get( "/{pipeline_id}" + RUNS, diff --git a/src/zenml/zen_server/routers/service_endpoints.py b/src/zenml/zen_server/routers/service_endpoints.py new file mode 100644 index 00000000000..1d7925494df --- /dev/null +++ b/src/zenml/zen_server/routers/service_endpoints.py @@ -0,0 +1,180 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Endpoint definitions for services.""" + +from uuid import UUID + +from fastapi import APIRouter, Depends, Security + +from zenml.constants import API, SERVICES, VERSION_1 +from zenml.models import ( + Page, + ServiceFilter, + ServiceResponse, + ServiceUpdate, +) +from zenml.models.v2.core.service import ServiceRequest +from zenml.zen_server.auth import AuthContext, authorize +from zenml.zen_server.exceptions import error_response +from zenml.zen_server.rbac.endpoint_utils import ( + verify_permissions_and_create_entity, + verify_permissions_and_delete_entity, + verify_permissions_and_get_entity, + verify_permissions_and_list_entities, + verify_permissions_and_update_entity, +) +from zenml.zen_server.rbac.models import ResourceType +from zenml.zen_server.utils import ( + handle_exceptions, + make_dependable, + zen_store, +) + +router = APIRouter( + prefix=API + VERSION_1 + SERVICES, + tags=["services"], + responses={401: error_response, 403: error_response}, +) + + +@router.post( + "", + response_model=ServiceResponse, + responses={401: error_response, 422: error_response}, +) +@handle_exceptions +def create_service( + service: ServiceRequest, + _: AuthContext = Security(authorize), +) -> ServiceResponse: + """Creates a new service. + + Args: + service: The model containing the attributes of the new service. + + Returns: + The created service object. + """ + return verify_permissions_and_create_entity( + request_model=service, + create_method=zen_store().create_service, + resource_type=ResourceType.SERVICE, + ) + + +@router.get( + "", + response_model=Page[ServiceResponse], + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def list_services( + filter_model: ServiceFilter = Depends(make_dependable(ServiceFilter)), + hydrate: bool = False, + _: AuthContext = Security(authorize), +) -> Page[ServiceResponse]: + """Gets a page of service objects. + + Args: + filter_model: Filter model used for pagination, sorting, + filtering. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + Page of service objects. + """ + return verify_permissions_and_list_entities( + filter_model=filter_model, + resource_type=ResourceType.SERVICE, + list_method=zen_store().list_services, + hydrate=hydrate, + ) + + +@router.get( + "/{service_id}", + response_model=ServiceResponse, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def get_service( + service_id: UUID, + hydrate: bool = True, + _: AuthContext = Security(authorize), +) -> ServiceResponse: + """Gets a specific service using its unique ID. + + Args: + service_id: The ID of the service to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A specific service object. + """ + return verify_permissions_and_get_entity( + id=service_id, + get_method=zen_store().get_service, + hydrate=hydrate, + ) + + +@router.put( + "/{service_id}", + response_model=ServiceResponse, + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def update_service( + service_id: UUID, + update: ServiceUpdate, + _: AuthContext = Security(authorize), +) -> ServiceResponse: + """Updates a service. + + Args: + service_id: The ID of the service to update. + update: The model containing the attributes to update. + + Returns: + The updated service object. + """ + return verify_permissions_and_update_entity( + id=service_id, + update_model=update, + get_method=zen_store().get_service, + update_method=zen_store().update_service, + ) + + +@router.delete( + "/{service_id}", + responses={401: error_response, 404: error_response, 422: error_response}, +) +@handle_exceptions +def delete_service( + service_id: UUID, + _: AuthContext = Security(authorize), +) -> None: + """Deletes a specific service. + + Args: + service_id: The ID of the service to delete. + """ + verify_permissions_and_delete_entity( + id=service_id, + get_method=zen_store().get_service, + delete_method=zen_store().delete_service, + ) diff --git a/src/zenml/zen_server/routers/webhook_endpoints.py b/src/zenml/zen_server/routers/webhook_endpoints.py index 2f115a535b1..9cb167a5d34 100644 --- a/src/zenml/zen_server/routers/webhook_endpoints.py +++ b/src/zenml/zen_server/routers/webhook_endpoints.py @@ -13,9 +13,10 @@ # permissions and limitations under the License. """Endpoint definitions for webhooks.""" +from typing import Dict from uuid import UUID -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, BackgroundTasks, Depends, Request from zenml.constants import API, VERSION_1, WEBHOOKS from zenml.enums import PluginSubType, PluginType @@ -52,20 +53,26 @@ async def get_body(request: Request) -> bytes: @router.post( "/{event_source_id}", + response_model=Dict[str, str], ) @handle_exceptions def webhook( event_source_id: UUID, request: Request, + background_tasks: BackgroundTasks, raw_body: bytes = Depends(get_body), -) -> None: +) -> Dict[str, str]: """Webhook to receive events from external event sources. Args: event_source_id: The event_source_id request: The request object + background_tasks: Background task handler raw_body: The raw request body + Returns: + Static dict stating that event is received. + Raises: AuthorizationException: If the Event Source does not exist. KeyError: If no appropriate Plugin found in the plugin registry @@ -111,8 +118,11 @@ def webhook( ) # Pass the raw event and headers to the plugin - plugin.process_webhook_event( + background_tasks.add_task( + plugin.process_webhook_event, event_source=event_source, raw_body=raw_body, headers=dict(request.headers.items()), ) + + return {"status": "Event Received."} diff --git a/src/zenml/zen_server/routers/workspaces_endpoints.py b/src/zenml/zen_server/routers/workspaces_endpoints.py index 042565b6275..4b747ae6f0c 100644 --- a/src/zenml/zen_server/routers/workspaces_endpoints.py +++ b/src/zenml/zen_server/routers/workspaces_endpoints.py @@ -28,12 +28,14 @@ PIPELINE_BUILDS, PIPELINE_DEPLOYMENTS, PIPELINES, + REPORTABLE_RESOURCES, RUN_METADATA, RUNS, SCHEDULES, SECRETS, SERVICE_CONNECTOR_RESOURCES, SERVICE_CONNECTORS, + SERVICES, STACK_COMPONENTS, STACKS, STATISTICS, @@ -80,6 +82,8 @@ ServiceConnectorRequest, ServiceConnectorResourcesModel, ServiceConnectorResponse, + ServiceRequest, + ServiceResponse, StackFilter, StackRequest, StackResponse, @@ -90,6 +94,10 @@ ) from zenml.zen_server.auth import AuthContext, authorize from zenml.zen_server.exceptions import error_response +from zenml.zen_server.feature_gate.endpoint_utils import ( + check_entitlement, + report_usage, +) from zenml.zen_server.rbac.endpoint_utils import ( verify_permissions_and_create_entity, verify_permissions_and_delete_entity, @@ -509,12 +517,30 @@ def create_pipeline( f"not supported." ) - return verify_permissions_and_create_entity( + # We limit pipeline namespaces, not pipeline versions + needs_usage_increment = ( + ResourceType.PIPELINE in REPORTABLE_RESOURCES + and zen_store().count_pipelines(PipelineFilter(name=pipeline.name)) + == 0 + ) + + if needs_usage_increment: + check_entitlement(ResourceType.PIPELINE) + + pipeline_response = verify_permissions_and_create_entity( request_model=pipeline, resource_type=ResourceType.PIPELINE, create_method=zen_store().create_pipeline, ) + if needs_usage_increment: + report_usage( + resource_type=ResourceType.PIPELINE, + resource_id=pipeline_response.id, + ) + + return pipeline_response + @router.get( WORKSPACES + "/{workspace_name_or_id}" + PIPELINE_BUILDS, @@ -1431,3 +1457,44 @@ def create_model_version_pipeline_run_link( model_version_pipeline_run_link ) return mv + + +@router.post( + WORKSPACES + "/{workspace_name_or_id}" + SERVICES, + response_model=ServiceResponse, + responses={401: error_response, 409: error_response, 422: error_response}, +) +@handle_exceptions +def create_service( + workspace_name_or_id: Union[str, UUID], + service: ServiceRequest, + _: AuthContext = Security(authorize), +) -> ServiceResponse: + """Create a new service. + + Args: + workspace_name_or_id: Name or ID of the workspace. + service: The service to create. + + Returns: + The created service. + + Raises: + IllegalOperationError: If the workspace or user specified in the + model does not match the current workspace or authenticated + user. + """ + workspace = zen_store().get_workspace(workspace_name_or_id) + + if service.workspace != workspace.id: + raise IllegalOperationError( + "Creating models outside of the workspace scope " + f"of this endpoint `{workspace_name_or_id}` is " + f"not supported." + ) + + return verify_permissions_and_create_entity( + request_model=service, + resource_type=ResourceType.SERVICE, + create_method=zen_store().create_service, + ) diff --git a/src/zenml/zen_server/utils.py b/src/zenml/zen_server/utils.py index bc6be68e24b..2d6d9af132e 100644 --- a/src/zenml/zen_server/utils.py +++ b/src/zenml/zen_server/utils.py @@ -16,7 +16,15 @@ import inspect import os from functools import wraps -from typing import Any, Callable, Optional, Tuple, Type, TypeVar, cast +from typing import ( + Any, + Callable, + Optional, + Tuple, + Type, + TypeVar, + cast, +) from urllib.parse import urlparse from pydantic import BaseModel, ValidationError @@ -35,6 +43,9 @@ LocalServerDeploymentConfig, ) from zenml.zen_server.exceptions import http_exception_from_error +from zenml.zen_server.feature_gate.feature_gate_interface import ( + FeatureGateInterface, +) from zenml.zen_server.pipeline_deployment.workload_manager_interface import ( WorkloadManagerInterface, ) @@ -45,6 +56,7 @@ _zen_store: Optional["SqlZenStore"] = None _rbac: Optional[RBACInterface] = None +_feature_gate: Optional[FeatureGateInterface] = None _workload_manager: Optional[WorkloadManagerInterface] = None _plugin_flavor_registry: Optional[PluginFlavorRegistry] = None @@ -92,6 +104,50 @@ def rbac() -> RBACInterface: return _rbac +def initialize_rbac() -> None: + """Initialize the RBAC component.""" + global _rbac + + if rbac_source := server_config().rbac_implementation_source: + from zenml.utils import source_utils + + implementation_class = source_utils.load_and_validate_class( + rbac_source, expected_class=RBACInterface + ) + _rbac = implementation_class() + + +def feature_gate() -> FeatureGateInterface: + """Return the initialized Feature Gate component. + + Raises: + RuntimeError: If the RBAC component is not initialized. + + Returns: + The RBAC component. + """ + global _feature_gate + if _feature_gate is None: + raise RuntimeError("Feature gate component not initialized.") + return _feature_gate + + +def initialize_feature_gate() -> None: + """Initialize the Feature Gate component.""" + global _feature_gate + + if ( + feature_gate_source + := server_config().feature_gate_implementation_source + ): + from zenml.utils import source_utils + + implementation_class = source_utils.load_and_validate_class( + feature_gate_source, expected_class=FeatureGateInterface + ) + _feature_gate = implementation_class() + + def workload_manager() -> WorkloadManagerInterface: """Return the initialized workload manager component. @@ -107,19 +163,6 @@ def workload_manager() -> WorkloadManagerInterface: return _workload_manager -def initialize_rbac() -> None: - """Initialize the RBAC component.""" - global _rbac - - if rbac_source := server_config().rbac_implementation_source: - from zenml.utils import source_utils - - implementation_class = source_utils.load_and_validate_class( - rbac_source, expected_class=RBACInterface - ) - _rbac = implementation_class() - - def initialize_workload_manager() -> None: """Initialize the workload manager component. diff --git a/src/zenml/zen_server/zen_server_api.py b/src/zenml/zen_server/zen_server_api.py index 79c9c01f7c0..b5f01f940da 100644 --- a/src/zenml/zen_server/zen_server_api.py +++ b/src/zenml/zen_server/zen_server_api.py @@ -52,6 +52,7 @@ server_endpoints, service_accounts_endpoints, service_connectors_endpoints, + service_endpoints, stack_components_endpoints, stacks_endpoints, steps_endpoints, @@ -62,6 +63,7 @@ workspaces_endpoints, ) from zenml.zen_server.utils import ( + initialize_feature_gate, initialize_plugins, initialize_rbac, initialize_workload_manager, @@ -158,6 +160,7 @@ def initialize() -> None: # race conditions initialize_zen_store() initialize_rbac() + initialize_feature_gate() initialize_workload_manager() initialize_plugins() @@ -234,6 +237,7 @@ def dashboard(request: Request) -> Any: app.include_router(service_accounts_endpoints.router) app.include_router(service_connectors_endpoints.router) app.include_router(service_connectors_endpoints.types_router) +app.include_router(service_endpoints.router) app.include_router(stacks_endpoints.router) app.include_router(stack_components_endpoints.router) app.include_router(stack_components_endpoints.types_router) diff --git a/src/zenml/zen_stores/migrations/utils.py b/src/zenml/zen_stores/migrations/utils.py index f1300946ee5..6ee4af7b3b5 100644 --- a/src/zenml/zen_stores/migrations/utils.py +++ b/src/zenml/zen_stores/migrations/utils.py @@ -236,9 +236,17 @@ def backup_database_to_storage( # correct order, since some tables have inner foreign key # constraints. if "created" in table.columns: - order_by = table.columns["created"] + order_by = [table.columns["created"]] else: - order_by = None + order_by = [] + if "id" in table.columns: + # If the table has an `id` column, we also use it to sort + # the rows in the table, even if we already use "created" + # to sort the rows. We need a unique field to sort the rows, + # to break the tie between rows with the same "created" + # date, otherwise the same entry might end up multiple times + # in subsequent pages. + order_by.append(table.columns["id"]) # Fetch the number of rows in the table row_count = conn.scalar( @@ -250,7 +258,7 @@ def backup_database_to_storage( for i in range(0, row_count, batch_size): rows = conn.execute( table.select() - .order_by(order_by) + .order_by(*order_by) .limit(batch_size) .offset(i) ).fetchall() diff --git a/src/zenml/zen_stores/migrations/versions/0.56.0_release.py b/src/zenml/zen_stores/migrations/versions/0.56.0_release.py new file mode 100644 index 00000000000..85dc2ccdf5e --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/0.56.0_release.py @@ -0,0 +1,23 @@ +"""Release [0.56.0]. + +Revision ID: 0.56.0 +Revises: 1a9a9d2a836d +Create Date: 2024-03-20 13:30:40.013587 + +""" + +# revision identifiers, used by Alembic. +revision = "0.56.0" +down_revision = "1a9a9d2a836d" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + pass + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + pass diff --git a/src/zenml/zen_stores/migrations/versions/0.56.1_release.py b/src/zenml/zen_stores/migrations/versions/0.56.1_release.py new file mode 100644 index 00000000000..d1eb6c0c982 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/0.56.1_release.py @@ -0,0 +1,23 @@ +"""Release [0.56.1]. + +Revision ID: 0.56.1 +Revises: 0.56.0 +Create Date: 2024-03-21 14:50:20.869911 + +""" + +# revision identifiers, used by Alembic. +revision = "0.56.1" +down_revision = "0.56.0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + pass + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + pass diff --git a/src/zenml/zen_stores/migrations/versions/0.56.2_release.py b/src/zenml/zen_stores/migrations/versions/0.56.2_release.py new file mode 100644 index 00000000000..47431e949fe --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/0.56.2_release.py @@ -0,0 +1,23 @@ +"""Release [0.56.2]. + +Revision ID: 0.56.2 +Revises: 0701da9951a0 +Create Date: 2024-03-25 14:49:49.021147 + +""" + +# revision identifiers, used by Alembic. +revision = "0.56.2" +down_revision = "0701da9951a0" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + pass + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + pass diff --git a/src/zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py b/src/zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py new file mode 100644 index 00000000000..b32a6fe8b72 --- /dev/null +++ b/src/zenml/zen_stores/migrations/versions/0701da9951a0_added_service_table.py @@ -0,0 +1,94 @@ +"""Added service table [0701da9951a0]. + +Revision ID: 0701da9951a0 +Revises: 0.56.1 +Create Date: 2024-03-25 12:24:32.928543 + +""" + +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy.engine.reflection import Inspector + +# revision identifiers, used by Alembic. +revision = "0701da9951a0" +down_revision = "0.56.1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + """Upgrade database schema and/or data, creating a new revision.""" + # If the tables already exist, skip this migration. + conn = op.get_bind() + inspector = Inspector.from_engine(conn) + tables = inspector.get_table_names() + if "service" in tables: + return + + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "service", + sa.Column( + "workspace_id", sqlmodel.sql.sqltypes.GUID(), nullable=False + ), + sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=True), + sa.Column("service_source", sa.TEXT(), nullable=True), + sa.Column("service_type", sa.TEXT(), nullable=False), + sa.Column("type", sa.TEXT(), nullable=False), + sa.Column("flavor", sa.TEXT(), nullable=False), + sa.Column("admin_state", sa.TEXT(), nullable=True), + sa.Column("state", sa.TEXT(), nullable=True), + sa.Column("prediction_url", sa.TEXT(), nullable=True), + sa.Column("health_check_url", sa.TEXT(), nullable=True), + sa.Column("pipeline_name", sa.TEXT(), nullable=True), + sa.Column("pipeline_step_name", sa.TEXT(), nullable=True), + sa.Column( + "model_version_id", sqlmodel.sql.sqltypes.GUID(), nullable=True + ), + sa.Column( + "pipeline_run_id", sqlmodel.sql.sqltypes.GUID(), nullable=True + ), + sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False), + sa.Column("created", sa.DateTime(), nullable=False), + sa.Column("updated", sa.DateTime(), nullable=False), + sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False), + sa.Column("labels", sa.LargeBinary(), nullable=True), + sa.Column("config", sa.LargeBinary(), nullable=False), + sa.Column("status", sa.LargeBinary(), nullable=True), + sa.Column("endpoint", sa.LargeBinary(), nullable=True), + sa.ForeignKeyConstraint( + ["model_version_id"], + ["model_version.id"], + name="fk_service_model_version_id_model_version", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["pipeline_run_id"], + ["pipeline_run.id"], + name="fk_service_pipeline_run_id_pipeline_run", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + name="fk_service_user_id_user", + ondelete="SET NULL", + ), + sa.ForeignKeyConstraint( + ["workspace_id"], + ["workspace.id"], + name="fk_service_workspace_id_workspace", + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + """Downgrade database schema and/or data back to the previous revision.""" + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("service") + # ### end Alembic commands ### diff --git a/src/zenml/zen_stores/rest_zen_store.py b/src/zenml/zen_stores/rest_zen_store.py index 8a52010daf3..d03cec9f79b 100644 --- a/src/zenml/zen_stores/rest_zen_store.py +++ b/src/zenml/zen_stores/rest_zen_store.py @@ -80,6 +80,7 @@ SERVICE_CONNECTOR_TYPES, SERVICE_CONNECTOR_VERIFY, SERVICE_CONNECTORS, + SERVICES, STACK_COMPONENTS, STACKS, STEPS, @@ -189,6 +190,10 @@ ServiceConnectorResponse, ServiceConnectorTypeModel, ServiceConnectorUpdate, + ServiceFilter, + ServiceRequest, + ServiceResponse, + ServiceUpdate, StackFilter, StackRequest, StackResponse, @@ -590,6 +595,93 @@ def delete_api_key( route=f"{SERVICE_ACCOUNTS}/{str(service_account_id)}{API_KEYS}", ) + # ----------------------------- Services ----------------------------- + + def create_service( + self, service_request: ServiceRequest + ) -> ServiceResponse: + """Create a new service. + + Args: + service_request: The service to create. + + Returns: + The created service. + """ + return self._create_resource( + resource=service_request, + response_model=ServiceResponse, + route=SERVICES, + ) + + def get_service( + self, service_id: UUID, hydrate: bool = True + ) -> ServiceResponse: + """Get a service. + + Args: + service_id: The ID of the service to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The service. + """ + return self._get_resource( + resource_id=service_id, + route=SERVICES, + response_model=ServiceResponse, + params={"hydrate": hydrate}, + ) + + def list_services( + self, filter_model: ServiceFilter, hydrate: bool = False + ) -> Page[ServiceResponse]: + """List all services matching the given filter criteria. + + Args: + filter_model: All filter parameters including pagination + params. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A list of all services matching the filter criteria. + """ + return self._list_paginated_resources( + route=SERVICES, + response_model=ServiceResponse, + filter_model=filter_model, + params={"hydrate": hydrate}, + ) + + def update_service( + self, service_id: UUID, update: ServiceUpdate + ) -> ServiceResponse: + """Update a service. + + Args: + service_id: The ID of the service to update. + update: The update to be applied to the service. + + Returns: + The updated service. + """ + return self._update_resource( + resource_id=service_id, + resource_update=update, + response_model=ServiceResponse, + route=SERVICES, + ) + + def delete_service(self, service_id: UUID) -> None: + """Delete a service. + + Args: + service_id: The ID of the service to delete. + """ + self._delete_resource(resource_id=service_id, route=SERVICES) + # ----------------------------- Artifacts ----------------------------- def create_artifact(self, artifact: ArtifactRequest) -> ArtifactResponse: @@ -3816,6 +3908,7 @@ def _create_resource( The created resource. """ response_body = self.post(f"{route}", body=resource, params=params) + return response_model.parse_obj(response_body) def _create_workspace_scoped_resource( diff --git a/src/zenml/zen_stores/schemas/__init__.py b/src/zenml/zen_stores/schemas/__init__.py index 0ec208fff81..5957605c0c7 100644 --- a/src/zenml/zen_stores/schemas/__init__.py +++ b/src/zenml/zen_stores/schemas/__init__.py @@ -41,6 +41,7 @@ from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema from zenml.zen_stores.schemas.secret_schemas import SecretSchema +from zenml.zen_stores.schemas.service_schemas import ServiceSchema from zenml.zen_stores.schemas.service_connector_schemas import ( ServiceConnectorSchema, ) @@ -90,6 +91,7 @@ "ScheduleSchema", "SecretSchema", "ServiceConnectorSchema", + "ServiceSchema", "StackComponentSchema", "StackCompositionSchema", "StackSchema", diff --git a/src/zenml/zen_stores/schemas/artifact_schemas.py b/src/zenml/zen_stores/schemas/artifact_schemas.py index 40bb32d2a57..6aeebd556ca 100644 --- a/src/zenml/zen_stores/schemas/artifact_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_schemas.py @@ -171,7 +171,7 @@ class ArtifactVersionSchema(BaseSchema, table=True): # Fields version: str version_number: Optional[int] - type: str + type: ArtifactType uri: str = Field(sa_column=Column(TEXT, nullable=False)) materializer: str = Field(sa_column=Column(TEXT, nullable=False)) data_type: str = Field(sa_column=Column(TEXT, nullable=False)) @@ -277,7 +277,7 @@ def from_request( artifact_store_id=artifact_version_request.artifact_store_id, workspace_id=artifact_version_request.workspace, user_id=artifact_version_request.user, - type=artifact_version_request.type.value, + type=artifact_version_request.type, uri=artifact_version_request.uri, materializer=artifact_version_request.materializer.json(), data_type=artifact_version_request.data_type.json(), @@ -328,7 +328,7 @@ def to_model( version=self.version_number or self.version, user=self.user.to_model() if self.user else None, uri=self.uri, - type=ArtifactType(self.type), + type=self.type, materializer=materializer, data_type=data_type, created=self.created, diff --git a/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py b/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py index 79862fc0376..6447bf93d25 100644 --- a/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py +++ b/src/zenml/zen_stores/schemas/artifact_visualization_schemas.py @@ -37,7 +37,7 @@ class ArtifactVisualizationSchema(BaseSchema, table=True): __tablename__ = "artifact_visualization" # Fields - type: str + type: VisualizationType uri: str = Field(sa_column=Column(TEXT, nullable=False)) # Foreign Keys @@ -71,7 +71,7 @@ def from_model( The `ArtifactVisualizationSchema`. """ return cls( - type=artifact_visualization_request.type.value, + type=artifact_visualization_request.type, uri=artifact_visualization_request.uri, artifact_version_id=artifact_version_id, ) @@ -95,7 +95,7 @@ def to_model( The `Visualization`. """ body = ArtifactVisualizationResponseBody( - type=VisualizationType(self.type), + type=self.type, uri=self.uri, created=self.created, updated=self.updated, diff --git a/src/zenml/zen_stores/schemas/component_schemas.py b/src/zenml/zen_stores/schemas/component_schemas.py index ca50ac29e37..f3b5b44ea72 100644 --- a/src/zenml/zen_stores/schemas/component_schemas.py +++ b/src/zenml/zen_stores/schemas/component_schemas.py @@ -49,7 +49,7 @@ class StackComponentSchema(NamedSchema, table=True): __tablename__ = "stack_component" - type: str + type: StackComponentType flavor: str configuration: bytes labels: Optional[bytes] @@ -127,8 +127,6 @@ def update( self.labels = base64.b64encode( json.dumps(component_update.labels).encode("utf-8") ) - elif field == "type": - self.type = component_update.type.value else: setattr(self, field, value) @@ -153,7 +151,7 @@ def to_model( A `ComponentModel` """ body = ComponentResponseBody( - type=StackComponentType(self.type), + type=self.type, flavor=self.flavor, user=self.user.to_model() if self.user else None, created=self.created, diff --git a/src/zenml/zen_stores/schemas/device_schemas.py b/src/zenml/zen_stores/schemas/device_schemas.py index abb9e6551c7..93ebc69556f 100644 --- a/src/zenml/zen_stores/schemas/device_schemas.py +++ b/src/zenml/zen_stores/schemas/device_schemas.py @@ -44,7 +44,7 @@ class OAuthDeviceSchema(BaseSchema, table=True): client_id: UUID user_code: str device_code: str - status: str + status: OAuthDeviceStatus failed_auth_attempts: int = 0 expires: Optional[datetime] = None last_login: Optional[datetime] = None @@ -121,7 +121,7 @@ def from_request( client_id=request.client_id, user_code=hashed_user_code, device_code=hashed_device_code, - status=OAuthDeviceStatus.PENDING.value, + status=OAuthDeviceStatus.PENDING, failed_auth_attempts=0, expires=now + timedelta(seconds=request.expires_in), os=request.os, @@ -153,9 +153,9 @@ def update(self, device_update: OAuthDeviceUpdate) -> "OAuthDeviceSchema": setattr(self, field, value) if device_update.locked is True: - self.status = OAuthDeviceStatus.LOCKED.value + self.status = OAuthDeviceStatus.LOCKED elif device_update.locked is False: - self.status = OAuthDeviceStatus.ACTIVE.value + self.status = OAuthDeviceStatus.ACTIVE self.updated = datetime.utcnow() return self @@ -233,7 +233,7 @@ def to_model( client_id=self.client_id, expires=self.expires, trusted_device=self.trusted_device, - status=OAuthDeviceStatus(self.status), + status=self.status, os=self.os, ip_address=self.ip_address, hostname=self.hostname, diff --git a/src/zenml/zen_stores/schemas/flavor_schemas.py b/src/zenml/zen_stores/schemas/flavor_schemas.py index 7ace6f97716..edb9c3d8b37 100644 --- a/src/zenml/zen_stores/schemas/flavor_schemas.py +++ b/src/zenml/zen_stores/schemas/flavor_schemas.py @@ -46,7 +46,7 @@ class FlavorSchema(NamedSchema, table=True): __tablename__ = "flavor" - type: str + type: StackComponentType source: str config_schema: str = Field(sa_column=Column(TEXT, nullable=False)) integration: Optional[str] = Field(default="") @@ -98,8 +98,6 @@ def update(self, flavor_update: "FlavorUpdate") -> "FlavorSchema": ).items(): if field == "config_schema": setattr(self, field, json.dumps(value)) - elif field == "type": - setattr(self, field, value.value) else: setattr(self, field, value) @@ -125,7 +123,7 @@ def to_model( """ body = FlavorResponseBody( user=self.user.to_model() if self.user else None, - type=StackComponentType(self.type), + type=self.type, integration=self.integration, logo_url=self.logo_url, created=self.created, diff --git a/src/zenml/zen_stores/schemas/model_schemas.py b/src/zenml/zen_stores/schemas/model_schemas.py index 52e0355b9eb..4658d094281 100644 --- a/src/zenml/zen_stores/schemas/model_schemas.py +++ b/src/zenml/zen_stores/schemas/model_schemas.py @@ -14,7 +14,7 @@ """SQLModel implementation of model tables.""" from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast from uuid import UUID from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column @@ -38,6 +38,8 @@ ModelVersionResponse, ModelVersionResponseBody, ModelVersionResponseMetadata, + ModelVersionResponseResources, + Page, ) from zenml.zen_stores.schemas.artifact_schemas import ArtifactVersionSchema from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema @@ -46,8 +48,12 @@ from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field from zenml.zen_stores.schemas.tag_schemas import TagResourceSchema from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.utils import get_page_from_list from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema +if TYPE_CHECKING: + from zenml.zen_stores.schemas import ServiceSchema + class ModelSchema(NamedSchema, table=True): """SQL Model for model.""" @@ -263,6 +269,10 @@ class ModelVersionSchema(NamedSchema, table=True): ), ) + services: List["ServiceSchema"] = Relationship( + back_populates="model_version", + ) + number: int = Field(sa_column=Column(INTEGER, nullable=False)) description: str = Field(sa_column=Column(TEXT, nullable=True)) stage: str = Field(sa_column=Column(TEXT, nullable=True)) @@ -315,6 +325,8 @@ def to_model( Returns: The created `ModelVersionResponse`. """ + from zenml.models import ServiceResponse + # Construct {name: {version: id}} dicts for all linked artifacts model_artifact_ids: Dict[str, Dict[str, UUID]] = {} deployment_artifact_ids: Dict[str, Dict[str, UUID]] = {} @@ -347,7 +359,6 @@ def to_model( pipeline_run_ids[pipeline_run.name] = pipeline_run.id metadata = None - if include_metadata: metadata = ModelVersionResponseMetadata( workspace=self.workspace.to_model(), @@ -358,6 +369,21 @@ def to_model( }, ) + resources = None + if include_resources: + services = cast( + Page[ServiceResponse], + get_page_from_list( + items_list=self.services, + response_model=ServiceResponse, + include_resources=include_resources, + include_metadata=include_metadata, + ), + ) + resources = ModelVersionResponseResources( + services=services, + ) + body = ModelVersionResponseBody( user=self.user.to_model() if self.user else None, created=self.created, @@ -377,6 +403,7 @@ def to_model( name=self.name, body=body, metadata=metadata, + resources=resources, ) def update( diff --git a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py index 966952d4416..c27cc34bb74 100644 --- a/src/zenml/zen_stores/schemas/pipeline_run_schemas.py +++ b/src/zenml/zen_stores/schemas/pipeline_run_schemas.py @@ -49,6 +49,7 @@ ModelVersionPipelineRunSchema, ) from zenml.zen_stores.schemas.run_metadata_schemas import RunMetadataSchema + from zenml.zen_stores.schemas.service_schemas import ServiceSchema from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema @@ -68,7 +69,7 @@ class PipelineRunSchema(NamedSchema, table=True): orchestrator_run_id: Optional[str] = Field(nullable=True) start_time: Optional[datetime] = Field(nullable=True) end_time: Optional[datetime] = Field(nullable=True, default=None) - status: str = Field(nullable=False) + status: ExecutionStatus = Field(nullable=False) orchestrator_environment: Optional[str] = Field( sa_column=Column(TEXT, nullable=True) ) @@ -182,6 +183,10 @@ class PipelineRunSchema(NamedSchema, table=True): pipeline: Optional["PipelineSchema"] = Relationship(back_populates="runs") trigger_execution: Optional["TriggerExecutionSchema"] = Relationship() + services: List["ServiceSchema"] = Relationship( + back_populates="pipeline_run", + ) + @classmethod def from_request( cls, request: "PipelineRunRequest" @@ -203,7 +208,7 @@ def from_request( orchestrator_run_id=request.orchestrator_run_id, orchestrator_environment=orchestrator_environment, start_time=request.start_time, - status=request.status.value, + status=request.status, pipeline_id=request.pipeline, deployment_id=request.deployment, trigger_execution_id=request.trigger_execution_id, @@ -277,7 +282,7 @@ def to_model( body = PipelineRunResponseBody( user=self.user.to_model() if self.user else None, - status=ExecutionStatus(self.status), + status=self.status, stack=stack, pipeline=pipeline, build=build, @@ -322,7 +327,7 @@ def update(self, run_update: "PipelineRunUpdate") -> "PipelineRunSchema": The updated `PipelineRunSchema`. """ if run_update.status: - self.status = run_update.status.value + self.status = run_update.status self.end_time = run_update.end_time self.updated = datetime.utcnow() @@ -367,7 +372,7 @@ def update_placeholder( self.orchestrator_run_id = request.orchestrator_run_id self.orchestrator_environment = orchestrator_environment - self.status = request.status.value + self.status = request.status self.updated = datetime.utcnow() diff --git a/src/zenml/zen_stores/schemas/run_metadata_schemas.py b/src/zenml/zen_stores/schemas/run_metadata_schemas.py index f84e210d97d..ade0bb1449a 100644 --- a/src/zenml/zen_stores/schemas/run_metadata_schemas.py +++ b/src/zenml/zen_stores/schemas/run_metadata_schemas.py @@ -109,7 +109,7 @@ class RunMetadataSchema(BaseSchema, table=True): key: str value: str = Field(sa_column=Column(TEXT, nullable=False)) - type: str + type: MetadataTypeEnum def to_model( self, @@ -134,7 +134,7 @@ def to_model( created=self.created, updated=self.updated, value=json.loads(self.value), - type=MetadataTypeEnum(self.type), + type=self.type, ) metadata = None if include_metadata: diff --git a/src/zenml/zen_stores/schemas/secret_schemas.py b/src/zenml/zen_stores/schemas/secret_schemas.py index 94059c6b102..468318c87c8 100644 --- a/src/zenml/zen_stores/schemas/secret_schemas.py +++ b/src/zenml/zen_stores/schemas/secret_schemas.py @@ -55,7 +55,7 @@ class SecretSchema(NamedSchema, table=True): __tablename__ = "secret" - scope: str + scope: SecretScope values: Optional[bytes] = Field(sa_column=Column(TEXT, nullable=True)) @@ -177,7 +177,7 @@ def from_request( assert secret.user is not None, "User must be set for secret creation." return cls( name=secret.name, - scope=secret.scope.value, + scope=secret.scope, workspace_id=secret.workspace, user_id=secret.user, # Don't store secret values implicitly in the secret. The @@ -204,10 +204,7 @@ def update( for field, value in secret_update.dict( exclude_unset=True, exclude={"workspace", "user", "values"} ).items(): - if field == "scope": - setattr(self, field, value.value) - else: - setattr(self, field, value) + setattr(self, field, value) self.updated = datetime.utcnow() return self @@ -242,7 +239,7 @@ def to_model( user=self.user.to_model() if self.user else None, created=self.created, updated=self.updated, - scope=SecretScope(self.scope), + scope=self.scope, ) return SecretResponse( id=self.id, diff --git a/src/zenml/zen_stores/schemas/service_schemas.py b/src/zenml/zen_stores/schemas/service_schemas.py new file mode 100644 index 00000000000..a38c0b68425 --- /dev/null +++ b/src/zenml/zen_stores/schemas/service_schemas.py @@ -0,0 +1,249 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""SQLModel implementation of service table.""" + +import base64 +import json +from datetime import datetime +from typing import Any, Optional +from uuid import UUID + +from sqlalchemy import TEXT, Column +from sqlmodel import Field, Relationship + +from zenml.models.v2.core.service import ( + ServiceRequest, + ServiceResponse, + ServiceResponseBody, + ServiceResponseMetadata, + ServiceResponseResources, + ServiceUpdate, +) +from zenml.utils.dict_utils import dict_to_bytes +from zenml.zen_stores.schemas.base_schemas import NamedSchema +from zenml.zen_stores.schemas.model_schemas import ModelVersionSchema +from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema +from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field +from zenml.zen_stores.schemas.user_schemas import UserSchema +from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema + + +class ServiceSchema(NamedSchema, table=True): + """SQL Model for service.""" + + __tablename__ = "service" + + workspace_id: UUID = build_foreign_key_field( + source=__tablename__, + target=WorkspaceSchema.__tablename__, + source_column="workspace_id", + target_column="id", + ondelete="CASCADE", + nullable=False, + ) + workspace: "WorkspaceSchema" = Relationship(back_populates="services") + + user_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=UserSchema.__tablename__, + source_column="user_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + user: Optional["UserSchema"] = Relationship(back_populates="services") + service_source: Optional[str] = Field( + sa_column=Column(TEXT, nullable=True) + ) + service_type: str = Field(sa_column=Column(TEXT, nullable=False)) + type: str = Field(sa_column=Column(TEXT, nullable=False)) + flavor: str = Field(sa_column=Column(TEXT, nullable=False)) + admin_state: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + state: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + labels: Optional[bytes] + config: bytes + status: Optional[bytes] + endpoint: Optional[bytes] + prediction_url: Optional[str] = Field( + sa_column=Column(TEXT, nullable=True) + ) + health_check_url: Optional[str] = Field( + sa_column=Column(TEXT, nullable=True) + ) + pipeline_name: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) + pipeline_step_name: Optional[str] = Field( + sa_column=Column(TEXT, nullable=True) + ) + model_version_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target=ModelVersionSchema.__tablename__, + source_column="model_version_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + model_version: Optional["ModelVersionSchema"] = Relationship( + back_populates="services", + ) + pipeline_run_id: Optional[UUID] = build_foreign_key_field( + source=__tablename__, + target="pipeline_run", + source_column="pipeline_run_id", + target_column="id", + ondelete="SET NULL", + nullable=True, + ) + pipeline_run: Optional["PipelineRunSchema"] = Relationship( + back_populates="services", + ) + + def to_model( + self, + include_metadata: bool = False, + include_resources: bool = False, + **kwargs: Any, + ) -> ServiceResponse: + """Convert an `ServiceSchema` to an `ServiceResponse`. + + Args: + include_metadata: Whether to include metadata in the response. + include_resources: Whether to include resources in the response. + kwargs: Additional keyword arguments. + + Returns: + The created `ServiceResponse`. + """ + body = ServiceResponseBody( + user=self.user.to_model() if self.user else None, + workspace=self.workspace.to_model(), + created=self.created, + updated=self.updated, + service_type=json.loads(self.service_type), + labels=json.loads(base64.b64decode(self.labels).decode()) + if self.labels + else None, + state=self.state, + ) + metadata = None + if include_metadata: + metadata = ServiceResponseMetadata( + workspace=self.workspace.to_model(), + service_source=self.service_source, + config=json.loads(base64.b64decode(self.config).decode()), + status=json.loads(base64.b64decode(self.status).decode()) + if self.status + else None, + endpoint=json.loads(base64.b64decode(self.endpoint).decode()) + if self.endpoint + else None, + admin_state=self.admin_state or None, + prediction_url=self.prediction_url or None, + health_check_url=self.health_check_url, + ) + resources = None + if include_resources: + resources = ServiceResponseResources( + model_version=self.model_version.to_model() + if self.model_version + else None, + pipeline_run=self.pipeline_run.to_model() + if self.pipeline_run + else None, + ) + return ServiceResponse( + id=self.id, + name=self.name, + body=body, + metadata=metadata, + resources=resources, + ) + + def update( + self, + update: ServiceUpdate, + ) -> "ServiceSchema": + """Updates a `ServiceSchema` from a `ServiceUpdate`. + + Args: + update: The `ServiceUpdate` to update from. + + Returns: + The updated `ServiceSchema`. + """ + for field, value in update.dict( + exclude_unset=True, exclude_none=True + ).items(): + if field == "labels": + self.labels = ( + dict_to_bytes(update.labels) if update.labels else None + ) + elif field == "status": + self.status = ( + dict_to_bytes(update.status) if update.status else None + ) + self.state = ( + update.status.get("state") if update.status else None + ) + elif field == "endpoint": + self.endpoint = ( + dict_to_bytes(update.endpoint) if update.endpoint else None + ) + else: + setattr(self, field, value) + self.updated = datetime.utcnow() + return self + + @classmethod + def from_request( + cls, service_request: "ServiceRequest" + ) -> "ServiceSchema": + """Convert a `ServiceRequest` to a `ServiceSchema`. + + Args: + service_request: The request model to convert. + + Returns: + The converted schema. + """ + return cls( + name=service_request.name, + workspace_id=service_request.workspace, + user_id=service_request.user, + service_source=service_request.service_source, + service_type=service_request.service_type.json(), + type=service_request.service_type.type, + flavor=service_request.service_type.flavor, + admin_state=service_request.admin_state, + config=dict_to_bytes(service_request.config), + labels=dict_to_bytes(service_request.labels) + if service_request.labels + else None, + status=dict_to_bytes(service_request.status) + if service_request.status + else None, + endpoint=dict_to_bytes(service_request.endpoint) + if service_request.endpoint + else None, + state=service_request.status.get("state") + if service_request.status + else None, + model_version_id=service_request.model_version_id, + pipeline_run_id=service_request.pipeline_run_id, + prediction_url=service_request.prediction_url, + health_check_url=service_request.health_check_url, + pipeline_name=service_request.config.get("pipeline_name"), + pipeline_step_name=service_request.config.get( + "pipeline_step_name" + ), + ) diff --git a/src/zenml/zen_stores/schemas/step_run_schemas.py b/src/zenml/zen_stores/schemas/step_run_schemas.py index 4ae1d111f90..8ba628fc92a 100644 --- a/src/zenml/zen_stores/schemas/step_run_schemas.py +++ b/src/zenml/zen_stores/schemas/step_run_schemas.py @@ -27,6 +27,8 @@ from zenml.enums import ( ExecutionStatus, MetadataResourceTypes, + StepRunInputArtifactType, + StepRunOutputArtifactType, ) from zenml.models import ( StepRunRequest, @@ -58,7 +60,7 @@ class StepRunSchema(NamedSchema, table=True): # Fields start_time: Optional[datetime] = Field(nullable=True) end_time: Optional[datetime] = Field(nullable=True) - status: str = Field(nullable=False) + status: ExecutionStatus = Field(nullable=False) docstring: Optional[str] = Field(sa_column=Column(TEXT, nullable=True)) cache_key: Optional[str] = Field(nullable=True) @@ -163,7 +165,7 @@ def from_request(cls, request: StepRunRequest) -> "StepRunSchema": user_id=request.user, start_time=request.start_time, end_time=request.end_time, - status=request.status.value, + status=request.status, original_step_run_id=request.original_step_run_id, pipeline_run_id=request.pipeline_run_id, deployment_id=request.deployment, @@ -223,7 +225,7 @@ def to_model( body = StepRunResponseBody( user=self.user.to_model() if self.user else None, - status=ExecutionStatus(self.status), + status=self.status, inputs=input_artifacts, outputs=output_artifacts, created=self.created, @@ -268,7 +270,7 @@ def update(self, step_update: "StepRunUpdate") -> "StepRunSchema": exclude_unset=True, exclude_none=True ).items(): if key == "status": - self.status = value.value + self.status = value if key == "end_time": self.end_time = value @@ -310,7 +312,7 @@ class StepRunInputArtifactSchema(SQLModel, table=True): # Fields name: str = Field(nullable=False, primary_key=True) - type: str + type: StepRunInputArtifactType # Foreign keys step_id: UUID = build_foreign_key_field( @@ -346,7 +348,7 @@ class StepRunOutputArtifactSchema(SQLModel, table=True): # Fields name: str - type: str + type: StepRunOutputArtifactType # Foreign keys step_id: UUID = build_foreign_key_field( diff --git a/src/zenml/zen_stores/schemas/tag_schemas.py b/src/zenml/zen_stores/schemas/tag_schemas.py index 1cfbfc29c55..803a53805c5 100644 --- a/src/zenml/zen_stores/schemas/tag_schemas.py +++ b/src/zenml/zen_stores/schemas/tag_schemas.py @@ -108,11 +108,7 @@ def update(self, update: TagUpdate) -> "TagSchema": The updated `TagSchema`. """ for field, value in update.dict(exclude_unset=True).items(): - if field == "color": - setattr(self, field, value.value) - else: - setattr(self, field, value) - + setattr(self, field, value) self.updated = datetime.utcnow() return self diff --git a/src/zenml/zen_stores/schemas/user_schemas.py b/src/zenml/zen_stores/schemas/user_schemas.py index 610e45d4c18..72737ffa9b8 100644 --- a/src/zenml/zen_stores/schemas/user_schemas.py +++ b/src/zenml/zen_stores/schemas/user_schemas.py @@ -54,6 +54,7 @@ ScheduleSchema, SecretSchema, ServiceConnectorSchema, + ServiceSchema, StackComponentSchema, StackSchema, StepRunSchema, @@ -124,6 +125,7 @@ class UserSchema(NamedSchema, table=True): code_repositories: List["CodeRepositorySchema"] = Relationship( back_populates="user", ) + services: List["ServiceSchema"] = Relationship(back_populates="user") service_connectors: List["ServiceConnectorSchema"] = Relationship( back_populates="user", ) diff --git a/src/zenml/zen_stores/schemas/workspace_schemas.py b/src/zenml/zen_stores/schemas/workspace_schemas.py index aa9fd28f16c..3da451ac6c1 100644 --- a/src/zenml/zen_stores/schemas/workspace_schemas.py +++ b/src/zenml/zen_stores/schemas/workspace_schemas.py @@ -45,6 +45,7 @@ ScheduleSchema, SecretSchema, ServiceConnectorSchema, + ServiceSchema, StackComponentSchema, StackSchema, StepRunSchema, @@ -120,6 +121,10 @@ class WorkspaceSchema(NamedSchema, table=True): back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, ) + services: List["ServiceSchema"] = Relationship( + back_populates="workspace", + sa_relationship_kwargs={"cascade": "delete"}, + ) service_connectors: List["ServiceConnectorSchema"] = Relationship( back_populates="workspace", sa_relationship_kwargs={"cascade": "delete"}, diff --git a/src/zenml/zen_stores/sql_zen_store.py b/src/zenml/zen_stores/sql_zen_store.py index 5feb10d93a2..7cf37dcad84 100644 --- a/src/zenml/zen_stores/sql_zen_store.py +++ b/src/zenml/zen_stores/sql_zen_store.py @@ -36,11 +36,11 @@ TypeVar, Union, cast, - get_origin, ) from uuid import UUID from pydantic import Field, SecretStr, root_validator, validator +from pydantic.json import pydantic_encoder from sqlalchemy import asc, desc, func from sqlalchemy.engine import URL, Engine, make_url from sqlalchemy.exc import ( @@ -48,7 +48,7 @@ IntegrityError, NoResultFound, ) -from sqlalchemy.orm import Mapped, noload +from sqlalchemy.orm import noload from sqlmodel import ( Session, SQLModel, @@ -209,6 +209,10 @@ ServiceConnectorResponse, ServiceConnectorTypeModel, ServiceConnectorUpdate, + ServiceFilter, + ServiceRequest, + ServiceResponse, + ServiceUpdate, StackFilter, StackRequest, StackResponse, @@ -299,6 +303,7 @@ ArtifactVisualizationSchema, ) from zenml.zen_stores.schemas.logs_schemas import LogsSchema +from zenml.zen_stores.schemas.service_schemas import ServiceSchema from zenml.zen_stores.schemas.trigger_schemas import TriggerSchema from zenml.zen_stores.secrets_stores.base_secrets_store import BaseSecretsStore from zenml.zen_stores.secrets_stores.sql_secrets_store import ( @@ -861,18 +866,23 @@ def filter_and_paginate( custom_fetch_result = custom_fetch(session, query, filter_model) total = len(custom_fetch_result) else: - total = ( - session.query(func.count()) - .select_from(query.options(noload("*")).subquery()) - .scalar() + total = session.scalar( + select([func.count("*")]).select_from( + query.options(noload("*")).subquery() + ) ) # Sorting column, operand = filter_model.sorting_params if operand == SorterOps.DESCENDING: - query = query.order_by(desc(getattr(table, column))) + sort_clause = desc(getattr(table, column)) else: - query = query.order_by(asc(getattr(table, column))) + sort_clause = asc(getattr(table, column)) + + # We always add the `id` column as a tiebreaker to ensure a stable, + # repeatable order of items, otherwise subsequent pages might contain + # the same items. + query = query.order_by(sort_clause, asc(table.id)) # Get the total amount of pages in the database for a given query if total == 0: @@ -1363,7 +1373,9 @@ def migrate_database(self) -> None: # identity table with needed info. logger.info("Creating database tables") with self.engine.begin() as conn: - SQLModel.metadata.create_all(conn) + conn.run_callable( + SQLModel.metadata.create_all # type: ignore[arg-type] + ) with Session(self.engine) as session: session.add( IdentitySchema( @@ -1757,6 +1769,175 @@ def delete_api_key( session.delete(api_key) session.commit() + # -------------------- Services -------------------- + + @staticmethod + def _fail_if_service_with_config_exists( + service_request: ServiceRequest, session: Session + ) -> None: + """Raise an exception if a service with same name/config exists. + + Args: + service_request: The service to check for. + session: The database session to use for the query. + + Raises: + EntityExistsError: If a service with the given name and + type already exists. + """ + # Check if service with the same domain key (name, config, workspace) + # already exists + + existing_domain_service = session.exec( + select(ServiceSchema).where( + ServiceSchema.config + == base64.b64encode( + json.dumps( + service_request.config, + sort_keys=False, + default=pydantic_encoder, + ).encode("utf-8") + ) + ) + ).first() + + if existing_domain_service: + raise EntityExistsError( + f"Unable to create service '{service_request.name}' with the given configuration: " + "A service with the same configuration already exists." + ) + + def create_service(self, service: ServiceRequest) -> ServiceResponse: + """Create a new service. + + Args: + service: The service to create. + + Returns: + The newly created service. + """ + with Session(self.engine) as session: + # Check if a service with the given name already exists + self._fail_if_service_with_config_exists( + service_request=service, + session=session, + ) + + # Create the service. + service_schema = ServiceSchema.from_request(service) + logger.debug("Creating service: %s", service_schema) + session.add(service_schema) + session.commit() + + return service_schema.to_model( + include_metadata=True, include_resources=True + ) + + def get_service( + self, service_id: UUID, hydrate: bool = True + ) -> ServiceResponse: + """Get a service. + + Args: + service_id: The ID of the service to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The service. + + Raises: + KeyError: if the service doesn't exist. + """ + with Session(self.engine) as session: + service = session.exec( + select(ServiceSchema).where(ServiceSchema.id == service_id) + ).first() + if service is None: + raise KeyError( + f"Unable to get service with ID {service_id}: No " + "service with this ID found." + ) + return service.to_model( + include_metadata=hydrate, include_resources=hydrate + ) + + def list_services( + self, filter_model: ServiceFilter, hydrate: bool = False + ) -> Page[ServiceResponse]: + """List all services matching the given filter criteria. + + Args: + filter_model: All filter parameters including pagination + params. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A list of all services matching the filter criteria. + """ + with Session(self.engine) as session: + query = select(ServiceSchema) + return self.filter_and_paginate( + session=session, + query=query, + table=ServiceSchema, + filter_model=filter_model, + hydrate=hydrate, + ) + + def update_service( + self, service_id: UUID, update: ServiceUpdate + ) -> ServiceResponse: + """Update a service. + + Args: + service_id: The ID of the service to update. + update: The update to be applied to the service. + + Returns: + The updated service. + + Raises: + KeyError: if the service doesn't exist. + """ + with Session(self.engine) as session: + existing_service = session.exec( + select(ServiceSchema).where(ServiceSchema.id == service_id) + ).first() + if not existing_service: + raise KeyError(f"Service with ID {service_id} not found.") + + # Update the schema itself. + existing_service.update(update=update) + logger.debug("Updated service: %s", existing_service) + session.add(existing_service) + session.commit() + session.refresh(existing_service) + return existing_service.to_model( + include_metadata=True, include_resources=True + ) + + def delete_service(self, service_id: UUID) -> None: + """Delete a service. + + Args: + service_id: The ID of the service to delete. + + Raises: + KeyError: if the service doesn't exist. + """ + with Session(self.engine) as session: + existing_service = session.exec( + select(ServiceSchema).where(ServiceSchema.id == service_id) + ).first() + if not existing_service: + raise KeyError(f"Service with ID {service_id} not found.") + + # Delete the service + session.delete(existing_service) + session.commit() + # -------------------- Artifacts -------------------- def create_artifact(self, artifact: ArtifactRequest) -> ArtifactResponse: @@ -2566,9 +2747,7 @@ def update_stack_component( if existing_component.name != component_update.name: self._fail_if_component_with_name_type_exists( name=component_update.name, - component_type=StackComponentType( - existing_component.type - ), + component_type=existing_component.type, workspace_id=existing_component.workspace_id, session=session, ) @@ -3320,6 +3499,7 @@ def _custom_fetch( PipelineRunSchema.created == max_date_subquery.c.max_created, ) + .order_by(desc(PipelineRunSchema.updated)) ) return self.filter_and_paginate( @@ -6865,9 +7045,7 @@ def _update_pipeline_run_status( assert pipeline_run.deployment num_steps = len(pipeline_run.deployment.to_model().step_configurations) new_status = get_pipeline_run_status( - step_statuses=[ - ExecutionStatus(step_run.status) for step_run in step_runs - ], + step_statuses=[step_run.status for step_run in step_runs], num_steps=num_steps, ) @@ -7277,8 +7455,6 @@ def _get_resource_references( for resource_attr in resource_attrs: # Extract the target schema from the annotation annotation = UserSchema.__annotations__[resource_attr] - if get_origin(annotation) == Mapped: - annotation = annotation.__args__[0] # The annotation must be of the form # `typing.List[ForwardRef('')]` @@ -7336,13 +7512,11 @@ def _account_owns_resources( resource_attrs = self._get_resource_references() for schema, resource_attr in resource_attrs: # Check if the user owns any resources of this type - count = ( - session.query(func.count()) + count = session.scalar( + select([func.count("*")]) .select_from(schema) .where(getattr(schema, resource_attr) == account.id) - .scalar() ) - if count > 0: logger.debug( f"User {account.name} owns {count} resources of type " @@ -8421,7 +8595,9 @@ def get_model_version( f"`{model_version_id}`: No model version with this " f"ID found." ) - return model_version.to_model(include_metadata=hydrate) + return model_version.to_model( + include_metadata=hydrate, include_resources=hydrate + ) def list_model_versions( self, @@ -8570,7 +8746,9 @@ def update_model_version( session.commit() session.refresh(existing_model_version) - return existing_model_version.to_model(include_metadata=True) + return existing_model_version.to_model( + include_metadata=True, include_resources=True + ) # ------------------------ Model Versions Artifacts ------------------------ diff --git a/src/zenml/zen_stores/zen_store_interface.py b/src/zenml/zen_stores/zen_store_interface.py index 7914a5681bd..7163936d506 100644 --- a/src/zenml/zen_stores/zen_store_interface.py +++ b/src/zenml/zen_stores/zen_store_interface.py @@ -104,6 +104,10 @@ ServiceConnectorResponse, ServiceConnectorTypeModel, ServiceConnectorUpdate, + ServiceFilter, + ServiceRequest, + ServiceResponse, + ServiceUpdate, StackFilter, StackRequest, StackResponse, @@ -359,6 +363,87 @@ def delete_api_key( for the given service account. """ + # -------------------- Services -------------------- + + @abstractmethod + def create_service( + self, + service: ServiceRequest, + ) -> ServiceResponse: + """Create a new service. + + Args: + service: The service to create. + + Returns: + The newly created service. + + Raises: + EntityExistsError: If a service with the same name already exists. + """ + + @abstractmethod + def get_service( + self, service_id: UUID, hydrate: bool = True + ) -> ServiceResponse: + """Get a service by ID. + + Args: + service_id: The ID of the service to get. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + The service. + + Raises: + KeyError: if the service doesn't exist. + """ + + @abstractmethod + def list_services( + self, filter_model: ServiceFilter, hydrate: bool = False + ) -> Page[ServiceResponse]: + """List all services matching the given filter criteria. + + Args: + filter_model: All filter parameters including pagination + params. + hydrate: Flag deciding whether to hydrate the output model(s) + by including metadata fields in the response. + + Returns: + A list of all services matching the filter criteria. + """ + + @abstractmethod + def update_service( + self, service_id: UUID, update: ServiceUpdate + ) -> ServiceResponse: + """Update an existing service. + + Args: + service_id: The ID of the service to update. + update: The update to be applied to the service. + + Returns: + The updated service. + + Raises: + KeyError: if the service doesn't exist. + """ + + @abstractmethod + def delete_service(self, service_id: UUID) -> None: + """Delete a service. + + Args: + service_id: The ID of the service to delete. + + Raises: + KeyError: if the service doesn't exist. + """ + # -------------------- Artifacts -------------------- @abstractmethod diff --git a/tests/integration/examples/bentoml/steps/prediction_service_loader.py b/tests/integration/examples/bentoml/steps/prediction_service_loader.py index 3871fe1e8ab..1fa9e669a9c 100644 --- a/tests/integration/examples/bentoml/steps/prediction_service_loader.py +++ b/tests/integration/examples/bentoml/steps/prediction_service_loader.py @@ -29,7 +29,7 @@ def bentoml_prediction_service_loader( """Get the BentoML prediction service started by the deployment pipeline. Args: - pipeline_name: name of the pipeline that deployed the model. + pipeline_name: name of the pipeline_name that deployed the model. step_name: the name of the step that deployed the model. model_name: the name of the model that was deployed. """ diff --git a/tests/integration/examples/huggingface/steps/prediction_service_loader/prediction_service_loader.py b/tests/integration/examples/huggingface/steps/prediction_service_loader/prediction_service_loader.py index 49a763bdca1..91ea669c9cc 100644 --- a/tests/integration/examples/huggingface/steps/prediction_service_loader/prediction_service_loader.py +++ b/tests/integration/examples/huggingface/steps/prediction_service_loader/prediction_service_loader.py @@ -43,19 +43,16 @@ def prediction_service_loader( # get the Huggingface model deployer stack component model_deployer = HuggingFaceModelDeployer.get_active_model_deployer() - # fetch existing services with same pipeline name, step name and model name - services = model_deployer.find_model_server( + if services := model_deployer.find_model_server( pipeline_name=pipeline_name, pipeline_step_name=pipeline_step_name, model_name=model_name, running=running, - ) - - if not services: + ): + return cast(HuggingFaceDeploymentService, services[0]) + else: raise RuntimeError( f"No Huggingface inference endpoint deployed by step " f"'{pipeline_step_name}' in pipeline '{pipeline_name}' with name " f"'{model_name}' is currently running." ) - - return cast(HuggingFaceDeploymentService, services[0]) diff --git a/tests/integration/examples/mlflow/pipelines/deployment_pipelines/deployment_inference_pipeline.py b/tests/integration/examples/mlflow/pipelines/deployment_pipelines/deployment_inference_pipeline.py index fad0bd06331..29bb1b57887 100644 --- a/tests/integration/examples/mlflow/pipelines/deployment_pipelines/deployment_inference_pipeline.py +++ b/tests/integration/examples/mlflow/pipelines/deployment_pipelines/deployment_inference_pipeline.py @@ -36,6 +36,5 @@ def mlflow_deployment_inference_pipeline( model_deployment_service = prediction_service_loader( pipeline_name=pipeline_name, pipeline_step_name=pipeline_step_name, - running=False, ) predictor(model_deployment_service, inference_data) diff --git a/tests/integration/examples/mlflow/steps/prediction_service_loader_step.py b/tests/integration/examples/mlflow/steps/prediction_service_loader_step.py index 36067d0dfbb..4e4e8427b37 100644 --- a/tests/integration/examples/mlflow/steps/prediction_service_loader_step.py +++ b/tests/integration/examples/mlflow/steps/prediction_service_loader_step.py @@ -24,7 +24,6 @@ def prediction_service_loader( pipeline_name: str, pipeline_step_name: str, running: bool = True, - model_name: str = "model", ) -> MLFlowDeploymentService: """Get the prediction service started by the deployment pipeline. @@ -40,19 +39,13 @@ def prediction_service_loader( model_deployer = MLFlowModelDeployer.get_active_model_deployer() # fetch existing services with same pipeline name, step name and model name - existing_services = model_deployer.find_model_server( - pipeline_name=pipeline_name, - pipeline_step_name=pipeline_step_name, - model_name=model_name, - running=running, - ) + existing_services = model_deployer.find_model_server() if not existing_services: raise RuntimeError( f"No MLflow prediction service deployed by the " f"{pipeline_step_name} step in the {pipeline_name} " - f"pipeline for the '{model_name}' model is currently " - f"running." + f"pipeline" ) return existing_services[0] diff --git a/tests/integration/functional/zen_server/test_zen_server.py b/tests/integration/functional/zen_server/test_zen_server.py index 93290aa22d8..322635c8e8b 100644 --- a/tests/integration/functional/zen_server/test_zen_server.py +++ b/tests/integration/functional/zen_server/test_zen_server.py @@ -17,8 +17,13 @@ import pytest import requests +from zenml.client import Client +from zenml.constants import DEFAULT_USERNAME +from zenml.enums import StoreType from zenml.utils.networking_utils import scan_for_available_port from zenml.zen_server.deploy import ServerDeployer, ServerDeploymentConfig +from zenml.zen_server.utils import server_config +from zenml.zen_stores.rest_zen_store import RestZenStore SERVER_START_STOP_TIMEOUT = 60 @@ -73,3 +78,17 @@ def test_server_up_down(clean_client, mocker): print(line) raise assert deployer.list_servers() == [] + + +def test_rate_limit_is_not_impacted_by_successful_requests(): + zen_store = Client().zen_store + if zen_store.type == StoreType.SQL: + pytest.skip("SQL ZenStore does not support rate limiting.") + + assert Client().active_user.name == DEFAULT_USERNAME + zen_store: RestZenStore = zen_store + + repeat = server_config().login_rate_limit_minute * 2 + for _ in range(repeat): + zen_store.clear_session() + zen_store.session diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 930d8a48929..df139355033 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -69,10 +69,17 @@ WorkspaceResponseBody, WorkspaceResponseMetadata, ) +from zenml.models.v2.core.service import ( + ServiceResponse, + ServiceResponseBody, + ServiceResponseMetadata, +) from zenml.new.pipelines.pipeline import Pipeline from zenml.orchestrators.base_orchestrator import BaseOrchestratorConfig from zenml.orchestrators.local.local_orchestrator import LocalOrchestrator from zenml.pipelines import pipeline +from zenml.services.service_status import ServiceState +from zenml.services.service_type import ServiceType from zenml.stack.stack import Stack from zenml.stack.stack_component import ( StackComponentConfig, @@ -693,3 +700,64 @@ def sample_hub_plugin_response_model() -> HubPluginResponseModel: updated=datetime.now(), requirements=["ploogin==0.0.1", "zenml>=0.1.0"], ) + + +# Test data +service_id = "12345678-1234-5678-1234-567812345678" +service_name = "test_service" +service_type = ServiceType( + type="model-serving", flavor="test_flavor", name="test_name" +) +service_source = "tests.unit.services.test_service.TestService" +admin_state = ServiceState.ACTIVE +config = { + "type": "zenml.services.service.ServiceConfig", + "name": "test_service", + "description": "", + "pipeline_name": "", + "pipeline_step_name": "", + "model_name": "", + "model_version": "", + "service_name": "zenml-test_service", +} +labels = {"label1": "value1", "label2": "value2"} +status = { + "type": "zenml.services.service_status.ServiceStatus", + "state": ServiceState.ACTIVE, + "last_state": ServiceState.INACTIVE, + "last_error": "", +} +endpoint = None +prediction_url = "http://example.com/predict" +health_check_url = "http://example.com/health" +created_time = datetime(2024, 3, 14, 10, 30) +updated_time = datetime(2024, 3, 14, 11, 45) + + +@pytest.fixture +def service_response( + sample_workspace_model, +): + body = ServiceResponseBody( + service_type=service_type, + labels=labels, + created=created_time, + updated=updated_time, + state=admin_state, + ) + metadata = ServiceResponseMetadata( + service_source=service_source, + admin_state=admin_state, + config=config, + status=status, + endpoint=endpoint, + prediction_url=prediction_url, + health_check_url=health_check_url, + workspace=sample_workspace_model, + ) + return ServiceResponse( + id=service_id, + name=service_name, + body=body, + metadata=metadata, + ) diff --git a/tests/unit/models/test_service_models.py b/tests/unit/models/test_service_models.py new file mode 100644 index 00000000000..2148d576222 --- /dev/null +++ b/tests/unit/models/test_service_models.py @@ -0,0 +1,130 @@ +# Copyright (c) ZenML GmbH 2023. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from datetime import datetime + +import pytest + +from zenml.constants import STR_FIELD_MAX_LENGTH +from zenml.models import ( + ServiceRequest, + ServiceResponse, + ServiceResponseBody, + ServiceResponseMetadata, +) +from zenml.services.service_status import ServiceState +from zenml.services.service_type import ServiceType + +# Test data +service_id = "12345678-1234-5678-1234-567812345678" +service_name = "test_service" +service_type = ServiceType( + type="model-serving", flavor="test_flavor", name="test_name" +) +service_source = "tests.unit.services.test_service.TestService" +admin_state = ServiceState.ACTIVE +config = { + "type": "zenml.services.service.ServiceConfig", + "name": "test_service", + "description": "", + "pipeline_name": "", + "pipeline_step_name": "", + "model_name": "", + "model_version": "", + "service_name": "zenml-test_service", +} +labels = {"label1": "value1", "label2": "value2"} +status = { + "type": "zenml.services.service_status.ServiceStatus", + "state": ServiceState.ACTIVE, + "last_state": ServiceState.INACTIVE, + "last_error": "", +} +endpoint = None +prediction_url = "http://example.com/predict" +health_check_url = "http://example.com/health" +created_time = datetime(2023, 3, 14, 10, 30) +updated_time = datetime(2023, 3, 14, 11, 45) + + +@pytest.fixture +def service_response( + sample_workspace_model, +): + body = ServiceResponseBody( + service_type=service_type, + labels=labels, + created=created_time, + updated=updated_time, + state=admin_state, + ) + metadata = ServiceResponseMetadata( + service_source=service_source, + admin_state=admin_state, + config=config, + status=status, + endpoint=endpoint, + prediction_url=prediction_url, + health_check_url=health_check_url, + workspace=sample_workspace_model, + ) + return ServiceResponse( + id=service_id, + name=service_name, + body=body, + metadata=metadata, + ) + + +def test_service_response_properties(service_response): + assert service_response.service_type == service_type + assert service_response.labels == labels + assert service_response.service_source == service_source + assert service_response.config == config + assert service_response.status == status + assert service_response.endpoint == endpoint + assert service_response.created == created_time + assert service_response.updated == updated_time + assert service_response.admin_state == admin_state + assert service_response.prediction_url == prediction_url + assert service_response.health_check_url == health_check_url + assert service_response.state == admin_state + + +def test_service_request_name_too_long(): + # Test that the service name cannot be longer than the maximum allowed length + long_name = "a" * (STR_FIELD_MAX_LENGTH + 1) + with pytest.raises(ValueError): + ServiceRequest( + name=long_name, + service_type=ServiceType( + type="model-serving", flavor="test_flavor", name="test_name" + ), + service_source="path.to.ServiceClass", + admin_state=ServiceState.ACTIVE, + config={"param1": "value1"}, + ) + + +def test_service_request_invalid_service_type(): + # Test that an invalid service type raises an error + invalid_service_type = "invalid_type" + with pytest.raises(ValueError): + ServiceRequest( + name="test_service", + service_type=invalid_service_type, + service_source="path.to.ServiceClass", + admin_state=ServiceState.ACTIVE, + config={"param1": "value1"}, + ) diff --git a/tests/unit/services/__init__.py b/tests/unit/services/__init__.py new file mode 100644 index 00000000000..cd90a82cfc2 --- /dev/null +++ b/tests/unit/services/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. diff --git a/tests/unit/services/test_service.py b/tests/unit/services/test_service.py new file mode 100644 index 00000000000..b0875e9d62d --- /dev/null +++ b/tests/unit/services/test_service.py @@ -0,0 +1,112 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +from typing import Generator, Optional, Tuple +from uuid import UUID + +import pytest + +from zenml.services import ( + BaseService, + ServiceConfig, + ServiceState, + ServiceStatus, +) +from zenml.services.service import ZENM_ENDPOINT_PREFIX + + +# Create a concrete subclass of BaseService +class TestService(BaseService): + """Test service class for testing BaseService.""" + + SERVICE_TYPE = { + "type": "model-serving", + "flavor": "test_flavor", + "name": "test_name", + } + + @property + def is_running(self): + return True + + @property + def is_stopped(self): + return not self.is_running + + @property + def is_failed(self): + return False + + def check_status(self) -> Tuple[ServiceState, str]: + return ServiceState.ACTIVE, "Service is running" + + def get_logs( + self, follow: bool = False, tail: Optional[int] = None + ) -> Generator[str, bool, None]: + return (f"log line {i}" for i in range(5)) + + +# Modify the base_service fixture to use the TestService subclass +@pytest.fixture +def base_service(): + return TestService( + uuid=UUID("12345678-1234-5678-1234-567812345678"), + admin_state=ServiceState.ACTIVE, + config=ServiceConfig(name="test_service", param1="value1", param2=2), + status=ServiceStatus( + state=ServiceState.ACTIVE, + last_error="", + last_status=ServiceState.INACTIVE, + ), + endpoint=None, + ) + + +# Update the test_from_model to handle the case when service_source is missing +def test_from_model(service_response): + service = BaseService.from_model(service_response) + assert isinstance(service, TestService) + assert service.uuid == service_response.id + assert service.admin_state == service_response.admin_state + assert service.config == service_response.config + assert service.status == service_response.status + assert service.SERVICE_TYPE["type"] == service_response.service_type.type + assert ( + service.SERVICE_TYPE["flavor"] == service_response.service_type.flavor + ) + assert service.endpoint == service_response.endpoint + + +def test_update_status(base_service, monkeypatch): + def mock_check_status(self): + return ServiceState.ACTIVE, "Service is running" + + monkeypatch.setattr(BaseService, "check_status", mock_check_status) + base_service.update_status() + + assert base_service.status.state == ServiceState.ACTIVE + assert base_service.status.last_error == "Service is running" + + +def test_service_config_init_without_name_or_model_name(): + """Test initialization without name or model_name.""" + with pytest.raises(ValueError) as excinfo: + ServiceConfig() + assert "Either 'name' or 'model_name' must be set." in str(excinfo.value) + + +def test_service_config_init_with_name(): + """Test initialization with name.""" + config = ServiceConfig(name="test-service") + assert config.name == "test-service" + assert config.service_name == f"{ZENM_ENDPOINT_PREFIX}test-service" diff --git a/tests/unit/test_constants.py b/tests/unit/test_constants.py index 1a5e76faa3b..78ab52076fa 100644 --- a/tests/unit/test_constants.py +++ b/tests/unit/test_constants.py @@ -12,19 +12,50 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. -import os -from zenml.constants import handle_int_env_var +from zenml.constants import handle_int_env_var, handle_json_env_var -def test_handle_int_env_var(): +def test_handle_int_env_var(monkeypatch): """Check handle_int_env_var in all cases.""" env_var = "ZENML_TEST_HANDLE_INT_ENV_VAR" # check value error (when it can't be converted to int) - os.environ[env_var] = "test" + monkeypatch.setenv(env_var, "test") assert 0 == handle_int_env_var(env_var, 0) # check if it isn't there (in case it doesn't exist) - del os.environ[env_var] + monkeypatch.delenv(env_var, raising=False) assert 0 == handle_int_env_var(env_var, 0) + + +def test_handle_json_env_var(monkeypatch): + # Given an environment variable that is json + monkeypatch.setenv("TEST_VAR", '["hello", "world"]') + + # When we ask for that variable and expect it to be a List + result = handle_json_env_var("TEST_VAR", expected_type=list) + + # Then we should get the list ["hello", "world"] + assert result == ["hello", "world"] + + # Given an environment variable that is not json + monkeypatch.setenv("TEST_VAR", "hello world") + + # When we ask for that variable and expect it to be a List + result = handle_json_env_var("TEST_VAR", expected_type=list) + + # Then we should get an empty list (the default) + assert result == [] + + # Given an environment variable that is json but not the expected type + monkeypatch.setenv("TEST_VAR", '{"hello": "world"}') + + # When we ask for that variable and expect it to be a List + result = handle_json_env_var("TEST_VAR", expected_type=list) + + # Then we should get an empty list (the default) + assert result == [] + + # Unset environment variable + monkeypatch.delenv("TEST_VAR", raising=False)