diff --git a/.github/actions/image-matrix-prep/action.yaml b/.github/actions/image-matrix-prep/action.yaml new file mode 100644 index 000000000000..247e4351875f --- /dev/null +++ b/.github/actions/image-matrix-prep/action.yaml @@ -0,0 +1,38 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +name: Image Matrix Prep +description: Prepares the matrix of images to build +inputs: + skip_images: + description: 'Comma separated list of images to skip' + required: false + default: '' +outputs: + matrix: + description: 'The matrix of images to build' + value: ${{ steps.set-matrix.outputs.matrix }} +runs: + using: "composite" + steps: + - uses: actions/checkout@v3 + - id: set-matrix + run: | + skipImages=",$INPUT_SKIP_IMAGES," + matrix=$(jq --arg skipImages "$skipImages" 'map(. | select(",\(."image-name")," | inside($skipImages)|not))' ./.github/workflows/build-workflow-matrix.json) + echo "matrix={\"include\":$(echo $matrix)}" >> $GITHUB_OUTPUT + shell: bash + env: + INPUT_SKIP_IMAGES: ${{ inputs.skip_images }} diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 864c3f659a0c..695d495858b8 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -14,6 +14,7 @@ # # This name is referenced in the release.yaml workflow, if you're changing here - change there name: Build +run-name: Building ${{ inputs.version }} ${{ github.ref_name }} on: push: @@ -43,6 +44,7 @@ on: description: 'Whether to build images from cache or not. Default: true, set to false only if required because that will cause a significant increase in build time' required: true default: 'true' + jobs: matrix_prep: runs-on: ubuntu-latest @@ -51,12 +53,10 @@ jobs: steps: - uses: actions/checkout@v3 - id: set-matrix - run: | - skipImages=",$INPUT_SKIP_IMAGES," - matrix=$(jq --arg skipImages "$skipImages" 'map(. | select(",\(."image-name")," | inside($skipImages)|not))' ./.github/workflows/build-workflow-matrix.json) - echo "matrix={\"include\":$(echo $matrix)}" >> $GITHUB_OUTPUT - env: - INPUT_SKIP_IMAGES: ${{ github.event.inputs.skip_images }} + uses: ./.github/actions/image-matrix-prep + with: + skip_images: ${{ github.event.inputs.skip_images }} + build-images: name: Build and push image - ${{ matrix.image-name }} (Python ${{ matrix.python-version }}) runs-on: ubuntu-latest diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index af89c68a2ce2..640fc1d29f63 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -19,6 +19,7 @@ on: branches: - development - '[0-9]+.[0-9]+.x' + - 'feature/**' # Run CI also on push to backport release branches - we sometimes push code there by cherry-picking, meaning it # doesn't go through CI (no PR) @@ -29,6 +30,9 @@ on: - master - '[0-9]+.[0-9]+.x' +env: + NAMESPACE: mlrun-integ-test + jobs: lint: name: Lint code (Python ${{ matrix.python-version }}) @@ -127,6 +131,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + # since github-actions gives us 14G only, and fills it up with some garbage + - name: Freeing up disk space + run: | + "${GITHUB_WORKSPACE}/automation/scripts/github_workflow_free_space.sh" - uses: manusa/actions-setup-minikube@v2.7.2 with: minikube version: "v1.28.0" @@ -140,6 +148,19 @@ jobs: - name: Run GO tests run: | make test-go-integration-dockerized + - name: Output some logs in case of failure + if: ${{ failure() }} + # add set -x to print commands before executing to make logs reading easier + run: | + set -x + minikube ip + minikube logs + minikube kubectl -- --namespace ${NAMESPACE} get events + minikube kubectl -- --namespace ${NAMESPACE} logs -l app.kubernetes.io/component=api,app.kubernetes.io/name=mlrun --tail=-1 + minikube kubectl -- --namespace ${NAMESPACE} get pods + minikube kubectl -- --namespace ${NAMESPACE} get pods -o yaml + minikube kubectl -- --namespace ${NAMESPACE} describe pods + set +x migrations-tests: name: Run Dockerized Migrations Tests diff --git a/.github/workflows/periodic-rebuild.yaml b/.github/workflows/periodic-rebuild.yaml new file mode 100644 index 000000000000..300bf1d2f726 --- /dev/null +++ b/.github/workflows/periodic-rebuild.yaml @@ -0,0 +1,53 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +name: Scheduled Re-Build Images + +on: + schedule: + # every night at 2am + - cron: "0 2 * * *" + +jobs: + re-build-images: + # let's not run this on every fork, change to your fork when developing + if: github.repository == 'mlrun/mlrun' || github.event_name == 'workflow_dispatch' + strategy: + fail-fast: false + matrix: + repo: ["mlrun","ui"] + branch: ["development","1.3.x"] + runs-on: ubuntu-latest + steps: + - name: Re-Build MLRun Image + if: matrix.repo == 'mlrun' + uses: convictional/trigger-workflow-and-wait@v1.6.5 + with: + owner: mlrun + repo: mlrun + github_token: ${{ secrets.RELEASE_GITHUB_ACCESS_TOKEN }} + workflow_file_name: build.yaml + ref: ${{ matrix.branch }} + wait_interval: 60 + client_payload: '{"skip_images": "models-gpu,models,base,tests", "build_from_cache": "false"}' + - name: Re-Build UI Image + if: matrix.repo == 'ui' + uses: convictional/trigger-workflow-and-wait@v1.6.5 + with: + owner: mlrun + repo: ui + github_token: ${{ secrets.RELEASE_GITHUB_ACCESS_TOKEN }} + workflow_file_name: build.yaml + ref: ${{ matrix.branch }} + wait_interval: 60 diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index eb083528e49a..5b84b1b04507 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + name: Release +run-name: Releasing ${{ inputs.version }} on: workflow_dispatch: @@ -157,3 +159,21 @@ jobs: allowUpdates: true prerelease: ${{ github.event.inputs.pre_release }} body: ${{ steps.resolve-release-notes.outputs.body }} + + + update-tutorials: + name: Bundle tutorials + needs: create-releases + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Create tutorials tar + run: | + tar -cvf mlrun-tutorials.tar docs/tutorial + - name: Add tutorials tar to release + uses: ncipollo/release-action@v1 + with: + allowUpdates: true + tag: v${{ github.event.inputs.version }} + token: ${{ secrets.RELEASE_GITHUB_ACCESS_TOKEN }} + artifacts: mlrun-tutorials.tar diff --git a/.github/workflows/security_scan.yaml b/.github/workflows/security_scan.yaml new file mode 100644 index 000000000000..25a65cd5844e --- /dev/null +++ b/.github/workflows/security_scan.yaml @@ -0,0 +1,153 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Currently supported running against prebuilt images +name: Security Scan +run-name: Scanning ${{ inputs.tag }} + +on: + workflow_dispatch: + inputs: + tag: + description: 'MLRun image tag to scan (unstable-cache, unstable-cache-13x, 1.3.0-wwwwwwww, 1.3.0 or any other tag)' + required: false + default: 'unstable' + registry: + description: 'MLRun image registry' + required: false + default: 'ghcr.io/' + skip_images: + description: 'Comma separated list of images to skip scanning' + required: false + + # disabling gpu images for now as scanning them takes more disk space than we can afford + # test images are not scanned as they are not used in production + default: 'test,models-gpu' + publish_results: + description: 'Whether to publish results to Github or not (default empty - no publish)' + required: false + default: '' + severity_threshold: + description: 'The minimum severity of vulnerabilities to report ("negligible", "low", "medium", "high" and "critical".)' + required: false + default: 'medium' + only_fixed: + description: 'Whether to scan only fixed vulnerabilities ("true" or "false")' + required: false + default: 'true' + +jobs: + matrix_prep: + name: Prepare image list + outputs: + matrix: ${{ steps.set-matrix.outputs.matrix }} + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - id: set-matrix + uses: ./.github/actions/image-matrix-prep + with: + skip_images: ${{ github.event.inputs.skip_images }} + + build_and_scan_docker_images: + name: Scan ${{ matrix.image-name }} (Python ${{ matrix.python-version }}) + runs-on: ubuntu-latest + needs: matrix_prep + strategy: + fail-fast: false + matrix: ${{ fromJson(needs.matrix_prep.outputs.matrix) }} + steps: + - uses: actions/checkout@v3 + + - name: Cleanup disk + run: | + "${GITHUB_WORKSPACE}/automation/scripts/github_workflow_free_space.sh" + + - name: Resolving image name + id: resolve_image_name + run: | + echo "image_name=$(make pull-${{ matrix.image-name }} | tail -1)" >> $GITHUB_OUTPUT + env: + MLRUN_DOCKER_REGISTRY: ${{ github.event.inputs.registry }} + MLRUN_VERSION: ${{ github.event.inputs.tag }} + MLRUN_PYTHON_VERSION: ${{ matrix.python-version }} + + - name: Define output format + id: output-format + + # this section is duplicated in the other jobs. + # make sure to update all when changed. + run: | + if [[ -n "${{ github.event.inputs.publish_results }}" ]]; then \ + echo "format=sarif" >> $GITHUB_OUTPUT; \ + echo "fail_build=false" >> $GITHUB_OUTPUT; \ + else \ + echo "format=table" >> $GITHUB_OUTPUT; \ + echo "fail_build=true" >> $GITHUB_OUTPUT; \ + fi + + - name: Scan image + uses: anchore/scan-action@v3 + id: scan + with: + image: ${{ steps.resolve_image_name.outputs.image_name }} + only-fixed: ${{ github.event.inputs.only_fixed }} + output-format: ${{ steps.output-format.outputs.format }} + fail-build: ${{ steps.output-format.outputs.fail_build }} + severity-cutoff: ${{ github.event.inputs.severity_threshold }} + + - name: Upload scan results + if: github.event.inputs.publish_results != '' + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: ${{ steps.scan.outputs.sarif }} + category: ${{ matrix.image-name }}-${{ matrix.python-version }}) + + scan_fs: + name: Scan file system + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Define output format + id: output-format + + # this section is duplicated in the other jobs. + # make sure to update all when changed. + run: | + if [[ -n "${{ github.event.inputs.publish_results }}" ]]; then \ + echo "format=sarif" >> $GITHUB_OUTPUT; \ + echo "fail_build=false" >> $GITHUB_OUTPUT; \ + else \ + echo "format=table" >> $GITHUB_OUTPUT; \ + echo "fail_build=true" >> $GITHUB_OUTPUT; \ + fi + + - name: Scan fs + uses: anchore/scan-action@v3 + id: scan + with: + path: "." + only-fixed: ${{ github.event.inputs.only_fixed }} + output-format: ${{ steps.output-format.outputs.format }} + fail-build: ${{ steps.output-format.outputs.fail_build }} + severity-cutoff: ${{ github.event.inputs.severity_threshold }} + + - name: Upload scan results + if: github.event.inputs.publish_results != '' + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: ${{ steps.scan.outputs.sarif }} + category: "repository" diff --git a/.github/workflows/system-tests-enterprise.yml b/.github/workflows/system-tests-enterprise.yml index 25de189cb53f..814e1ace009b 100644 --- a/.github/workflows/system-tests-enterprise.yml +++ b/.github/workflows/system-tests-enterprise.yml @@ -42,14 +42,6 @@ on: override_iguazio_version: description: 'Override the configured target system iguazio version (leave empty to resolve automatically)' required: false - test_code_from_action: - description: 'Take tested code from action REF rather than upstream (default: true). If running on personal fork you will want to set to false in order to pull images from mlrun ghcr (note that test code will be taken from the action REF anyways)' - required: true - default: 'true' - ui_code_from_action: - description: 'Take ui code from action branch in mlrun/ui (default: false - take from upstream)' - required: true - default: 'false' concurrency: one-at-a-time jobs: @@ -81,6 +73,12 @@ jobs: automation/system_test/cleanup.py \ ${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_USERNAME }}@${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_IP }}:/home/iguazio/cleanup.py + sshpass \ + -p "${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_PASSWORD }}" \ + scp \ + automation/system_test/dev_utilities.py \ + ${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_USERNAME }}@${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_IP }}:/home/iguazio/dev_utilities.py + sshpass \ -p "${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_PASSWORD }}" \ ssh \ @@ -102,70 +100,112 @@ jobs: name: Prepare System Tests Enterprise runs-on: ubuntu-latest + needs: [system-test-cleanup] # let's not run this on every fork, change to your fork when developing if: github.repository == 'mlrun/mlrun' || github.event_name == 'workflow_dispatch' steps: - uses: actions/checkout@v3 + - name: Copy state branch file from remote + run: | + sshpass -p "${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_PASSWORD }}" scp -o StrictHostKeyChecking=no ${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_USERNAME }}@${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_IP }}:/tmp/system-tests-branches-list.txt system-tests-branches-list.txt + + - name: Resolve Branch To Run System Tests + id: current-branch + # we store a file named /tmp/system-tests-branches-list.txt which contains a list of branches to run system tests + # on the branches are separated with commas, so each run we pop the first branch in the list and append it to the + # end of the list. + # This mechanism allows us to run on multiple branches without the need to modify the file or secrets each time + # a new branch is added or removed + run: | + # Read branches from local file + branches=$(cat system-tests-branches-list.txt) + echo "branches found in system-tests-branches-list.txt: $branches" + + # Split branches into an array + IFS=',' read -ra branches_array <<< "$branches" + + # Get the first branch in the list to work on + first_branch="${branches_array[0]}" + echo "working on $first_branch" + + # Remove the first branch from the list + branches_array=("${branches_array[@]:1}") + + # Add the first branch at the end of the list + branches_array+=("$first_branch") + + # Join branches back into a string + branches=$(printf ",%s" "${branches_array[@]}") + branches=${branches:1} + + # Output the new list of branches + echo "$branches" + + # Write new branches order to a local file + echo "$branches" | cat > system-tests-branches-list.txt + + # Set output + echo "name=$(echo $first_branch)" >> $GITHUB_OUTPUT + + - name: Override remote file from local resolved branch list + run: | + # Override the remote file with the new list of branches + sshpass -p "${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_PASSWORD }}" scp -o StrictHostKeyChecking=no system-tests-branches-list.txt ${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_USERNAME }}@${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_IP }}:/tmp/ + # checking out to base branch and not the target(resolved) branch, to be able to run the changed preparation code + # before merging the changes to upstream. + - name: Checkout base branch + uses: actions/checkout@v3 + - name: Set up python uses: actions/setup-python@v4 with: python-version: 3.9 cache: pip - name: Install automation scripts dependencies and add mlrun to dev packages - run: pip install -r automation/requirements.txt && pip install -e . - - name: Install curl and jq - run: sudo apt-get install curl jq - - name: Extract git branch - id: git_info run: | - echo "branch=$(echo ${GITHUB_REF#refs/heads/})" >> $GITHUB_OUTPUT - - name: Extract git hash from action mlrun version - # by default when running as part of the CI this param doesn't get enriched meaning it will be empty. - # we want the mlrun_hash to be set from the $GITHUB_SHA when running in CI - if: ${{ github.event.inputs.test_code_from_action != 'false' }} - id: git_action_info - run: | - echo "mlrun_hash=$(git rev-parse --short=8 $GITHUB_SHA)" >> $GITHUB_OUTPUT - - name: Extract git hash from action mlrun version - if: ${{ github.event.inputs.ui_code_from_action == 'true' }} - id: git_action_ui_info - run: | - echo "ui_hash=$( \ - cd /tmp && \ - git clone --single-branch --branch ${{ steps.git_info.outputs.branch }} https://github.com/mlrun/ui.git mlrun-ui 2> /dev/null && \ - cd mlrun-ui && \ - git rev-parse --short=8 HEAD && \ - cd .. && \ - rm -rf mlrun-ui)" >> $GITHUB_OUTPUT + pip install -r automation/requirements.txt && pip install -e . + sudo apt-get install curl jq - name: Extract git hashes from upstream and latest version id: git_upstream_info run: | + + # Get the latest commit of mlrun/mlrun (that is older than 1 hour) echo "mlrun_hash=$( \ cd /tmp && \ - git clone --single-branch --branch development https://github.com/mlrun/mlrun.git mlrun-upstream 2> /dev/null && \ + git clone --single-branch --branch ${{ steps.current-branch.outputs.name }} https://github.com/mlrun/mlrun.git mlrun-upstream 2> /dev/null && \ cd mlrun-upstream && \ - git rev-parse --short=8 HEAD && \ + git rev-list --until="1 hour ago" --max-count 1 --abbrev-commit HEAD && \ cd .. && \ rm -rf mlrun-upstream)" >> $GITHUB_OUTPUT + + # Get the latest commit of mlrun/ui (that is older than 1 hour) echo "ui_hash=$( \ cd /tmp && \ - git clone --single-branch --branch development https://github.com/mlrun/ui.git mlrun-ui 2> /dev/null && \ + git clone --single-branch --branch ${{ steps.current-branch.outputs.name }} https://github.com/mlrun/ui.git mlrun-ui 2> /dev/null && \ cd mlrun-ui && \ - git rev-parse --short=8 HEAD && \ + git rev-list --until="1 hour ago" --max-count 1 --abbrev-commit HEAD && \ cd .. && \ rm -rf mlrun-ui)" >> $GITHUB_OUTPUT - echo "unstable_version_prefix=$(cat automation/version/unstable_version_prefix)" >> $GITHUB_OUTPUT + + # Get the tested mlrun version + echo "unstable_version_prefix=$( \ + cd /tmp && \ + git clone --single-branch --branch ${{ steps.current-branch.outputs.name }} https://github.com/mlrun/mlrun.git mlrun-upstream 2> /dev/null && \ + cd mlrun-upstream && \ + cat automation/version/unstable_version_prefix && \ + cd .. && \ + rm -rf mlrun-upstream)" >> $GITHUB_OUTPUT - name: Set computed versions params id: computed_params run: | action_mlrun_hash=${{ steps.git_action_info.outputs.mlrun_hash }} && \ upstream_mlrun_hash=${{ steps.git_upstream_info.outputs.mlrun_hash }} && \ - export mlrun_hash=${action_mlrun_hash:-`echo $upstream_mlrun_hash`} + export mlrun_hash=${upstream_mlrun_hash:-`echo $action_mlrun_hash`} echo "mlrun_hash=$(echo $mlrun_hash)" >> $GITHUB_OUTPUT action_mlrun_ui_hash=${{ steps.git_action_ui_info.outputs.ui_hash }} && \ upstream_mlrun_ui_hash=${{ steps.git_upstream_info.outputs.ui_hash }} && \ - export ui_hash=${action_mlrun_ui_hash:-`echo $upstream_mlrun_ui_hash`} + export ui_hash=${upstream_mlrun_ui_hash:-`echo $action_mlrun_ui_hash`} echo "ui_hash=$(echo $ui_hash)" >> $GITHUB_OUTPUT echo "mlrun_version=$(echo ${{ steps.git_upstream_info.outputs.unstable_version_prefix }}+$mlrun_hash)" >> $GITHUB_OUTPUT echo "mlrun_docker_tag=$(echo ${{ steps.git_upstream_info.outputs.unstable_version_prefix }}-$mlrun_hash)" >> $GITHUB_OUTPUT @@ -186,28 +226,32 @@ jobs: INPUT_OVERRIDE_IGUAZIO_VERSION: ${{ github.event.inputs.override_iguazio_version }} INPUT_CLEAN_RESOURCES_IN_TEARDOWN: ${{ github.event.inputs.clean_resources_in_teardown }} - - name: Prepare System Test env.yaml and MLRun installation from current branch + - name: Prepare System Test Environment and Install MLRun + env: + IP_ADDR_PREFIX: ${{ secrets.IP_ADDR_PREFIX }} timeout-minutes: 50 run: | python automation/system_test/prepare.py run \ - "${{ steps.computed_params.outputs.mlrun_version }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_IP }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_USERNAME }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_PASSWORD }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_APP_CLUSTER_SSH_PASSWORD }}" \ - "${{ secrets.SYSTEM_TEST_GITHUB_ACCESS_TOKEN }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_PROVCTL_DOWNLOAD_PATH }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_PROVCTL_DOWNLOAD_URL_S3_ACCESS_KEY }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_PROVCTL_DOWNLOAD_URL_S3_KEY_ID }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_MLRUN_DB_PATH }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_WEBAPI_DIRECT_URL }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_FRAMESD_URL }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_USERNAME }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_ACCESS_KEY }}" \ - "${{ steps.computed_params.outputs.iguazio_version }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_SPARK_SERVICE }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_PASSWORD }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_SLACK_WEBHOOK_URL }}" \ + --mlrun-version "${{ steps.computed_params.outputs.mlrun_version }}" \ + --data-cluster-ip "${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_IP }}" \ + --data-cluster-ssh-username "${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_USERNAME }}" \ + --data-cluster-ssh-password "${{ secrets.LATEST_SYSTEM_TEST_DATA_CLUSTER_SSH_PASSWORD }}" \ + --app-cluster-ssh-password "${{ secrets.LATEST_SYSTEM_TEST_APP_CLUSTER_SSH_PASSWORD }}" \ + --github-access-token "${{ secrets.SYSTEM_TEST_GITHUB_ACCESS_TOKEN }}" \ + --provctl-download-url "${{ secrets.LATEST_SYSTEM_TEST_PROVCTL_DOWNLOAD_PATH }}" \ + --provctl-download-s3-access-key "${{ secrets.LATEST_SYSTEM_TEST_PROVCTL_DOWNLOAD_URL_S3_ACCESS_KEY }}" \ + --provctl-download-s3-key-id "${{ secrets.LATEST_SYSTEM_TEST_PROVCTL_DOWNLOAD_URL_S3_KEY_ID }}" \ + --mlrun-dbpath "${{ secrets.LATEST_SYSTEM_TEST_MLRUN_DB_PATH }}" \ + --webapi-direct-url "${{ secrets.LATEST_SYSTEM_TEST_WEBAPI_DIRECT_URL }}" \ + --framesd-url "${{ secrets.LATEST_SYSTEM_TEST_FRAMESD_URL }}" \ + --username "${{ secrets.LATEST_SYSTEM_TEST_USERNAME }}" \ + --access-key "${{ secrets.LATEST_SYSTEM_TEST_ACCESS_KEY }}" \ + --iguazio-version "${{ steps.computed_params.outputs.iguazio_version }}" \ + --spark-service "${{ secrets.LATEST_SYSTEM_TEST_SPARK_SERVICE }}" \ + --slack-webhook-url "${{ secrets.LATEST_SYSTEM_TEST_SLACK_WEBHOOK_URL }}" \ + --mysql-user "${{ secrets.LATEST_SYSTEM_TEST_MYSQL_USER }}" \ + --mysql-password "${{ secrets.LATEST_SYSTEM_TEST_MYSQL_PASSWORD }}" \ + --purge-db \ --mlrun-commit "${{ steps.computed_params.outputs.mlrun_hash }}" \ --override-image-registry "${{ steps.computed_params.outputs.mlrun_docker_registry }}" \ --override-image-repo ${{ steps.computed_params.outputs.mlrun_docker_repo }} \ @@ -217,12 +261,13 @@ jobs: outputs: mlrunVersion: ${{ steps.computed_params.outputs.mlrun_version }} + mlrunBranch: ${{ steps.current-branch.outputs.name }} mlrunSystemTestsCleanResources: ${{ steps.computed_params.outputs.mlrun_system_tests_clean_resources }} run-system-tests-enterprise-ci: # When increasing the timeout make sure it's not larger than the schedule cron interval timeout-minutes: 360 - name: Run System Tests Enterprise + name: Test ${{ matrix.test_component }} [${{ needs.prepare-system-tests-enterprise-ci.outputs.mlrunBranch }}] # requires prepare to finish before starting needs: [prepare-system-tests-enterprise-ci] runs-on: ubuntu-latest @@ -234,8 +279,12 @@ jobs: matrix: test_component: [api,runtimes,projects,model_monitoring,examples,backwards_compatibility,feature_store] steps: - # checking out to the commit hash that the preparation step executed on - uses: actions/checkout@v3 + # checking out to the resolved branch to run system tests on, as now we run the actual tests, we don't want to run + # the system tests of the branch that triggered the system tests as it might be in a different version + # than the mlrun version we deployed on the previous job (can have features that the resolved branch doesn't have) + with: + ref: ${{ needs.prepare-system-tests-enterprise-ci.outputs.mlrunBranch }} - name: Set up python uses: actions/setup-python@v4 with: @@ -249,17 +298,19 @@ jobs: timeout-minutes: 5 run: | python automation/system_test/prepare.py env \ - "${{ secrets.LATEST_SYSTEM_TEST_MLRUN_DB_PATH }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_WEBAPI_DIRECT_URL }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_FRAMESD_URL }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_USERNAME }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_ACCESS_KEY }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_SPARK_SERVICE }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_PASSWORD }}" \ - "${{ secrets.LATEST_SYSTEM_TEST_SLACK_WEBHOOK_URL }}" + --mlrun-dbpath "${{ secrets.LATEST_SYSTEM_TEST_MLRUN_DB_PATH }}" \ + --webapi-direct-url "${{ secrets.LATEST_SYSTEM_TEST_WEBAPI_DIRECT_URL }}" \ + --framesd-url "${{ secrets.LATEST_SYSTEM_TEST_FRAMESD_URL }}" \ + --username "${{ secrets.LATEST_SYSTEM_TEST_USERNAME }}" \ + --access-key "${{ secrets.LATEST_SYSTEM_TEST_ACCESS_KEY }}" \ + --spark-service "${{ secrets.LATEST_SYSTEM_TEST_SPARK_SERVICE }}" \ + --slack-webhook-url "${{ secrets.LATEST_SYSTEM_TEST_SLACK_WEBHOOK_URL }}" \ + --branch "${{ needs.prepare-system-tests-enterprise-ci.outputs.mlrunBranch }}" \ + --github-access-token "${{ secrets.SYSTEM_TEST_GITHUB_ACCESS_TOKEN }}" - name: Run System Tests run: | MLRUN_SYSTEM_TESTS_CLEAN_RESOURCES="${{ needs.prepare-system-tests-enterprise-ci.outputs.mlrunSystemTestsCleanResources }}" \ MLRUN_VERSION="${{ needs.prepare-system-tests-enterprise-ci.outputs.mlrunVersion }}" \ MLRUN_SYSTEM_TESTS_COMPONENT="${{ matrix.test_component }}" \ + MLRUN_SYSTEM_TESTS_BRANCH="${{ needs.prepare-system-tests-enterprise-ci.outputs.mlrunBranch }}" \ make test-system-dockerized diff --git a/.github/workflows/system-tests-opensource.yml b/.github/workflows/system-tests-opensource.yml index 85f018739be1..c7e721ef458e 100644 --- a/.github/workflows/system-tests-opensource.yml +++ b/.github/workflows/system-tests-opensource.yml @@ -35,14 +35,6 @@ on: description: 'Docker repo to pull images from (default: mlrun)' required: true default: 'mlrun' - test_code_from_action: - description: 'Take tested code from action REF (default: false - take from upstream) (note that test code will be taken from the action REF anyways)' - required: true - default: 'false' - ui_code_from_action: - description: 'Take ui code from action branch in mlrun/ui (default: false - take from upstream)' - required: true - default: 'false' clean_resources_in_teardown: description: 'Clean resources created by test (like project) in each test teardown (default: true - perform clean)' required: true @@ -77,33 +69,20 @@ jobs: cache: pip - name: Install automation scripts dependencies and add mlrun to dev packages run: | - pip install -r automation/requirements.txt -r dockerfiles/test-system/requirements.txt \ - -r dockerfiles/mlrun-api/requirements.txt -r dev-requirements.txt \ - -r extras-requirements.txt && pip install -e . - - # TODO: How can we avoid these duplicate lines from the enterprise system tests, up until line 120. - - name: Install curl and jq - run: sudo apt-get install curl jq + pip install \ + -r automation/requirements.txt \ + -r dockerfiles/test-system/requirements.txt \ + -r dockerfiles/mlrun-api/requirements.txt \ + -r dev-requirements.txt \ + -r extras-requirements.txt \ + && pip install -e . + sudo apt-get install curl jq + + # TODO: How can we avoid these duplicate lines from the enterprise system tests - name: Extract git branch id: git_info run: | echo "branch=$(echo ${GITHUB_REF#refs/heads/})" >> $GITHUB_OUTPUT - - name: Extract git hash from action mlrun version - if: ${{ github.event.inputs.test_code_from_action == 'true' }} - id: git_action_info - run: | - echo "mlrun_hash=$(git rev-parse --short=8 $GITHUB_SHA)" >> $GITHUB_OUTPUT - - name: Extract UI git hash from action mlrun version - if: ${{ github.event.inputs.ui_code_from_action == 'true' }} - id: git_action_ui_info - run: | - echo "ui_hash=$( \ - cd /tmp && \ - git clone --single-branch --branch ${{ steps.git_info.outputs.branch }} https://github.com/mlrun/ui.git mlrun-ui 2> /dev/null && \ - cd mlrun-ui && \ - git rev-parse --short=8 HEAD && \ - cd .. && \ - rm -rf mlrun-ui)" >> $GITHUB_OUTPUT - name: Extract git hashes from upstream and latest version id: git_upstream_info run: | @@ -111,14 +90,14 @@ jobs: cd /tmp && \ git clone --single-branch --branch development https://github.com/mlrun/mlrun.git mlrun-upstream 2> /dev/null && \ cd mlrun-upstream && \ - git rev-parse --short=8 HEAD && \ + git rev-list --until="1 hour ago" --max-count 1 --abbrev-commit HEAD && \ cd .. && \ rm -rf mlrun-upstream)" >> $GITHUB_OUTPUT echo "ui_hash=$( \ cd /tmp && \ git clone --single-branch --branch development https://github.com/mlrun/ui.git mlrun-ui 2> /dev/null && \ cd mlrun-ui && \ - git rev-parse --short=8 HEAD && \ + git rev-list --until="1 hour ago" --max-count 1 --abbrev-commit HEAD && \ cd .. && \ rm -rf mlrun-ui)" >> $GITHUB_OUTPUT echo "unstable_version_prefix=$(cat automation/version/unstable_version_prefix)" >> $GITHUB_OUTPUT @@ -164,51 +143,50 @@ jobs: # but this seems to work start args: '--addons=registry --insecure-registry="192.168.49.2:5000"' - - name: Get mlrun ce charts and create namespace - run: | - helm repo add mlrun-ce https://mlrun.github.io/ce - helm repo update - minikube kubectl -- create namespace ${NAMESPACE} - - name: Install MLRun CE helm chart run: | # TODO: There are a couple of modifications to the helm chart that we are doing right now: # 1. The grafana prometheus stack is disabled as there are currently no system tests checking its # functionality. Once the model monitoring feature is complete and we have system tests for it, we - # can enable it. (flags: --set kube-prometheus-stack.enabled=false) + # can enable it. # 2. The mlrun DB is set as the old SQLite db. There is a bug in github workers when trying to run a mysql # server pod in minikube installed on the worker, the mysql pod crashes. There isn't much information # about this issue online as this isn't how github expect you to use mysql in workflows - the worker # has a mysql server installed directly on it and should be enabled and used as the DB. So we might # want in the future to use that instead, unless the mysql will be able to come up without crashing. - # (flags: --set mlrun.httpDB.dbType="sqlite" --set mlrun.httpDB.dirPath="/mlrun/db" - # --set mlrun.httpDB.dsn="sqlite:////mlrun/db/mlrun.db?check_same_thread=false" - # --set mlrun.httpDB.oldDsn="") - helm --namespace ${NAMESPACE} \ - install mlrun-ce \ - --debug \ - --wait \ - --timeout 600s \ - --set kube-prometheus-stack.enabled=false \ - --set mlrun.httpDB.dbType="sqlite" \ - --set mlrun.httpDB.dirPath="/mlrun/db" \ - --set mlrun.httpDB.dsn="sqlite:////mlrun/db/mlrun.db?check_same_thread=false" \ - --set mlrun.httpDB.oldDsn="" \ - --set global.registry.url=$(minikube ip):5000 \ - --set global.registry.secretName="" \ - --set global.externalHostAddress=$(minikube ip) \ - --set nuclio.dashboard.externalIPAddresses[0]=$(minikube ip) \ - --set mlrun.api.image.repository=${{ steps.computed_params.outputs.mlrun_docker_registry }}${{ steps.computed_params.outputs.mlrun_docker_repo }}/mlrun-api \ - --set mlrun.api.image.tag=${{ steps.computed_params.outputs.mlrun_docker_tag }} \ - --set mlrun.ui.image.repository=ghcr.io/mlrun/mlrun-ui \ - --set mlrun.ui.image.tag=${{ steps.computed_params.outputs.mlrun_ui_version }} \ - --set mlrun.api.extraEnvKeyValue.MLRUN_HTTPDB__BUILDER__MLRUN_VERSION_SPECIFIER="mlrun[complete] @ git+https://github.com/mlrun/mlrun@${{ steps.computed_params.outputs.mlrun_hash }}" \ + # + # TODO: Align the mlrun config env vars with the ones in the prepare.py script to avoid further inconsistencies. + python automation/deployment/ce.py deploy \ + --verbose \ + --minikube \ + --namespace=${NAMESPACE} \ + --registry-secret-name="" \ + --disable-prometheus-stack \ + --sqlite /mlrun/db/mlrun.db \ + --override-mlrun-api-image="${{ steps.computed_params.outputs.mlrun_docker_registry }}${{ steps.computed_params.outputs.mlrun_docker_repo }}/mlrun-api:${{ steps.computed_params.outputs.mlrun_docker_tag }}" \ + --override-mlrun-ui-image="ghcr.io/mlrun/mlrun-ui:${{ steps.computed_params.outputs.mlrun_ui_version }}" \ + --set 'mlrun.api.extraEnvKeyValue.MLRUN_HTTPDB__BUILDER__MLRUN_VERSION_SPECIFIER="mlrun[complete] @ git+https://github.com/mlrun/mlrun@${{ steps.computed_params.outputs.mlrun_hash }}"' \ --set mlrun.api.extraEnvKeyValue.MLRUN_IMAGES_REGISTRY="${{ steps.computed_params.outputs.mlrun_docker_registry }}" \ - mlrun-ce/mlrun-ce + --set mlrun.api.extraEnvKeyValue.MLRUN_LOG_LEVEL="DEBUG" \ + --set 'mlrun.api.extraEnvKeyValue.MLRUN_HTTPDB__SCHEDULING__MIN_ALLOWED_INTERVAL="0 seconds"' \ + --set mlrun.api.extraEnvKeyValue.MLRUN_MODEL_ENDPOINT_MONITORING__PARQUET_BATCHING_MAX_EVENTS="100" - name: Prepare system tests env run: | - echo "MLRUN_DBPATH: http://$(minikube ip):${MLRUN_API_NODE_PORT}" > tests/system/env.yml + python automation/system_test/prepare.py env \ + --mlrun-dbpath "http://$(minikube ip):${MLRUN_API_NODE_PORT}" \ + --github-access-token "${{ secrets.SYSTEM_TEST_GITHUB_ACCESS_TOKEN }}" + + # Enable tmate debugging of manually-triggered workflows if the input option was provided + - name: Setup tmate session + uses: mxschmitt/action-tmate@v3 + if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.debug_enabled == 'true' }} + with: + + # run in detach mode to allow the workflow to continue running while session is active + # this will wait up to 10 minutes AFTER the entire job is done. Once user connects to the session, + # it will wait until the user disconnects before finishing up the job. + detached: true - name: Run system tests timeout-minutes: 180 @@ -235,8 +213,3 @@ jobs: minikube kubectl -- --namespace ${NAMESPACE} get pvc minikube kubectl -- --namespace ${NAMESPACE} get pv set +x - - # Enable tmate debugging of manually-triggered workflows if the input option was provided - - name: Setup tmate session - uses: mxschmitt/action-tmate@v3 - if: ${{ github.event_name == 'workflow_dispatch' && github.event.inputs.debug_enabled == 'true' }} diff --git a/.gitignore b/.gitignore index d51b023d0e7d..0aa76c87fc68 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,9 @@ mlrun.egg-info/ model.txt result*.html tests/test_results/ +tests/temp* +tests/*.pkl +tests/project.yaml *venv* mlrun/utils/version/version.json mlrun/api/migrations_sqlite/mlrun.db @@ -28,7 +31,7 @@ tests/system/env.yml # pyenv file for working with several python versions .python-version *.bak -docs/CONTRIBUTING.md +docs/contributing.md mlrun/api/proto/*pb2*.py docs/tutorial/colab/01-mlrun-basics-colab.ipynb diff --git a/.importlinter b/.importlinter new file mode 100644 index 000000000000..275a80e20e35 --- /dev/null +++ b/.importlinter @@ -0,0 +1,32 @@ +[importlinter] +root_package=mlrun +include_external_packages=True + + +[importlinter:contract:1] +name=common modules shouldn't import other mlrun utilities +type=forbidden +source_modules= + mlrun.common + +forbidden_modules= + mlrun.api + mlrun.artifacts + mlrun.data_types + mlrun.datastore + mlrun.db + mlrun.feature_store + mlrun.frameworks + mlrun.mlutils + mlrun.model_monitoring + mlrun.platforms + mlrun.projects + mlrun.runtimes + mlrun.serving + mlrun.utils + mlrun.builder + mlrun.config + mlrun.errors + mlrun.lists + mlrun.model + mlrun.run diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e94d21dd3ff8..90951e5d0cff 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,6 +2,8 @@ ## Creating a development environment +If you are working with an ARM64 machine, please see [Developing with ARM64 machines](#developing-with-arm64-machines). + We recommend using [pyenv](https://github.com/pyenv/pyenv#installation) to manage your python versions. Once you have pyenv installed, you can create a new environment by running: @@ -40,6 +42,46 @@ make install-requirements pip install -e '.[complete]' ``` +## Developing with ARM64 machines + +Some mlrun dependencies are not yet available for ARM64 machines via pypi, so we need to work with conda to get the packages compiled for ARM64 platform. +Install Anaconda from [here](https://docs.anaconda.com/free/anaconda/install/index.html) and then follow the steps below: + +Fork, clone and cd into the MLRun repository directory +```shell script +git clone git@github.com:/mlrun.git +cd mlrun +``` + +Create a conda environment and activate it +```shell script +conda create -n mlrun python=3.9 +conda activate mlrun +``` + +Then, install the dependencies +```shell script +make install-conda-requirements +``` + +Run some unit tests to make sure everything works: +```shell script +python -m pytest ./tests/projects +``` + +If you encounter any error with 'charset_normalizer' for example: +```shell script +AttributeError: partially initialized module 'charset_normalizer' has no attribute 'md__mypyc' (most likely due to a circular import) +``` +Run: +```shell script +pip install --force-reinstall charset-normalizer +``` +Finally, install mlrun +```shell script +pip install -e '.[complete]' +``` + ## Formatting We use [black](https://github.com/psf/black) as our formatter. diff --git a/Makefile b/Makefile index 9c0a97cf0598..f2d99997458a 100644 --- a/Makefile +++ b/Makefile @@ -108,6 +108,10 @@ install-requirements: ## Install all requirements needed for development -r dockerfiles/mlrun-api/requirements.txt \ -r docs/requirements.txt +.PHONY: install-conda-requirements +install-conda-requirements: install-requirements ## Install all requirements needed for development with specific conda packages for arm64 + conda install --yes --file conda-arm64-requirements.txt + .PHONY: install-complete-requirements install-complete-requirements: ## Install all requirements needed for development and testing python -m pip install --upgrade $(MLRUN_PIP_NO_CACHE_FLAG) pip~=$(MLRUN_PIP_VERSION) @@ -222,6 +226,10 @@ push-mlrun: mlrun ## Push mlrun docker image docker push $(MLRUN_IMAGE_NAME_TAGGED) $(MLRUN_CACHE_IMAGE_PUSH_COMMAND) +.PHONY: pull-mlrun +pull-mlrun: ## Pull mlrun docker image + docker pull $(MLRUN_IMAGE_NAME_TAGGED) + MLRUN_BASE_IMAGE_NAME := $(MLRUN_DOCKER_IMAGE_PREFIX)/$(MLRUN_ML_DOCKER_IMAGE_NAME_PREFIX)base MLRUN_BASE_CACHE_IMAGE_NAME := $(MLRUN_CACHE_DOCKER_IMAGE_PREFIX)/$(MLRUN_ML_DOCKER_IMAGE_NAME_PREFIX)base @@ -257,6 +265,9 @@ push-base: base ## Push base docker image docker push $(MLRUN_BASE_IMAGE_NAME_TAGGED) $(MLRUN_BASE_CACHE_IMAGE_PUSH_COMMAND) +.PHONY: pull-base +pull-base: ## Pull base docker image + docker pull $(MLRUN_BASE_IMAGE_NAME_TAGGED) MLRUN_MODELS_IMAGE_NAME := $(MLRUN_DOCKER_IMAGE_PREFIX)/$(MLRUN_ML_DOCKER_IMAGE_NAME_PREFIX)models MLRUN_MODELS_CACHE_IMAGE_NAME := $(MLRUN_CACHE_DOCKER_IMAGE_PREFIX)/$(MLRUN_ML_DOCKER_IMAGE_NAME_PREFIX)models @@ -296,6 +307,10 @@ push-models: models ## Push models docker image docker push $(MLRUN_MODELS_IMAGE_NAME_TAGGED) $(MLRUN_MODELS_CACHE_IMAGE_PUSH_COMMAND) +.PHONY: pull-models +pull-models: ## Pull models docker image + docker pull $(MLRUN_MODELS_IMAGE_NAME_TAGGED) + MLRUN_MODELS_GPU_IMAGE_NAME := $(MLRUN_DOCKER_IMAGE_PREFIX)/$(MLRUN_ML_DOCKER_IMAGE_NAME_PREFIX)models-gpu MLRUN_MODELS_GPU_CACHE_IMAGE_NAME := $(MLRUN_CACHE_DOCKER_IMAGE_PREFIX)/$(MLRUN_ML_DOCKER_IMAGE_NAME_PREFIX)models-gpu @@ -326,6 +341,10 @@ push-models-gpu: models-gpu ## Push models gpu docker image docker push $(MLRUN_MODELS_GPU_IMAGE_NAME_TAGGED) $(MLRUN_MODELS_GPU_CACHE_IMAGE_PUSH_COMMAND) +.PHONY: pull-models-gpu +pull-models-gpu: ## Pull models gpu docker image + docker pull $(MLRUN_MODELS_GPU_IMAGE_NAME_TAGGED) + .PHONY: prebake-models-gpu prebake-models-gpu: ## Build prebake models GPU docker image docker build \ @@ -370,25 +389,36 @@ jupyter: update-version-file ## Build mlrun jupyter docker image push-jupyter: jupyter ## Push mlrun jupyter docker image docker push $(MLRUN_JUPYTER_IMAGE_NAME) +.PHONY: pull-jupyter +pull-jupyter: ## Pull mlrun jupyter docker image + docker pull $(MLRUN_JUPYTER_IMAGE_NAME) + .PHONY: log-collector log-collector: update-version-file - cd go && \ - MLRUN_VERSION=$(MLRUN_VERSION) \ + @MLRUN_VERSION=$(MLRUN_VERSION) \ MLRUN_DOCKER_REGISTRY=$(MLRUN_DOCKER_REGISTRY) \ MLRUN_DOCKER_REPO=$(MLRUN_DOCKER_REPO) \ MLRUN_DOCKER_TAG=$(MLRUN_DOCKER_TAG) \ MLRUN_DOCKER_IMAGE_PREFIX=$(MLRUN_DOCKER_IMAGE_PREFIX) \ - make log-collector + make --no-print-directory -C $(shell pwd)/go log-collector .PHONY: push-log-collector push-log-collector: log-collector - cd go && \ - MLRUN_VERSION=$(MLRUN_VERSION) \ + @MLRUN_VERSION=$(MLRUN_VERSION) \ MLRUN_DOCKER_REGISTRY=$(MLRUN_DOCKER_REGISTRY) \ MLRUN_DOCKER_REPO=$(MLRUN_DOCKER_REPO) \ MLRUN_DOCKER_TAG=$(MLRUN_DOCKER_TAG) \ MLRUN_DOCKER_IMAGE_PREFIX=$(MLRUN_DOCKER_IMAGE_PREFIX) \ - make push-log-collector + make --no-print-directory -C $(shell pwd)/go push-log-collector + +.PHONY: pull-log-collector +pull-log-collector: + @MLRUN_VERSION=$(MLRUN_VERSION) \ + MLRUN_DOCKER_REGISTRY=$(MLRUN_DOCKER_REGISTRY) \ + MLRUN_DOCKER_REPO=$(MLRUN_DOCKER_REPO) \ + MLRUN_DOCKER_TAG=$(MLRUN_DOCKER_TAG) \ + MLRUN_DOCKER_IMAGE_PREFIX=$(MLRUN_DOCKER_IMAGE_PREFIX) \ + make --no-print-directory -C $(shell pwd)/go pull-log-collector .PHONY: compile-schemas @@ -425,6 +455,9 @@ push-api: api ## Push api docker image docker push $(MLRUN_API_IMAGE_NAME_TAGGED) $(MLRUN_API_CACHE_IMAGE_PUSH_COMMAND) +.PHONY: pull-api +pull-api: ## Pull api docker image + docker pull $(MLRUN_API_IMAGE_NAME_TAGGED) MLRUN_TEST_IMAGE_NAME := $(MLRUN_DOCKER_IMAGE_PREFIX)/test MLRUN_TEST_CACHE_IMAGE_NAME := $(MLRUN_CACHE_DOCKER_IMAGE_PREFIX)/test @@ -498,7 +531,6 @@ test: clean ## Run mlrun tests --durations=100 \ --ignore=tests/integration \ --ignore=tests/system \ - --ignore=tests/test_notebooks.py \ --ignore=tests/rundb/test_httpdb.py \ -rf \ tests @@ -522,7 +554,6 @@ test-integration: clean ## Run mlrun integration tests --durations=100 \ -rf \ tests/integration \ - tests/test_notebooks.py \ tests/rundb/test_httpdb.py .PHONY: test-migrations-dockerized @@ -611,6 +642,7 @@ run-api: api ## Run mlrun api (dockerized) --publish 8080 \ --add-host host.docker.internal:host-gateway \ --env MLRUN_HTTPDB__DSN=$(MLRUN_HTTPDB__DSN) \ + --env MLRUN_LOG_LEVEL=$(MLRUN_LOG_LEVEL) \ $(MLRUN_API_IMAGE_NAME_TAGGED) .PHONY: run-test-db @@ -650,6 +682,11 @@ fmt: ## Format the code (using black and isort) python -m black . python -m isort . +.PHONY: lint-imports +lint-imports: ## making sure imports dependencies are aligned + @echo "Running import linter" + lint-imports + .PHONY: lint lint: flake8 fmt-check ## Run lint on the code diff --git a/automation/deployment/README.md b/automation/deployment/README.md new file mode 100644 index 000000000000..46ebe8f4db2b --- /dev/null +++ b/automation/deployment/README.md @@ -0,0 +1,102 @@ +# MLRun Community Edition Deployer + +CLI tool for deploying MLRun Community Edition. +The CLI supports 3 commands: +- `deploy`: Deploys (or upgrades) an MLRun Community Edition Stack. +- `delete`: Uninstalls the CE and cleans up dangling resources. +- `patch-minikube-images`: If using custom images and running from Minikube, this command will patch the images to the Minikube env. + +## Command Usage: + +### Deploy: +To deploy the CE the minimum needed is the registry url and registry credentials. You can run: +``` +$ python automation/deployment/ce.py deploy \ + --registry-url \ + --registry-username \ + --registry-password +``` +This will deploy the CE with the default configuration. + +Instead of passing the registry credentials as command line arguments, you can create a secret in the cluster and pass the secret name: +``` +$ python automation/deployment/ce.py deploy \ + --registry-url \ + --registry-secret-name +``` + +#### Extra Configurations: + +You can override the mlrun version and chart version by using the flags `--mlrun-version` and `--chart-version` respectively. + +To disable the components that are installed by default, you can use the following flags: +- `--disable-pipelines`: Disable the installation of Kubeflow Pipelines. +- `--disable-prometheus-stack`: Disable the installation of the Prometheus stack. +- `--disable-spark-operator`: Disable the installation of the Spark operator. + +To override the images used by the CE, you can use the following flags: +- `--override-jupyter-image`: Override the jupyter image. Format: `:` +- `--override-mlrun-api-image`: Override the mlrun-api image. Format: `:` +- `--override-mlrun-ui-image`: Override the mlrun-ui image. Format: `:` + +To run mlrun with sqlite instead of MySQL, you can use the `--sqlite` flag. The value should be the path to the sqlite file to use. + +To set custom values for the mlrun chart, you can use the `--set` flag. The format is `=`. For example: +``` +$ python automation/deployment/ce.py deploy \ + --registry-url \ + --registry-username \ + --registry-password \ + --set mlrun.db.persistence.size="1Gi" \ + --set mlrun.api.persistence.size="1Gi" +``` + +To install the CE in a different namespace, you can use the `--namespace` flag. + +To install the CE in minikube, you can use the `--minikube` flag. + + +### Upgrade +To upgrade the CE, you can use the same command as deploy with the flag `--upgrade`. +The CLI will detect that the CE is already installed and will perform an upgrade. The flag will instruct helm to reuse values from the previous deployment. + +### Delete: +To simply uninstall the CE deployment, you can run: +``` +$ python automation/deployment/ce.py delete +``` + +To delete the CE deployment and clean up remaining volumes, you can run: +``` +$ python automation/deployment/ce.py delete --cleanup-volumes +``` + +To cleanup the entire namespace, you can run: +``` +$ python automation/deployment/ce.py delete --cleanup-namespace +``` + +If you already uninstalled, and only want to run cleanup, you can use the `--skip-uninstall` flag. + + +### Patch Minikube Images: +Patch MLRun Community Edition Deployment images to minikube. Useful if overriding images and running in minikube. +If you have some custom images, before deploying the CE, run: +``` +$ python automation/deployment/ce.py patch-minikube-images \ + --mlrun-api-image \ + --mlrun-ui-image \ + --jupyter-image +``` + +Then you can deploy the CE with: +``` +$ python automation/deployment/ce.py deploy \ + --registry-url \ + --registry-username \ + --registry-password \ + --minikube \ + --override-mlrun-api-image \ + --override-mlrun-ui-image \ + --override-jupyter-image +``` diff --git a/automation/deployment/__init__.py b/automation/deployment/__init__.py new file mode 100644 index 000000000000..7f557697af77 --- /dev/null +++ b/automation/deployment/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/automation/deployment/ce.py b/automation/deployment/ce.py new file mode 100644 index 000000000000..8d39597d72d1 --- /dev/null +++ b/automation/deployment/ce.py @@ -0,0 +1,329 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import typing + +import click +from deployer import CommunityEditionDeployer + +common_options = [ + click.option( + "-v", + "--verbose", + is_flag=True, + help="Enable debug logging", + ), + click.option( + "-f", + "--log-file", + help="Path to log file. If not specified, will log only to stdout", + ), + click.option( + "--remote", + help="Remote host to deploy to. If not specified, will deploy to the local host", + ), + click.option( + "--remote-ssh-username", + help="Username to use when connecting to the remote host via SSH. " + "If not specified, will use MLRUN_REMOTE_SSH_USERNAME environment variable", + ), + click.option( + "--remote-ssh-password", + help="Password to use when connecting to the remote host via SSH. " + "If not specified, will use MLRUN_REMOTE_SSH_PASSWORD environment variable", + ), +] + +common_deployment_options = [ + click.option( + "-n", + "--namespace", + default="mlrun", + help="Namespace to install the platform in. Defaults to 'mlrun'", + ), + click.option( + "--registry-secret-name", + help="Name of the secret containing the credentials for the container registry to use for storing images", + ), + click.option( + "--sqlite", + help="Path to sqlite file to use as the mlrun database. If not supplied, will use MySQL deployment", + ), +] + + +def add_options(options): + def _add_options(func): + for option in reversed(options): + func = option(func) + return func + + return _add_options + + +def order_click_options(func): + func.__click_params__ = list( + reversed(sorted(func.__click_params__, key=lambda option: option.name)) + ) + return func + + +@click.group(help="MLRun Community Edition Deployment CLI Tool") +def cli(): + pass + + +@cli.command(help="Deploy (or upgrade) MLRun Community Edition") +@order_click_options +@click.option( + "-mv", + "--mlrun-version", + help="Version of mlrun to install. If not specified, will install the latest version", +) +@click.option( + "-cv", + "--chart-version", + help="Version of the mlrun chart to install. If not specified, will install the latest version", +) +@click.option( + "--registry-url", + help="URL of the container registry to use for storing images", +) +@click.option( + "--registry-username", + help="Username of the container registry to use for storing images", +) +@click.option( + "--registry-password", + help="Password of the container registry to use for storing images", +) +@click.option( + "--override-mlrun-api-image", + help="Override the mlrun-api image. Format: :", +) +@click.option( + "--override-mlrun-ui-image", + help="Override the mlrun-ui image. Format: :", +) +@click.option( + "--override-jupyter-image", + help="Override the jupyter image. Format: :", +) +@click.option( + "--disable-pipelines", + is_flag=True, + help="Disable the installation of Kubeflow Pipelines", +) +@click.option( + "--disable-prometheus-stack", + is_flag=True, + help="Disable the installation of the Prometheus stack", +) +@click.option( + "--disable-spark-operator", + is_flag=True, + help="Disable the installation of the Spark operator", +) +@click.option( + "--devel", + is_flag=True, + help="Get the latest RC version of the mlrun chart. (Only works if --chart-version is not specified)", +) +@click.option( + "-m", + "--minikube", + is_flag=True, + help="Install the mlrun chart in local minikube", +) +@click.option( + "--set", + "set_", + help="Set custom values for the mlrun chart. Format: =", + multiple=True, +) +@click.option( + "--upgrade", + is_flag=True, + help="Upgrade the existing mlrun installation", +) +@click.option( + "--skip-registry-validation", + is_flag=True, + help="Skip validation of the registry URL", +) +@add_options(common_options) +@add_options(common_deployment_options) +def deploy( + verbose: bool = False, + log_file: str = None, + namespace: str = "mlrun", + remote: str = None, + remote_ssh_username: str = None, + remote_ssh_password: str = None, + mlrun_version: str = None, + chart_version: str = None, + registry_url: str = None, + registry_secret_name: str = None, + registry_username: str = None, + registry_password: str = None, + override_mlrun_api_image: str = None, + override_mlrun_ui_image: str = None, + override_jupyter_image: str = None, + disable_pipelines: bool = False, + disable_prometheus_stack: bool = False, + disable_spark_operator: bool = False, + skip_registry_validation: bool = False, + sqlite: str = None, + devel: bool = False, + minikube: bool = False, + upgrade: bool = False, + set_: typing.List[str] = None, +): + deployer = CommunityEditionDeployer( + namespace=namespace, + log_level="debug" if verbose else "info", + log_file=log_file, + remote=remote, + remote_ssh_username=remote_ssh_username, + remote_ssh_password=remote_ssh_password, + ) + deployer.deploy( + registry_url=registry_url, + registry_username=registry_username, + registry_password=registry_password, + registry_secret_name=registry_secret_name, + mlrun_version=mlrun_version, + chart_version=chart_version, + override_mlrun_api_image=override_mlrun_api_image, + override_mlrun_ui_image=override_mlrun_ui_image, + override_jupyter_image=override_jupyter_image, + disable_pipelines=disable_pipelines, + disable_prometheus_stack=disable_prometheus_stack, + disable_spark_operator=disable_spark_operator, + skip_registry_validation=skip_registry_validation, + devel=devel, + minikube=minikube, + sqlite=sqlite, + upgrade=upgrade, + custom_values=set_, + ) + + +@cli.command(help="Uninstall MLRun Community Edition Deployment") +@order_click_options +@click.option( + "--skip-uninstall", + is_flag=True, + help="Skip uninstalling the Helm chart. Useful if already uninstalled and you want to perform cleanup only", +) +@click.option( + "--skip-cleanup-registry-secret", + is_flag=True, + help="Skip deleting the registry secret created during installation", +) +@click.option( + "--cleanup-volumes", + is_flag=True, + help="Delete the PVCs created during installation. WARNING: This will result in data loss!", +) +@click.option( + "--cleanup-namespace", + is_flag=True, + help="Delete the namespace created during installation. This overrides the other cleanup options. " + "WARNING: This will result in data loss!", +) +@add_options(common_options) +@add_options(common_deployment_options) +def delete( + verbose: bool = False, + log_file: str = None, + namespace: str = "mlrun", + remote: str = None, + remote_ssh_username: str = None, + remote_ssh_password: str = None, + registry_secret_name: str = None, + skip_uninstall: bool = False, + skip_cleanup_registry_secret: bool = False, + cleanup_volumes: bool = False, + cleanup_namespace: bool = False, + sqlite: str = None, +): + deployer = CommunityEditionDeployer( + namespace=namespace, + log_level="debug" if verbose else "info", + log_file=log_file, + remote=remote, + remote_ssh_username=remote_ssh_username, + remote_ssh_password=remote_ssh_password, + ) + deployer.delete( + skip_uninstall=skip_uninstall, + sqlite=sqlite, + cleanup_registry_secret=not skip_cleanup_registry_secret, + cleanup_volumes=cleanup_volumes, + cleanup_namespace=cleanup_namespace, + registry_secret_name=registry_secret_name, + ) + + +@cli.command( + help="Patch MLRun Community Edition Deployment images to minikube. " + "Useful if overriding images and running in minikube" +) +@order_click_options +@click.option( + "--mlrun-api-image", + help="Override the mlrun-api image. Format: :", +) +@click.option( + "--mlrun-ui-image", + help="Override the mlrun-ui image. Format: :", +) +@click.option( + "--jupyter-image", + help="Override the jupyter image. Format: :", +) +@add_options(common_options) +def patch_minikube_images( + remote: str = None, + remote_ssh_username: str = None, + remote_ssh_password: str = None, + verbose: bool = False, + log_file: str = None, + mlrun_api_image: str = None, + mlrun_ui_image: str = None, + jupyter_image: str = None, +): + deployer = CommunityEditionDeployer( + namespace="", + log_level="debug" if verbose else "info", + log_file=log_file, + remote=remote, + remote_ssh_username=remote_ssh_username, + remote_ssh_password=remote_ssh_password, + ) + deployer.patch_minikube_images( + mlrun_api_image=mlrun_api_image, + mlrun_ui_image=mlrun_ui_image, + jupyter_image=jupyter_image, + ) + + +if __name__ == "__main__": + try: + cli() + except Exception as exc: + print("Unexpected error:", exc) + sys.exit(1) diff --git a/automation/deployment/deployer.py b/automation/deployment/deployer.py new file mode 100644 index 000000000000..73fd95e34c37 --- /dev/null +++ b/automation/deployment/deployer.py @@ -0,0 +1,802 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os.path +import platform +import subprocess +import sys +import typing + +import paramiko +import requests + + +class Constants: + helm_repo_name = "mlrun-ce" + helm_release_name = "mlrun-ce" + helm_chart_name = f"{helm_repo_name}/{helm_release_name}" + helm_repo_url = "https://mlrun.github.io/ce" + default_registry_secret_name = "registry-credentials" + mlrun_image_values = ["mlrun.api", "mlrun.ui", "jupyterNotebook"] + disableable_deployments = ["pipelines", "kube-prometheus-stack", "spark-operator"] + minikube_registry_port = 5000 + log_format = "> %(asctime)s [%(levelname)s] %(message)s" + + +class CommunityEditionDeployer: + """ + Deployer for MLRun Community Edition (CE) stack. + """ + + def __init__( + self, + namespace: str, + log_level: str = "info", + log_file: str = None, + remote: str = None, + remote_ssh_username: str = None, + remote_ssh_password: str = None, + ) -> None: + self._debug = log_level == "debug" + self._log_file_handler = None + logging.basicConfig(format="> %(asctime)s [%(levelname)s] %(message)s") + self._logger = logging.getLogger("automation") + self._logger.setLevel(log_level.upper()) + + if log_file: + self._log_file_handler = open(log_file, "a") + # using StreamHandler instead of FileHandler (which opens a file descriptor) so the same file descriptor + # can be used for command stdout as well as the logs. + handler = logging.StreamHandler(self._log_file_handler) + handler.setFormatter(logging.Formatter(Constants.log_format)) + self._logger.addHandler(handler) + + self._namespace = namespace + self._remote = remote + self._remote_ssh_username = remote_ssh_username or os.environ.get( + "MLRUN_REMOTE_SSH_USERNAME" + ) + self._remote_ssh_password = remote_ssh_password or os.environ.get( + "MLRUN_REMOTE_SSH_PASSWORD" + ) + self._ssh_client = None + if self._remote: + self.connect_to_remote() + + def connect_to_remote(self): + self._log("info", "Connecting to remote machine", remote=self._remote) + self._ssh_client = paramiko.SSHClient() + self._ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy) + self._ssh_client.connect( + self._remote, + username=self._remote_ssh_username, + password=self._remote_ssh_password, + ) + + def deploy( + self, + registry_url: str, + registry_username: str = None, + registry_password: str = None, + registry_secret_name: str = None, + chart_version: str = None, + mlrun_version: str = None, + override_mlrun_api_image: str = None, + override_mlrun_ui_image: str = None, + override_jupyter_image: str = None, + disable_pipelines: bool = False, + disable_prometheus_stack: bool = False, + disable_spark_operator: bool = False, + skip_registry_validation: bool = False, + devel: bool = False, + minikube: bool = False, + sqlite: str = None, + upgrade: bool = False, + custom_values: typing.List[str] = None, + ) -> None: + """ + Deploy MLRun CE stack. + :param registry_url: URL of the container registry to use for storing images + :param registry_username: Username for the container registry (not required if registry_secret_name is provided) + :param registry_password: Password for the container registry (not required if registry_secret_name is provided) + :param registry_secret_name: Name of the secret containing the credentials for the container registry + :param chart_version: Version of the helm chart to deploy (defaults to the latest stable version) + :param mlrun_version: Version of MLRun to deploy (defaults to the latest stable version) + :param override_mlrun_api_image: Override the default MLRun API image + :param override_mlrun_ui_image: Override the default MLRun UI image + :param override_jupyter_image: Override the default Jupyter image + :param disable_pipelines: Disable the deployment of the pipelines component + :param disable_prometheus_stack: Disable the deployment of the Prometheus stack component + :param disable_spark_operator: Disable the deployment of the Spark operator component + :param skip_registry_validation: Skip the validation of the registry URL + :param devel: Deploy the development version of the helm chart + :param minikube: Deploy the helm chart with minikube configuration + :param sqlite: Path to sqlite file to use as the mlrun database. If not supplied, will use MySQL deployment + :param upgrade: Upgrade an existing MLRun CE deployment + :param custom_values: List of custom values to pass to the helm chart + """ + self._prepare_prerequisites( + registry_url, + registry_username, + registry_password, + registry_secret_name, + skip_registry_validation, + minikube, + ) + helm_arguments = self._generate_helm_install_arguments( + registry_url, + registry_secret_name, + chart_version, + mlrun_version, + override_mlrun_api_image, + override_mlrun_ui_image, + override_jupyter_image, + disable_pipelines, + disable_prometheus_stack, + disable_spark_operator, + devel, + minikube, + sqlite, + upgrade, + custom_values, + ) + + self._log( + "info", + "Installing helm chart with arguments", + helm_arguments=helm_arguments, + ) + stdout, stderr, exit_status = self._run_command("helm", helm_arguments) + if exit_status != 0: + self._log( + "error", + "Failed to install helm chart", + stderr=stderr, + exit_status=exit_status, + ) + raise RuntimeError("Failed to install helm chart") + + self._teardown() + + def delete( + self, + skip_uninstall: bool = False, + sqlite: str = None, + cleanup_registry_secret: bool = True, + cleanup_volumes: bool = False, + cleanup_namespace: bool = False, + registry_secret_name: str = Constants.default_registry_secret_name, + ) -> None: + """ + Delete MLRun CE stack. + :param skip_uninstall: Skip the uninstallation of the helm chart + :param sqlite: Path to sqlite file to delete (if needed). + :param cleanup_registry_secret: Delete the registry secret + :param cleanup_volumes: Delete the MLRun volumes + :param cleanup_namespace: Delete the entire namespace + :param registry_secret_name: Name of the registry secret to delete + """ + if cleanup_namespace: + self._log( + "warning", "Cleaning up entire namespace", namespace=self._namespace + ) + self._run_command("kubectl", ["delete", "namespace", self._namespace]) + return + + if not skip_uninstall: + self._log( + "info", "Cleaning up helm release", release=Constants.helm_release_name + ) + self._run_command( + "helm", + [ + "--namespace", + self._namespace, + "uninstall", + Constants.helm_release_name, + ], + ) + + if cleanup_volumes: + self._log("warning", "Cleaning up mlrun volumes") + self._run_command( + "kubectl", + [ + "--namespace", + self._namespace, + "delete", + "pvc", + "-l", + f"app.kubernetes.io/name={Constants.helm_release_name}", + ], + ) + + if cleanup_registry_secret: + self._log( + "warning", + "Cleaning up registry secret", + secret_name=registry_secret_name, + ) + self._run_command( + "kubectl", + [ + "--namespace", + self._namespace, + "delete", + "secret", + registry_secret_name, + ], + ) + + if sqlite: + os.remove(sqlite) + + self._teardown() + + def patch_minikube_images( + self, + mlrun_api_image: str = None, + mlrun_ui_image: str = None, + jupyter_image: str = None, + ) -> None: + """ + Patch the MLRun CE stack images in minikube. + :param mlrun_api_image: MLRun API image to use + :param mlrun_ui_image: MLRun UI image to use + :param jupyter_image: Jupyter image to use + """ + for image in [mlrun_api_image, mlrun_ui_image, jupyter_image]: + if image: + self._run_command("minikube", ["load", image]) + + self._teardown() + + def _teardown(self): + """ + Teardown the CLI tool. + Close the log file handler if exists. + """ + if self._log_file_handler: + self._log_file_handler.close() + + def _prepare_prerequisites( + self, + registry_url: str, + registry_username: str = None, + registry_password: str = None, + registry_secret_name: str = None, + skip_registry_validation: bool = False, + minikube: bool = False, + ) -> None: + """ + Prepare the prerequisites for the MLRun CE stack deployment. + Creates namespace, adds helm repository, creates registry secret if needed. + :param registry_url: URL of the registry to use + :param registry_username: Username of the registry to use (not required if registry_secret_name is provided) + :param registry_password: Password of the registry to use (not required if registry_secret_name is provided) + :param registry_secret_name: Name of the registry secret to use + :param skip_registry_validation: Skip the validation of the registry URL + :param minikube: Whether to deploy on minikube + """ + self._log("info", "Preparing prerequisites") + skip_registry_validation = skip_registry_validation or ( + registry_url is None and minikube + ) + if not skip_registry_validation: + self._validate_registry_url(registry_url) + + self._log("info", "Creating namespace", namespace=self._namespace) + self._run_command("kubectl", ["create", "namespace", self._namespace]) + + self._log("debug", "Adding helm repo") + self._run_command( + "helm", ["repo", "add", Constants.helm_repo_name, Constants.helm_repo_url] + ) + + self._log("debug", "Updating helm repo") + self._run_command("helm", ["repo", "update"]) + + if registry_username and registry_password: + self._create_registry_credentials_secret( + registry_url, registry_username, registry_password + ) + elif registry_secret_name is not None: + self._log( + "warning", + "Using existing registry secret", + secret_name=registry_secret_name, + ) + else: + raise ValueError( + "Either registry credentials or registry secret name must be provided" + ) + + def _generate_helm_install_arguments( + self, + registry_url: str = None, + registry_secret_name: str = None, + chart_version: str = None, + mlrun_version: str = None, + override_mlrun_api_image: str = None, + override_mlrun_ui_image: str = None, + override_jupyter_image: str = None, + disable_pipelines: bool = False, + disable_prometheus_stack: bool = False, + disable_spark_operator: bool = False, + devel: bool = False, + minikube: bool = False, + sqlite: str = None, + upgrade: bool = False, + custom_values: typing.List[str] = None, + ) -> typing.List[str]: + """ + Generate the helm install arguments. + :param registry_url: URL of the registry to use + :param registry_secret_name: Name of the registry secret to use + :param chart_version: Version of the chart to use + :param mlrun_version: Version of MLRun to use + :param override_mlrun_api_image: Override MLRun API image to use + :param override_mlrun_ui_image: Override MLRun UI image to use + :param override_jupyter_image: Override Jupyter image to use + :param disable_pipelines: Disable pipelines + :param disable_prometheus_stack: Disable Prometheus stack + :param disable_spark_operator: Disable Spark operator + :param devel: Use development chart + :param minikube: Use minikube + :param sqlite: Path to sqlite file to use as the mlrun database. If not supplied, will use MySQL deployment + :param upgrade: Upgrade an existing MLRun CE deployment + :param custom_values: List of custom values to use + :return: List of helm install arguments + """ + helm_arguments = [ + "--namespace", + self._namespace, + "upgrade", + Constants.helm_release_name, + Constants.helm_chart_name, + "--install", + "--wait", + "--timeout", + "960s", + ] + + if self._debug: + helm_arguments.append("--debug") + + if upgrade: + helm_arguments.append("--reuse-values") + + for helm_key, helm_value in self._generate_helm_values( + registry_url, + registry_secret_name, + mlrun_version, + override_mlrun_api_image, + override_mlrun_ui_image, + override_jupyter_image, + disable_pipelines, + disable_prometheus_stack, + disable_spark_operator, + sqlite, + minikube, + ).items(): + helm_arguments.extend( + [ + "--set", + f"{helm_key}={helm_value}", + ] + ) + + for value in custom_values: + helm_arguments.extend( + [ + "--set", + value, + ] + ) + + if chart_version: + self._log( + "warning", + "Installing specific chart version", + chart_version=chart_version, + ) + helm_arguments.extend( + [ + "--version", + chart_version, + ] + ) + + if devel: + self._log("warning", "Installing development chart version") + helm_arguments.append("--devel") + + return helm_arguments + + def _generate_helm_values( + self, + registry_url: str, + registry_secret_name: str = None, + mlrun_version: str = None, + override_mlrun_api_image: str = None, + override_mlrun_ui_image: str = None, + override_jupyter_image: str = None, + disable_pipelines: bool = False, + disable_prometheus_stack: bool = False, + disable_spark_operator: bool = False, + sqlite: str = None, + minikube: bool = False, + ) -> typing.Dict[str, str]: + """ + Generate the helm values. + :param registry_url: URL of the registry to use + :param registry_secret_name: Name of the registry secret to use + :param mlrun_version: Version of MLRun to use + :param override_mlrun_api_image: Override MLRun API image to use + :param override_mlrun_ui_image: Override MLRun UI image to use + :param override_jupyter_image: Override Jupyter image to use + :param disable_pipelines: Disable pipelines + :param disable_prometheus_stack: Disable Prometheus stack + :param disable_spark_operator: Disable Spark operator + :param sqlite: Path to sqlite file to use as the mlrun database. If not supplied, will use MySQL deployment + :param minikube: Use minikube + :return: Dictionary of helm values + """ + host_ip = self._get_minikube_ip() if minikube else self._get_host_ip() + if not registry_url and minikube: + registry_url = f"{host_ip}:{Constants.minikube_registry_port}" + + helm_values = { + "global.registry.url": registry_url, + "global.registry.secretName": f'"{registry_secret_name}"' # adding quotes in case of empty string + if registry_secret_name is not None + else Constants.default_registry_secret_name, + "global.externalHostAddress": host_ip, + "nuclio.dashboard.externalIPAddresses[0]": host_ip, + } + + if mlrun_version: + self._set_mlrun_version_in_helm_values(helm_values, mlrun_version) + + for value, overriden_image in zip( + Constants.mlrun_image_values, + [ + override_mlrun_api_image, + override_mlrun_ui_image, + override_jupyter_image, + ], + ): + if overriden_image: + self._override_image_in_helm_values(helm_values, value, overriden_image) + + for deployment, disabled in zip( + Constants.disableable_deployments, + [ + disable_pipelines, + disable_prometheus_stack, + disable_spark_operator, + ], + ): + if disabled: + self._disable_deployment_in_helm_values(helm_values, deployment) + + if sqlite: + dir_path = os.path.dirname(sqlite) + helm_values.update( + { + "mlrun.httpDB.dbType": "sqlite", + "mlrun.httpDB.dirPath": dir_path, + "mlrun.httpDB.dsn": f"sqlite:///{sqlite}?check_same_thread=false", + "mlrun.httpDB.oldDsn": '""', + } + ) + + # TODO: We need to fix the pipelines metadata grpc server to work on arm + if self._check_platform_architecture() == "arm": + self._log( + "warning", + "Kubeflow Pipelines is not supported on ARM architecture. Disabling KFP installation.", + ) + self._disable_deployment_in_helm_values(helm_values, "pipelines") + + self._log( + "debug", + "Generated helm values", + helm_values=helm_values, + ) + + return helm_values + + def _create_registry_credentials_secret( + self, + registry_url: str, + registry_username: str, + registry_password: str, + registry_secret_name: str = None, + ) -> None: + """ + Create a registry credentials secret. + :param registry_url: URL of the registry to use + :param registry_username: Username of the registry to use + :param registry_password: Password of the registry to use + :param registry_secret_name: Name of the registry secret to use + """ + registry_secret_name = ( + registry_secret_name + if registry_secret_name is not None + else Constants.default_registry_secret_name + ) + self._log( + "debug", + "Creating registry credentials secret", + secret_name=registry_secret_name, + ) + self._run_command( + "kubectl", + [ + "--namespace", + self._namespace, + "create", + "secret", + "docker-registry", + registry_secret_name, + f"--docker-server={registry_url}", + f"--docker-username={registry_username}", + f"--docker-password={registry_password}", + ], + ) + + def _check_platform_architecture(self) -> str: + """ + Check the platform architecture. If running on macOS, check if Rosetta is enabled. + Used for kubeflow pipelines which is not supported on ARM architecture (specifically the metadata grpc server). + :return: Platform architecture + """ + if self._remote: + self._log( + "warning", + "Cannot check platform architecture on remote machine, assuming x86", + ) + return "x86" + + if platform.system() == "Darwin": + translated, _, exit_status = self._run_command( + "sysctl", + ["-n", "sysctl.proc_translated"], + live=False, + ) + is_rosetta = translated.strip() == b"1" and exit_status == 0 + + if is_rosetta: + return "arm" + + return platform.processor() + + def _get_host_ip(self) -> str: + """ + Get the host machine IP. + :return: Host IP + """ + if platform.system() == "Darwin": + return ( + self._run_command("ipconfig", ["getifaddr", "en0"], live=False)[0] + .strip() + .decode("utf-8") + ) + elif platform.system() == "Linux": + return ( + self._run_command("hostname", ["-I"], live=False)[0] + .split()[0] + .strip() + .decode("utf-8") + ) + else: + raise NotImplementedError( + f"Platform {platform.system()} is not supported for this action" + ) + + def _get_minikube_ip(self) -> str: + """ + Get the minikube IP. + :return: Minikube IP + """ + return ( + self._run_command("minikube", ["ip"], live=False)[0].strip().decode("utf-8") + ) + + def _validate_registry_url(self, registry_url): + """ + Validate the registry url. Send simple GET request to the registry url. + :param registry_url: URL of the registry to use + """ + if not registry_url: + raise ValueError("Registry url is required") + try: + response = requests.get(registry_url) + response.raise_for_status() + except Exception as exc: + self._log("error", "Failed to validate registry url", exc=exc) + raise exc + + def _set_mlrun_version_in_helm_values( + self, helm_values: typing.Dict[str, str], mlrun_version: str + ) -> None: + """ + Set the mlrun version in all the image tags in the helm values. + :param helm_values: Helm values to update + :param mlrun_version: MLRun version to use + """ + self._log( + "warning", "Installing specific mlrun version", mlrun_version=mlrun_version + ) + for image in Constants.mlrun_image_values: + helm_values[f"{image}.image.tag"] = mlrun_version + + def _override_image_in_helm_values( + self, + helm_values: typing.Dict[str, str], + image_helm_value: str, + overriden_image: str, + ) -> None: + """ + Override an image in the helm values. + :param helm_values: Helm values to update + :param image_helm_value: Helm value of the image to override + :param overriden_image: Image with which to override + """ + ( + overriden_image_repo, + overriden_image_tag, + ) = overriden_image.split(":") + self._log( + "warning", + "Overriding image", + image=image_helm_value, + overriden_image=overriden_image, + ) + helm_values[f"{image_helm_value}.image.repository"] = overriden_image_repo + helm_values[f"{image_helm_value}.image.tag"] = overriden_image_tag + + def _disable_deployment_in_helm_values( + self, helm_values: typing.Dict[str, str], deployment: str + ) -> None: + """ + Disable a deployment in the helm values. + :param helm_values: Helm values to update + :param deployment: Deployment to disable + """ + self._log("warning", "Disabling deployment", deployment=deployment) + helm_values[f"{deployment}.enabled"] = "false" + + def _run_command( + self, + command: str, + args: list = None, + workdir: str = None, + stdin: str = None, + live: bool = True, + ) -> (str, str, int): + if self._remote: + return run_command_remotely( + self._ssh_client, + command=command, + args=args, + workdir=workdir, + stdin=stdin, + live=live, + log_file_handler=self._log_file_handler, + ) + else: + return run_command( + command=command, + args=args, + workdir=workdir, + stdin=stdin, + live=live, + log_file_handler=self._log_file_handler, + ) + + def _log(self, level: str, message: str, **kwargs: typing.Any) -> None: + more = f": {kwargs}" if kwargs else "" + self._logger.log(logging.getLevelName(level.upper()), f"{message}{more}") + + +def run_command( + command: str, + args: list = None, + workdir: str = None, + stdin: str = None, + live: bool = True, + log_file_handler: typing.IO[str] = None, +) -> (str, str, int): + if workdir: + command = f"cd {workdir}; " + command + if args: + command += " " + " ".join(args) + + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE, + shell=True, + ) + + if stdin: + process.stdin.write(bytes(stdin, "ascii")) + process.stdin.close() + + stdout = _handle_command_stdout(process.stdout, log_file_handler, live) + stderr = process.stderr.read() + exit_status = process.wait() + + return stdout, stderr, exit_status + + +def run_command_remotely( + ssh_client: paramiko.SSHClient, + command: str, + args: list = None, + workdir: str = None, + stdin: str = None, + live: bool = True, + log_file_handler: typing.IO[str] = None, +) -> (str, str, int): + if workdir: + command = f"cd {workdir}; " + command + if args: + command += " " + " ".join(args) + + stdin_stream, stdout_stream, stderr_stream = ssh_client.exec_command(command) + + if stdin: + stdin_stream.write(stdin) + stdin_stream.close() + + stdout = _handle_command_stdout(stdout_stream, log_file_handler, live, remote=True) + stderr = stderr_stream.read() + exit_status = stdout_stream.channel.recv_exit_status() + + return stdout, stderr, exit_status + + +def _handle_command_stdout( + stdout_stream: typing.Union[typing.IO[bytes], paramiko.channel.ChannelFile], + log_file_handler: typing.IO[str] = None, + live: bool = True, + remote: bool = False, +) -> str: + def _maybe_decode(text: typing.Union[str, bytes]) -> str: + if isinstance(text, bytes): + return text.decode(sys.stdout.encoding) + return text + + def _write_to_log_file(text: bytes): + if log_file_handler: + log_file_handler.write(_maybe_decode(text)) + + stdout = "" + if live: + for line in iter(stdout_stream.readline, b""): + # remote stream never ends, so we need to break when there's no more data + if remote and not line: + break + stdout += str(line) + sys.stdout.write(_maybe_decode(line)) + _write_to_log_file(line) + else: + stdout = stdout_stream.read() + _write_to_log_file(stdout) + + return stdout diff --git a/automation/package_test/test.py b/automation/package_test/test.py index 6faa4df52143..7b8c933b50e7 100644 --- a/automation/package_test/test.py +++ b/automation/package_test/test.py @@ -166,7 +166,8 @@ def _test_requirements_vulnerabilities(self, extra): raise_on_error=False, ) if code != 0: - vulnerabilities = json.loads(stdout) + full_report = json.loads(stdout) + vulnerabilities = full_report["vulnerabilities"] if vulnerabilities: self._logger.debug( "Found requirements vulnerabilities", @@ -213,11 +214,15 @@ def _test_requirements_vulnerabilities(self, extra): filtered_vulnerabilities = [] for vulnerability in vulnerabilities: - if vulnerability[0] in ignored_vulnerabilities: - ignored_vulnerability = ignored_vulnerabilities[vulnerability[0]] + if vulnerability["package_name"] in ignored_vulnerabilities: + ignored_vulnerability = ignored_vulnerabilities[ + vulnerability["package_name"] + ] ignore_vulnerability = False for ignored_pattern in ignored_vulnerability: - if re.search(ignored_pattern["pattern"], vulnerability[3]): + if re.search( + ignored_pattern["pattern"], vulnerability["advisory"] + ): self._logger.debug( "Ignoring vulnerability", vulnerability=vulnerability, @@ -232,7 +237,6 @@ def _test_requirements_vulnerabilities(self, extra): message = "Found vulnerable requirements that can not be ignored" logger.warning( message, - vulnerabilities=vulnerabilities, filtered_vulnerabilities=filtered_vulnerabilities, ignored_vulnerabilities=ignored_vulnerabilities, ) diff --git a/automation/requirements.txt b/automation/requirements.txt index 1974693576bc..509738800659 100644 --- a/automation/requirements.txt +++ b/automation/requirements.txt @@ -2,4 +2,4 @@ click~=8.0.0 paramiko~=2.12 semver~=2.13 requests~=2.22 -boto3~=1.9, <1.17.107 +boto3~=1.24.59 diff --git a/automation/scripts/github_workflow_free_space.sh b/automation/scripts/github_workflow_free_space.sh index d5783d86992c..afd9c3fa9e54 100755 --- a/automation/scripts/github_workflow_free_space.sh +++ b/automation/scripts/github_workflow_free_space.sh @@ -49,7 +49,9 @@ sudo rm --recursive --force \ "$AGENT_TOOLSDIRECTORY" # clean unneeded docker images -docker system prune --all --force +if [ -z "$KEEP_DOCKER_IMAGES" ]; then + docker system prune --all --force +fi # post cleanup print_free_space diff --git a/automation/system_test/cleanup.py b/automation/system_test/cleanup.py index eec04d0b0ddb..252a28f2d724 100644 --- a/automation/system_test/cleanup.py +++ b/automation/system_test/cleanup.py @@ -38,6 +38,22 @@ def main(): def docker_images(registry_url: str, registry_container_name: str, images: str): images = images.split(",") loop = asyncio.get_event_loop() + try: + click.echo("Removing images from datanode docker") + _remove_image_from_datanode_docker() + except Exception as exc: + click.echo( + f"Unable to remove images from datanode docker: {exc}, continuing anyway" + ) + + try: + click.echo("Removing dangling images from datanode docker") + _remove_dangling_images_from_datanode_docker() + except Exception as exc: + click.echo( + f"Unable to remove dangling images from datanode docker: {exc}, continuing anyway" + ) + try: _run_registry_garbage_collection(registry_container_name) except Exception as exc: @@ -81,6 +97,39 @@ async def _collect_image_tags( return tags +def _remove_image_from_datanode_docker(): + """Remove image from datanode docker""" + formatted_docker_images = subprocess.Popen( + ["docker", "images", "--format", "'{{.Repository }}:{{.Tag}}'"], + stdout=subprocess.PIPE, + ) + grep = subprocess.Popen( + ["grep", "mlrun"], + stdin=formatted_docker_images.stdout, + stdout=subprocess.PIPE, + ) + subprocess.run( + ["xargs", "--no-run-if-empty", "docker", "rmi", "-f"], + stdin=grep.stdout, + ) + formatted_docker_images.stdout.close() + grep.stdout.close() + + +def _remove_dangling_images_from_datanode_docker(): + """Remove dangling images from datanode docker""" + + dangling_docker_images = subprocess.Popen( + ["docker", "images", "--quiet", "--filter", "dangling=true"], + stdout=subprocess.PIPE, + ) + subprocess.run( + ["xargs", "--no-run-if-empty", "docker", "rmi", "-f"], + stdin=dangling_docker_images.stdout, + ) + dangling_docker_images.stdout.close() + + async def _delete_image_tags( registry: str, tags: typing.Dict[str, typing.List[str]] ) -> None: diff --git a/automation/system_test/dev_utilities.py b/automation/system_test/dev_utilities.py new file mode 100644 index 000000000000..681acc74f40e --- /dev/null +++ b/automation/system_test/dev_utilities.py @@ -0,0 +1,375 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import base64 +import subprocess + +import click + + +def run_click_command(command, **kwargs): + """ + Runs a click command with the specified arguments. + :param command: The click command to run. + :param kwargs: Keyword arguments to pass to the click command. + """ + # create a Click context object + ctx = click.Context(command) + # invoke the Click command with the desired arguments + ctx.invoke(command, **kwargs) + + +def get_installed_releases(namespace): + cmd = ["helm", "ls", "-n", namespace, "--deployed", "--short"] + output = subprocess.check_output(cmd).decode("utf-8") + release_names = output.strip().split("\n") + return release_names + + +def run_command(cmd): + """ + Runs a shell command and returns its output and exit status. + """ + result = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True + ) + return result.stdout.decode("utf-8"), result.returncode + + +def create_ingress_resource(domain_name, ipadd): + # Replace the placeholder string with the actual domain name + yaml_manifest = """ + apiVersion: networking.k8s.io/v1 + kind: Ingress + metadata: + annotations: + nginx.ingress.kubernetes.io/auth-cache-duration: 200 202 5m, 401 30s + nginx.ingress.kubernetes.io/auth-cache-key: $host$http_x_remote_user$http_cookie$http_authorization + nginx.ingress.kubernetes.io/proxy-body-size: "0" + nginx.ingress.kubernetes.io/whitelist-source-range: "{}" + nginx.ingress.kubernetes.io/service-upstream: "true" + nginx.ingress.kubernetes.io/ssl-redirect: "false" + labels: + release: redisinsight + name: redisinsight + namespace: devtools + spec: + ingressClassName: nginx + rules: + - host: {} + http: + paths: + - backend: + service: + name: redisinsight + port: + number: 80 + path: / + pathType: ImplementationSpecific + tls: + - hosts: + - {} + secretName: ingress-tls + """.format( + ipadd, domain_name, domain_name + ) + subprocess.run( + ["kubectl", "apply", "-f", "-"], input=yaml_manifest.encode(), check=True + ) + + +def get_ingress_controller_version(): + # Run the kubectl command and capture its output + kubectl_cmd = "kubectl" + namespace = "default-tenant" + grep_cmd = "grep shell.default-tenant" + awk_cmd1 = "awk '{print $3}'" + awk_cmd2 = "awk -F shell.default-tenant '{print $2}'" + cmd = f"{kubectl_cmd} get ingress -n {namespace} | {grep_cmd} | {awk_cmd1} | {awk_cmd2}" + result = subprocess.run( + cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + return result.stdout.decode("utf-8").strip() + + +def get_svc_password(namespace, service_name, key): + cmd = f'kubectl get secret --namespace {namespace} {service_name} -o jsonpath="{{.data.{key}}}" | base64 --decode' + result = subprocess.run( + cmd, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + return result.stdout.decode("utf-8").strip() + + +def print_svc_info(svc_host, svc_port, svc_username, svc_password, nodeport): + print(f"Service is running at {svc_host}:{svc_port}") + print(f"Service username: {svc_username}") + print(f"Service password: {svc_password}") + print(f"service nodeport: {nodeport}") + + +def check_redis_installation(): + cmd = "helm ls -A | grep -w redis | awk '{print $1}' | wc -l" + result = subprocess.check_output(cmd, shell=True) + return result.decode("utf-8").strip() + + +def add_repos(): + repos = {"bitnami": "https://charts.bitnami.com/bitnami"} + for repo, url in repos.items(): + cmd = f"helm repo add {repo} {url}" + subprocess.run(cmd.split(), check=True) + + +def install_redisinsight(ipadd): + print(check_redis_installation) + if check_redis_installation() == "1": + subprocess.run(["rm", "-rf", "redisinsight-chart-0.1.0.tgz*"]) + chart_url = "https://docs.redis.com/latest/pkgs/redisinsight-chart-0.1.0.tgz" + chart_file = "redisinsight-chart-0.1.0.tgz" + subprocess.run(["wget", chart_url]) + # get redis password + redis_password = subprocess.check_output( + [ + "kubectl", + "get", + "secret", + "--namespace", + "devtools", + "redis", + "-o", + 'jsonpath="{.data.redis-password}"', + ], + encoding="utf-8", + ).strip('"\n') + redis_password = base64.b64decode(redis_password).decode("utf-8") + cmd = [ + "helm", + "install", + "redisinsight", + chart_file, + "--set", + "redis.url=redis-master", + "--set", + "master.service.nodePort=6379", + "--set", + f"auth.password={redis_password}", + "--set", + "fullnameOverride=redisinsight", + "--namespace", + "devtools", + ] + subprocess.run(cmd.split(), check=True) + # run patch cmd + fqdn = get_ingress_controller_version() + full_domain = "redisinsight" + fqdn + create_ingress_resource(full_domain, ipadd) + deployment_name = "redisinsight" + container_name = "redisinsight-chart" + env_name = "RITRUSTEDORIGINS" + full_domain = full_domain + pfull_domain = "https://" + full_domain + patch_command = ( + f'kubectl patch deployment -n devtools {deployment_name} -p \'{{"spec":{{"template":{{"spec":{{' + f'"containers":[{{"name":"{container_name}","env":[{{"name":"{env_name}","value":"' + f"{pfull_domain}\"}}]}}]}}}}}}}}'" + ) + subprocess.run(patch_command, shell=True) + clean_command = "rm -rf redisinsight-chart-0.1.0.tgz*" + subprocess.run(clean_command, shell=True) + else: + print("redis is not install, please install redis first") + exit() + + +@click.command() +@click.option("--redis", is_flag=True, help="Install Redis") +@click.option("--kafka", is_flag=True, help="Install Kafka") +@click.option("--mysql", is_flag=True, help="Install MySQL") +@click.option("--redisinsight", is_flag=True, help="Install Redis GUI") +@click.option("--ipadd", default="localhost", help="IP address as string") +def install(redis, kafka, mysql, redisinsight, ipadd): + # Check if the local-path storage class exists + output, exit_code = run_command( + "kubectl get storageclass local-path >/dev/null 2>&1" + ) + if exit_code != 0: + # Install the local-path provisioner + cmd = ( + "kubectl apply -f https://raw.githubusercontent.com/rancher/local-path-provisioner/v0.0.24/deploy/local" + "-path-storage.yaml" + ) + output, exit_code = run_command(cmd) + if exit_code == 0: + # Set the local-path storage class as the default + cmd = ( + 'kubectl patch storageclass local-path -p \'{"metadata": {"annotations":{' + '"storageclass.kubernetes.io/is-default-class":"true"}}}\'' + ) + output, exit_code = run_command(cmd) + if exit_code == 0: + print( + "local-path storage class has been installed and set as the default." + ) + else: + print(f"Error setting local-path storage class as default: {output}") + else: + print(f"Error installing local-path storage class: {output}") + else: + print("local-path storage class already exists.") + services = { + "redis": { + "chart": "bitnami/redis", + "set_values": "--set master.service.nodePorts.redis=31001", + }, + "kafka": { + "chart": "bitnami/kafka", + "set_values": "--set service.nodePorts.client=31002", + }, + "mysql": { + "chart": "bitnami/mysql", + "set_values": "--set primary.service.nodePorts.mysql=31003", + }, + } + namespace = "devtools" + # Add Helm repos + add_repos() + # Check if the namespace exists, if not create it + check_namespace_cmd = f"kubectl get namespace {namespace}" + try: + subprocess.run(check_namespace_cmd.split(), check=True) + except subprocess.CalledProcessError: + create_namespace_cmd = f"kubectl create namespace {namespace}" + subprocess.run(create_namespace_cmd.split(), check=True) + for service, data in services.items(): + if locals().get(service): + chart = data["chart"] + set_values = data["set_values"] + cmd = f"helm install {service} {chart} {set_values} --namespace {namespace}" + print(cmd) + subprocess.run(cmd.split(), check=True) + if redisinsight: + install_redisinsight(ipadd) + + +@click.command() +@click.option("--redis", is_flag=True, help="Uninstall Redis") +@click.option("--kafka", is_flag=True, help="Uninstall Kafka") +@click.option("--mysql", is_flag=True, help="Uninstall MySQL") +@click.option("--redisinsight", is_flag=True, help="Uninstall Redis GUI") +def uninstall(redis, kafka, mysql, redisinsight): + services = ["redis", "kafka", "mysql", "redisinsight"] + namespace = "devtools" + try: + if redisinsight: + cmd = "kubectl delete ingress -n devtools redisinsight" + subprocess.run(cmd.split(), check=True) + except Exception as e: + print(e) + try: + for service in services: + if locals().get(service): + cmd = f"helm uninstall {service} --namespace {namespace}" + subprocess.run(cmd.split(), check=True) + except Exception as e: + print(e) + try: + print("namespace deleteted") + delns = "kubectl delete namespace devtools" + subprocess.run(cmd.split(), check=True) + except Exception as e: # !!! + print(e) + pass + # code to handle any exception + + +@click.command() +def list_services(): + namespace = "devtools" + # for service in services: + cmd = f"helm ls --namespace {namespace} " + subprocess.run(cmd.split(), check=True) + + +def list_services_h(): + namespace = "devtools" + return get_installed_releases(namespace) + + +@click.command() +@click.option("--redis", is_flag=True, help="Install Redis") +@click.option("--kafka", is_flag=True, help="Install Kafka") +@click.option("--mysql", is_flag=True, help="Install MySQL") +@click.option("--redisinsight", is_flag=True, help="Install Redis GUI") +def status(redis, kafka, mysql, redisinsight): + namespace = "devtools" + if redis: + svc_password = get_svc_password(namespace, "redis", "redis-password") + print_svc_info( + "redis-master-0.redis-headless.devtools.svc.cluster.local", + 6379, + "default", + svc_password, + "-------", + ) + if kafka: + print_svc_info("kafka", 9092, "-------", "-------", "-------") + if mysql: + svc_password = get_svc_password(namespace, "mysql", "mysql-root-password") + print_svc_info("mysql", 3306, "root", svc_password, "-------") + if redisinsight: + fqdn = get_ingress_controller_version() + full_domain = "https://redisinsight" + fqdn + print_svc_info("", " " + full_domain, "-------", "-------", "-------") + + +def status_h(svc): + namespace = "devtools" + if svc == "redis": + svc_password = get_svc_password(namespace, "redis", "redis-password") + dict = { + "app_url": "redis-master-0.redis-headless.devtools.svc.cluster.local:6379", + "username": "default", + "password": svc_password, + } + return dict + if svc == "kafka": + dict = {"app_url": "kafka-0.kafka-headless.devtools.svc.cluster.local:9092"} + return dict + if svc == "mysql": + svc_password = get_svc_password(namespace, "mysql", "mysql-root-password") + dict = { + "app_url": "mysql-0.mysql.devtools.svc.cluster.local:3306", + "username": "root", + "password": svc_password, + } + return dict + if svc == "redisinsight": + fqdn = get_ingress_controller_version() + full_domain = "https://redisinsight" + fqdn + dict = {"app_url": full_domain} + return dict + + +@click.group() +def cli(): + pass + + +cli.add_command(install) +cli.add_command(uninstall) +cli.add_command(list_services) +cli.add_command(status) + +if __name__ == "__main__": + cli() diff --git a/automation/system_test/prepare.py b/automation/system_test/prepare.py index c2db2d62d42e..7d2e164bd291 100644 --- a/automation/system_test/prepare.py +++ b/automation/system_test/prepare.py @@ -15,11 +15,14 @@ import datetime import logging +import os import pathlib +import shutil import subprocess import sys import tempfile import time +import typing import urllib.parse import boto3 @@ -27,8 +30,10 @@ import paramiko import yaml +# TODO: remove and use local logger import mlrun.utils +project_dir = pathlib.Path(__file__).resolve().parent.parent.parent logger = mlrun.utils.create_logger(level="debug", name="automation") logging.getLogger("paramiko").setLevel(logging.DEBUG) @@ -41,7 +46,10 @@ class Constants: igz_version_file = homedir / "igz" / "version.txt" mlrun_code_path = workdir / "mlrun" provctl_path = workdir / "provctl" - system_tests_env_yaml = pathlib.Path("tests") / "system" / "env.yml" + system_tests_env_yaml = ( + project_dir / pathlib.Path("tests") / "system" / "env.yml" + ) + namespace = "default-tenant" git_url = "https://github.com/mlrun/mlrun.git" @@ -67,9 +75,12 @@ def __init__( access_key: str = None, iguazio_version: str = None, spark_service: str = None, - password: str = None, slack_webhook_url: str = None, + mysql_user: str = None, + mysql_password: str = None, + purge_db: bool = False, debug: bool = False, + branch: str = None, ): self._logger = logger self._debug = debug @@ -91,6 +102,9 @@ def __init__( self._provctl_download_s3_access_key = provctl_download_s3_access_key self._provctl_download_s3_key_id = provctl_download_s3_key_id self._iguazio_version = iguazio_version + self._mysql_user = mysql_user + self._mysql_password = mysql_password + self._purge_db = purge_db self._env_config = { "MLRUN_DBPATH": mlrun_dbpath, @@ -100,9 +114,11 @@ def __init__( "V3IO_ACCESS_KEY": access_key, "MLRUN_SYSTEM_TESTS_DEFAULT_SPARK_SERVICE": spark_service, "MLRUN_SYSTEM_TESTS_SLACK_WEBHOOK_URL": slack_webhook_url, + "MLRUN_SYSTEM_TESTS_BRANCH": branch, + # Setting to MLRUN_SYSTEM_TESTS_GIT_TOKEN instead of GIT_TOKEN, to not affect tests which doesn't need it + # (e.g. tests which use public repos, therefor doesn't need that access token) + "MLRUN_SYSTEM_TESTS_GIT_TOKEN": github_access_token, } - if password: - self._env_config["V3IO_PASSWORD"] = password def prepare_local_env(self): self._prepare_env_local() @@ -123,6 +139,13 @@ def connect_to_remote(self): def run(self): self.connect_to_remote() + try: + logger.debug("installing dev utilities") + self._install_dev_utilities() + logger.debug("installing dev utilities - done") + except Exception as exp: + self._logger.error("error on install dev utilities", exception=str(exp)) + # for sanity clean up before starting the run self.clean_up_remote_workdir() @@ -134,6 +157,12 @@ def run(self): self._override_mlrun_api_env() + # purge of the database needs to be executed before patching mlrun so that the mlrun migrations + # that run as part of the patch would succeed even if we move from a newer version to an older one + # e.g from development branch which is (1.4.0) and has a newer alembic revision than 1.3.x which is (1.3.1) + if self._purge_db: + self._purge_mlrun_db() + self._patch_mlrun() def clean_up_remote_workdir(self): @@ -155,7 +184,7 @@ def _run_command( local: bool = False, detach: bool = False, verbose: bool = True, - ) -> str: + ) -> (bytes, bytes): workdir = workdir or str(self.Constants.workdir) stdout, stderr, exit_status = "", "", 0 @@ -170,10 +199,10 @@ def _run_command( workdir=workdir, ) if self._debug: - return "" + return b"", b"" try: if local: - stdout, stderr, exit_status = self._run_command_locally( + stdout, stderr, exit_status = run_command( command, args, workdir, stdin, live ) else: @@ -189,15 +218,19 @@ def _run_command( if exit_status != 0 and not suppress_errors: raise RuntimeError(f"Command failed with exit status: {exit_status}") except (paramiko.SSHException, RuntimeError) as exc: + err_log_kwargs = { + "error": str(exc), + "stdout": stdout, + "stderr": stderr, + "exit_status": exit_status, + } if verbose: - self._logger.error( - f"Failed running command {log_command_location}", - command=command, - error=exc, - stdout=stdout, - stderr=stderr, - exit_status=exit_status, - ) + err_log_kwargs["command"] = command + + self._logger.error( + f"Failed running command {log_command_location}", + **err_log_kwargs, + ) raise else: if verbose: @@ -208,7 +241,7 @@ def _run_command( stderr=stderr, exit_status=exit_status, ) - return stdout + return stdout, stderr def _run_command_remotely( self, @@ -256,45 +289,6 @@ def _run_command_remotely( return stdout, stderr, exit_status - @staticmethod - def _run_command_locally( - command: str, - args: list = None, - workdir: str = None, - stdin: str = None, - live: bool = True, - ) -> (str, str, int): - stdout, stderr, exit_status = "", "", 0 - if workdir: - command = f"cd {workdir}; " + command - if args: - command += " " + " ".join(args) - - process = subprocess.Popen( - command, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - stdin=subprocess.PIPE, - shell=True, - ) - - if stdin: - process.stdin.write(bytes(stdin, "ascii")) - process.stdin.close() - - if live: - for line in iter(process.stdout.readline, b""): - stdout += str(line) - sys.stdout.write(line.decode(sys.stdout.encoding)) - else: - stdout = process.stdout.read() - - stderr = process.stderr.read() - - exit_status = process.wait() - - return stdout, stderr, exit_status - def _prepare_env_remote(self): self._run_command( "mkdir", @@ -303,16 +297,20 @@ def _prepare_env_remote(self): ) def _prepare_env_local(self): - contents = yaml.safe_dump(self._env_config) filepath = str(self.Constants.system_tests_env_yaml) + backup_filepath = str(self.Constants.system_tests_env_yaml) + ".bak" self._logger.debug("Populating system tests env.yml", filepath=filepath) - self._run_command( - "cat > ", - workdir=".", - args=[filepath], - stdin=contents, - local=True, - ) + + # if filepath exists, backup the file first (to avoid overriding it) + if os.path.isfile(filepath) and not os.path.isfile(backup_filepath): + self._logger.debug( + "Backing up existing env.yml", destination=backup_filepath + ) + shutil.copy(filepath, backup_filepath) + + serialized_env_config = self._serialize_env_config() + with open(filepath, "w") as f: + f.write(serialized_env_config) def _override_mlrun_api_env(self): version_specifier = ( @@ -334,7 +332,10 @@ def _override_mlrun_api_env(self): "apiVersion": "v1", "data": data, "kind": "ConfigMap", - "metadata": {"name": "mlrun-override-env", "namespace": "default-tenant"}, + "metadata": { + "name": "mlrun-override-env", + "namespace": self.Constants.namespace, + }, } manifest_file_name = "override_mlrun_registry.yml" self._run_command( @@ -348,11 +349,33 @@ def _override_mlrun_api_env(self): args=["apply", "-f", manifest_file_name], ) + def _install_dev_utilities(self): + list_uninstall = [ + "dev_utilities.py", + "uninstall", + "--redis", + "--mysql", + "--redisinsight", + "--kafka", + ] + list_install = [ + "dev_utilities.py", + "install", + "--redis", + "--mysql", + "--redisinsight", + "--kafka", + "--ipadd", + os.environ.get("IP_ADDR_PREFIX", "localhost"), + ] + self._run_command("rm", args=["-rf", "/home/iguazio/dev_utilities"]) + self._run_command("python3", args=list_uninstall, workdir="/home/iguazio/") + self._run_command("python3", args=list_install, workdir="/home/iguazio/") + def _download_provctl(self): # extract bucket name, object name from s3 file path # https://.s3.amazonaws.com/ # s3:/// - parsed_url = urllib.parse.urlparse(self._provctl_download_url) if self._provctl_download_url.startswith("s3://"): object_name = parsed_url.path.lstrip("/") @@ -360,7 +383,6 @@ def _download_provctl(self): else: object_name = parsed_url.path.lstrip("/") bucket_name = parsed_url.netloc.split(".")[0] - # download provctl from s3 with tempfile.NamedTemporaryFile() as local_provctl_path: self._logger.debug( @@ -375,7 +397,6 @@ def _download_provctl(self): aws_access_key_id=self._provctl_download_s3_key_id, ) s3_client.download_file(bucket_name, object_name, local_provctl_path.name) - # upload provctl to data node self._logger.debug( "Uploading provctl to datanode", @@ -385,7 +406,6 @@ def _download_provctl(self): sftp_client = self._ssh_client.open_sftp() sftp_client.put(local_provctl_path.name, str(self.Constants.provctl_path)) sftp_client.close() - # make provctl executable self._run_command("chmod", args=["+x", str(self.Constants.provctl_path)]) @@ -473,6 +493,9 @@ def _patch_mlrun(self): self._data_cluster_ssh_password, "patch", "appservice", + # we force because by default provctl doesn't allow downgrading between version but due to system tests + # running on multiple branches this might occur. + "--force", "mlrun", mlrun_archive, ], @@ -488,21 +511,118 @@ def _patch_mlrun(self): self._run_command(f"cat {provctl_patch_mlrun_log}") def _resolve_iguazio_version(self): - # iguazio version is optional, if not provided, we will try to resolve it from the data node if not self._iguazio_version: self._logger.info("Resolving iguazio version") - self._iguazio_version = self._run_command( + self._iguazio_version, _ = self._run_command( f"cat {self.Constants.igz_version_file}", verbose=False, live=False, - ).strip() - if isinstance(self._iguazio_version, bytes): - self._iguazio_version = self._iguazio_version.decode("utf-8") + ) + self._iguazio_version = self._iguazio_version.strip().decode() self._logger.info( "Resolved iguazio version", iguazio_version=self._iguazio_version ) + def _purge_mlrun_db(self): + """ + Purge mlrun db - exec into mlrun-db pod, delete the database and scale down mlrun pods + """ + self._delete_mlrun_db() + self._scale_down_mlrun_deployments() + + def _delete_mlrun_db(self): + self._logger.info("Deleting mlrun db") + + mlrun_db_pod_name_cmd = self._get_pod_name_command( + labels={ + "app.kubernetes.io/component": "db", + "app.kubernetes.io/instance": "mlrun", + }, + ) + if not mlrun_db_pod_name_cmd: + self._logger.info("No mlrun db pod found") + return + + self._logger.info( + "Deleting mlrun db pod", mlrun_db_pod_name_cmd=mlrun_db_pod_name_cmd + ) + + password = "" + if self._mysql_password: + password = f"-p {self._mysql_password} " + + drop_db_cmd = f"mysql --socket=/run/mysqld/mysql.sock -u {self._mysql_user} {password}-e 'DROP DATABASE mlrun;'" + self._run_kubectl_command( + args=[ + "exec", + "-n", + self.Constants.namespace, + "-it", + mlrun_db_pod_name_cmd, + "--", + drop_db_cmd, + ], + verbose=False, + ) + + def _get_pod_name_command(self, labels): + labels_selector = ",".join([f"{k}={v}" for k, v in labels.items()]) + pod_name, stderr = self._run_kubectl_command( + args=[ + "get", + "pods", + "--namespace", + self.Constants.namespace, + "--selector", + labels_selector, + "|", + "tail", + "-n", + "1", + "|", + "awk", + "'{print $1}'", + ], + ) + if b"No resources found" in stderr or not pod_name: + return None + return pod_name.strip() + + def _scale_down_mlrun_deployments(self): + # scaling down to avoid automatically deployments restarts and failures + self._logger.info("scaling down mlrun deployments") + self._run_kubectl_command( + args=[ + "scale", + "deployment", + "-n", + self.Constants.namespace, + "mlrun-api-chief", + "mlrun-api-worker", + "mlrun-db", + "--replicas=0", + ] + ) + + def _run_kubectl_command(self, args, verbose=True): + return self._run_command( + command="kubectl", + args=args, + verbose=verbose, + ) + + def _serialize_env_config(self, allow_none_values: bool = False): + env_config = self._env_config.copy() + + # we sanitize None values from config to avoid "null" values in yaml + if not allow_none_values: + for key in list(env_config): + if env_config[key] is None: + del env_config[key] + + return yaml.safe_dump(env_config) + @click.group() def main(): @@ -510,7 +630,7 @@ def main(): @main.command(context_settings=dict(ignore_unknown_options=True)) -@click.argument("mlrun-version", type=str, required=True) +@click.option("--mlrun-version") @click.option( "--override-image-registry", "-oireg", @@ -535,23 +655,25 @@ def main(): default=None, help="The commit (in mlrun/mlrun) of the tested mlrun version.", ) -@click.argument("data-cluster-ip", type=str, required=True) -@click.argument("data-cluster-ssh-username", type=str, required=True) -@click.argument("data-cluster-ssh-password", type=str, required=True) -@click.argument("app-cluster-ssh-password", type=str, required=True) -@click.argument("github-access-token", type=str, required=True) -@click.argument("provctl-download-url", type=str, required=True) -@click.argument("provctl-download-s3-access-key", type=str, required=True) -@click.argument("provctl-download-s3-key-id", type=str, required=True) -@click.argument("mlrun-dbpath", type=str, required=True) -@click.argument("webapi-direct-url", type=str, required=True) -@click.argument("framesd-url", type=str, required=True) -@click.argument("username", type=str, required=True) -@click.argument("access-key", type=str, required=True) -@click.argument("iguazio-version", type=str, default=None, required=True) -@click.argument("spark-service", type=str, required=True) -@click.argument("password", type=str, default=None, required=False) -@click.argument("slack-webhook-url", type=str, default=None, required=False) +@click.option("--data-cluster-ip", required=True) +@click.option("--data-cluster-ssh-username", required=True) +@click.option("--data-cluster-ssh-password", required=True) +@click.option("--app-cluster-ssh-password", required=True) +@click.option("--github-access-token", required=True) +@click.option("--provctl-download-url", required=True) +@click.option("--provctl-download-s3-access-key", required=True) +@click.option("--provctl-download-s3-key-id", required=True) +@click.option("--mlrun-dbpath", required=True) +@click.option("--webapi-direct-url", required=True) +@click.option("--framesd-url", required=True) +@click.option("--username", required=True) +@click.option("--access-key", required=True) +@click.option("--iguazio-version", default=None) +@click.option("--spark-service", required=True) +@click.option("--slack-webhook-url") +@click.option("--mysql-user") +@click.option("--mysql-password") +@click.option("--purge-db", "-pdb", is_flag=True, help="Purge mlrun db") @click.option( "--debug", "-d", @@ -579,8 +701,10 @@ def run( access_key: str, iguazio_version: str, spark_service: str, - password: str, slack_webhook_url: str, + mysql_user: str, + mysql_password: str, + purge_db: bool, debug: bool, ): system_test_preparer = SystemTestPreparer( @@ -604,8 +728,10 @@ def run( access_key, iguazio_version, spark_service, - password, slack_webhook_url, + mysql_user, + mysql_password, + purge_db, debug, ) try: @@ -616,20 +742,26 @@ def run( @main.command(context_settings=dict(ignore_unknown_options=True)) -@click.argument("mlrun-dbpath", type=str, required=True) -@click.argument("webapi-direct-url", type=str, required=True) -@click.argument("framesd-url", type=str, required=True) -@click.argument("username", type=str, required=True) -@click.argument("access-key", type=str, required=True) -@click.argument("spark-service", type=str, required=True) -@click.argument("password", type=str, default=None, required=False) -@click.argument("slack-webhook-url", type=str, default=None, required=False) +@click.option("--mlrun-dbpath", help="The mlrun api address", required=True) +@click.option("--webapi-direct-url", help="Iguazio webapi direct url") +@click.option("--framesd-url", help="Iguazio framesd url") +@click.option("--username", help="Iguazio running username") +@click.option("--access-key", help="Iguazio running user access key") +@click.option("--spark-service", help="Iguazio kubernetes spark service name") +@click.option( + "--slack-webhook-url", help="Slack webhook url to send tests notifications to" +) @click.option( "--debug", "-d", is_flag=True, help="Don't run the ci only show the commands that will be run", ) +@click.option("--branch", help="The mlrun branch to run the tests against") +@click.option( + "--github-access-token", + help="Github access token to use for fetching private functions", +) def env( mlrun_dbpath: str, webapi_direct_url: str, @@ -637,9 +769,10 @@ def env( username: str, access_key: str, spark_service: str, - password: str, slack_webhook_url: str, debug: bool, + branch: str, + github_access_token: str, ): system_test_preparer = SystemTestPreparer( mlrun_dbpath=mlrun_dbpath, @@ -648,9 +781,10 @@ def env( username=username, access_key=access_key, spark_service=spark_service, - password=password, debug=debug, slack_webhook_url=slack_webhook_url, + branch=branch, + github_access_token=github_access_token, ) try: system_test_preparer.prepare_local_env() @@ -659,5 +793,59 @@ def env( raise +def run_command( + command: str, + args: list = None, + workdir: str = None, + stdin: str = None, + live: bool = True, + log_file_handler: typing.IO[str] = None, +) -> (str, str, int): + if workdir: + command = f"cd {workdir}; " + command + if args: + command += " " + " ".join(args) + + process = subprocess.Popen( + command, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + stdin=subprocess.PIPE, + shell=True, + ) + + if stdin: + process.stdin.write(bytes(stdin, "ascii")) + process.stdin.close() + + stdout = _handle_command_stdout(process.stdout, log_file_handler, live) + stderr = process.stderr.read() + exit_status = process.wait() + + return stdout, stderr, exit_status + + +def _handle_command_stdout( + stdout_stream: typing.IO[bytes], + log_file_handler: typing.IO[str] = None, + live: bool = True, +) -> str: + def _write_to_log_file(text: bytes): + if log_file_handler: + log_file_handler.write(text.decode(sys.stdout.encoding)) + + stdout = "" + if live: + for line in iter(stdout_stream.readline, b""): + stdout += str(line) + sys.stdout.write(line.decode(sys.stdout.encoding)) + _write_to_log_file(line) + else: + stdout = stdout_stream.read() + _write_to_log_file(stdout) + + return stdout + + if __name__ == "__main__": main() diff --git a/conda-arm64-requirements.txt b/conda-arm64-requirements.txt new file mode 100644 index 000000000000..f984aba006d9 --- /dev/null +++ b/conda-arm64-requirements.txt @@ -0,0 +1,3 @@ +# with moving to arm64 for the new M1/M2 macs some packages are not yet compatible via pip and require +# conda which supports different architecture environments on the same machine +lightgbm>=3.0 diff --git a/dependencies.py b/dependencies.py index 1d4239022c7b..ab7a8278234a 100644 --- a/dependencies.py +++ b/dependencies.py @@ -31,21 +31,20 @@ def extra_requirements() -> typing.Dict[str, typing.List[str]]: # - We have a copy of these in extras-requirements.txt. If you modify these, make sure to change it # there as well extras_require = { - # from 1.17.107 boto3 requires botocore>=1.20.107,<1.21.0 which - # conflicts with s3fs 2021.8.1 that has aiobotocore~=1.4.0 - # which so far (1.4.1) has botocore>=1.20.106,<1.20.107 - # boto3 1.17.106 has botocore>=1.20.106,<1.21.0, so we must add botocore explicitly + # last version that supports python 3.7: fsspec: 2023.1.0, aiobotocore: 2.4.2, adlfs: 2022.2.0 + # selecting ~=2023.1.0 for fsspec and its implementations s3fs and gcsfs (adlfs pinned per comment above) + # s3fs 2023.1.0 requires aiobotocore 2.4.2 which requires botocore 1.27.59 + # requesting boto3 1.24.59, the only version that requires botocore 1.27.59 "s3": [ - "boto3~=1.9, <1.17.107", - "botocore>=1.20.106,<1.20.107", - "aiobotocore~=1.4.0", - "s3fs~=2021.8.1", + "boto3~=1.24.59", + "aiobotocore~=2.4.2", + "s3fs~=2023.1.0", ], "azure-blob-storage": [ "msrest~=0.6.21", "azure-core~=1.24", "azure-storage-blob~=12.13", - "adlfs~=2021.8.1", + "adlfs~=2022.2.0", "pyopenssl>=23", ], "azure-key-vault": [ @@ -69,9 +68,13 @@ def extra_requirements() -> typing.Dict[str, typing.List[str]]: "google-cloud-bigquery[pandas, bqstorage]~=3.2", "google-cloud~=0.34", ], - "google-cloud-storage": ["gcsfs~=2021.8.1"], + "google-cloud-storage": ["gcsfs~=2023.1.0"], "google-cloud-bigquery": ["google-cloud-bigquery[pandas, bqstorage]~=3.2"], - "kafka": ["kafka-python~=2.0"], + "kafka": [ + "kafka-python~=2.0", + # because confluent kafka supports avro format by default + "avro~=1.11", + ], "redis": ["redis~=4.3"], } diff --git a/dev-requirements.txt b/dev-requirements.txt index 46645ab0afaa..5f60e63bc9a0 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,6 +1,6 @@ -pytest~=6.0 +pytest~=7.0 twine~=3.1 -black~=22.0 +black[jupyter]~=22.0 flake8~=5.0 pytest-asyncio~=0.15.0 pytest-alembic~=0.9.1 @@ -17,6 +17,7 @@ avro~=1.11 # needed for mlutils tests scikit-learn~=1.0 # needed for frameworks tests -lightgbm~=3.0 +lightgbm~=3.0; platform_machine != 'arm64' xgboost~=1.1 sqlalchemy_utils~=0.39.0 +import-linter~=1.8 diff --git a/dockerfiles/base/requirements.txt b/dockerfiles/base/requirements.txt index 34d6615c2931..f5ba8cd5bda6 100644 --- a/dockerfiles/base/requirements.txt +++ b/dockerfiles/base/requirements.txt @@ -10,7 +10,7 @@ lifelines~=0.25.0 # so, it cannot be logged as artifact (raised UnicodeEncode error - ML-3255) plotly~=5.4, <5.12.0 pyod~=0.8.1 -pytest~=6.0 +pytest~=7.0 scikit-multiflow~=0.5.3 scikit-optimize~=0.8.1 scikit-image~=0.16.0 diff --git a/dockerfiles/jupyter/requirements.txt b/dockerfiles/jupyter/requirements.txt index 32934bf13df7..5747a632c67c 100644 --- a/dockerfiles/jupyter/requirements.txt +++ b/dockerfiles/jupyter/requirements.txt @@ -6,7 +6,7 @@ scikit-plot~=0.3.7 xgboost~=1.1 graphviz~=0.20.0 python-dotenv~=0.17.0 -nuclio-jupyter[jupyter-server]~=0.9.9 +nuclio-jupyter[jupyter-server]~=0.9.10 nbclassic>=0.2.8 # added to tackle security vulnerabilities notebook~=6.4 diff --git a/dockerfiles/mlrun-api/requirements.txt b/dockerfiles/mlrun-api/requirements.txt index f5f07ec6564c..3901bb83e1d9 100644 --- a/dockerfiles/mlrun-api/requirements.txt +++ b/dockerfiles/mlrun-api/requirements.txt @@ -3,3 +3,4 @@ dask-kubernetes~=0.11.0 apscheduler~=3.6 sqlite3-to-mysql~=1.4 objgraph~=3.5 +igz-mgmt~=0.0.8 diff --git a/dockerfiles/test-system/requirements.txt b/dockerfiles/test-system/requirements.txt index 4cf5c9bf2096..cd7bffd531a2 100644 --- a/dockerfiles/test-system/requirements.txt +++ b/dockerfiles/test-system/requirements.txt @@ -1,4 +1,4 @@ -pytest~=6.0 +pytest~=7.0 matplotlib~=3.5 graphviz~=0.20.0 scikit-learn~=1.0 diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 5a6c03878999..c7f9f8d5724c 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -1,6 +1,6 @@ # Documenting mlrun -This document describe how to write the external documentation for `mlrun`, the +This document describes how to write the external documentation for `mlrun`, the one you can view at https://mlrun.readthedocs.io ## Technology diff --git a/docs/api/mlrun.db.rst b/docs/api/mlrun.db.rst index 32524a250d3a..2247fc752a4e 100644 --- a/docs/api/mlrun.db.rst +++ b/docs/api/mlrun.db.rst @@ -6,7 +6,7 @@ mlrun.db :show-inheritance: :undoc-members: -.. autoclass:: mlrun.api.schemas.secret::SecretProviderName +.. autoclass:: mlrun.common.schemas.secret::SecretProviderName :members: :show-inheritance: :undoc-members: diff --git a/docs/api/mlrun.feature_store.rst b/docs/api/mlrun.feature_store.rst index 8a7776e60d5f..a1328bc40a4d 100644 --- a/docs/api/mlrun.feature_store.rst +++ b/docs/api/mlrun.feature_store.rst @@ -9,6 +9,13 @@ mlrun.feature_store .. autoclass:: mlrun.feature_store.feature_set.FeatureSetSpec .. autoclass:: mlrun.feature_store.feature_set.FeatureSetStatus +.. autoclass:: mlrun.feature_store.steps.MLRunStep + :members: + :private-members: _do_pandas, _do_storey, _do_spark + .. automodule:: mlrun.feature_store.steps + :exclude-members: MLRunStep :members: :special-members: __init__ + + diff --git a/docs/api/mlrun.serving.rst b/docs/api/mlrun.serving.rst index aa36f96465de..6efd3dcdb17e 100644 --- a/docs/api/mlrun.serving.rst +++ b/docs/api/mlrun.serving.rst @@ -9,3 +9,6 @@ mlrun.serving .. automodule:: mlrun.serving.remote :members: :special-members: __init__ + +.. autoclass:: mlrun.serving.utils.StepToDict + :members: diff --git a/docs/architecture.md b/docs/architecture.md index 9840be4931fa..65d6677a872e 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -1,4 +1,4 @@ -(architecture)= +(mlrun-architecture)= # MLRun architecture diff --git a/docs/change-log/index.md b/docs/change-log/index.md index 7afb105a0a13..6d3103c0d842 100644 --- a/docs/change-log/index.md +++ b/docs/change-log/index.md @@ -1,6 +1,11 @@ (change-log)= # Change log +- [v1.3.3](#v1-3-3) +- [v1.3.2](#v1-3-2) +- [v1.3.1](#v1-3-1) - [v1.3.0](#v1-3-0) +- [v1.2.3](#v1-2-3) +- [v1.2.2](#v1-2-2) - [v1.2.1](#v1-2-1) - [v1.2.0](#v1-2-0) - [v1.1.3](#1-1-3) @@ -14,6 +19,39 @@ - [Limitations](#limitations) - [Deprecations](#deprecations) +## v1.3.3 + +### Closed issues + +| ID | Description | +| --- | ----------------------------------------------------------------- | +| ML-3940 | MLRun does not initiate log collection for runs in aborted state. [View in Git](https://github.com/mlrun/mlrun/pull/3698). | + +## v1.3.2 + +### Closed issues + +| ID | Description | +| --- | ----------------------------------------------------------------- | +| ML-3896 | Fixed: MLRun API failed to get pod logs. [View in Git](https://github.com/mlrun/mlrun/pull/3649). | +| ML-3865 | kubectl now returns logs as expected. [View in Git](https://github.com/mlrun/mlrun/pull/3660). | +| ML-3917 | Reduced number of logs. [View in Git](https://github.com/mlrun/mlrun/pull/3674). | +| ML-3934 | Logs are no longer collected for run pods in an unknown state [View in Git](https://github.com/mlrun/mlrun/pull/3690). | + +## v1.3.1 + +### Closed issues + +| ID | Description | +| --- | ----------------------------------------------------------------- | +| ML-3764 | Fixed the scikit-learn to 1.2 in the tutorial 02-model-training. (Previously pointed to 1.0.) [View in Git](https://github.com/mlrun/mlrun/pull/3437). | +| ML-3794 | Fixed a Mask detection demo notebook (3-automatic-pipeline.ipynb). [View in Git](https://github.com/mlrun/demos/releases/tag/v1.3.1-rc6). | +| ML-3819 | Reduce overly-verbose logs on the backend side. [View in Git](https://github.com/mlrun/mlrun/pull/3531). [View in Git](https://github.com/mlrun/mlrun/pull/3553). | +| ML-3823 | Optimized `/projects` endpoint to work faster. [View in Git](https://github.com/mlrun/mlrun/pull/3560). | + +### Documentation +New sections describing [Git best practices](../projects/git-best-practices.html) and an example [Nuclio function](../concepts/nuclio-real-time-functions.html#example-of-nuclio-function). + ## v1.3.0 ### Client/server matrix, prerequisites, and installing @@ -25,23 +63,24 @@ python 3.7 have the suffix: `-py37`. The correct version is automatically chosen MLRun is pre-installed in CE Jupyter. -To install on a **Python 3.9** client, run:
+To install on a **Python 3.9** environment, run:
``` ./align_mlrun.sh ``` -To install on a **Python 3.7** client, run: +To install on a **Python 3.7** environment (and optionally upgrade to python 3.9), run: -1. Configure the Jupyter service with the env variable`JUPYTER_PREFER_ENV_PATH=false`. -2. Within the Jupyter service, open a terminal and update conda and pip to have an up to date pip resolver. +1. Configure the Jupyter service with the env variable `JUPYTER_PREFER_ENV_PATH=false`. +2. Within the Jupyter service, open a terminal and update conda and pip to have an up-to-date pip resolver. -```$CONDA_HOME/bin/conda install -y conda=23.1.0 - $CONDA_HOME/bin/conda install -y pip ``` -3. If you are going to work with python 3.9, create a new conda env and activate it: +$CONDA_HOME/bin/conda install -y conda=23.1.0 +$CONDA_HOME/bin/conda install -y 'pip>=22.0' +``` +3. If you wish to upgrade to python 3.9, create a new conda env and activate it: ``` - conda create -n python39 python=3.9 ipykernel -y - conda activate python39 +conda create -n python39 python=3.9 ipykernel -y +conda activate python39 ``` 4. Install mlrun: ``` @@ -67,7 +106,7 @@ To install on a **Python 3.7** client, run: #### Logging data | ID | Description | | --- | ----------------------------------------------------------------- | -| ML-2845 | Logging data using `hints`. You can now pass data into MLRun and log it using log hints, instead of the decorator. This is the initial change in MLRun to simplify wrapping usable code into MLRun without having to modify it. Future releases will continue this paradigm shift. See [more details](../cheat-sheet.html#track-returning-values-using-returns-new-in-v1-3-0). | +| ML-2845 | Logging data using `hints`. You can now pass data into MLRun and log it using log hints, instead of the decorator. This is the initial change in MLRun to simplify wrapping usable code into MLRun without having to modify it. Future releases will continue this paradigm shift. See [more details](../cheat-sheet.html#track-returning-values-using-hints-and-returns). | #### Projects @@ -201,6 +240,23 @@ The `--ensure-project` flag of the `mlrun project` CLI command is deprecated and | ML-3446 | Fix: Failed MLRun Nuclio deploy needs better error messages. [View in Git](https://github.com/mlrun/mlrun/pull/3241). | | ML-3482 | Fixed model-monitoring incompatibility issue with mlrun client running v1.1.x and a server running v1.2.x. [View in Git](https://github.com/mlrun/mlrun/pull/3180). | +## v1.2.3 + +### Closed issues + +| ID | Description | +| --- | ----------------------------------------------------------------- | +| ML-3287 | UI now resets the cache upon MLRun upgrades, and the Projects page displays correctly. [View in Git](https://github.com/mlrun/ui/pull/1612). | +| ML-3801 | Optimized `/projects` endpoint to work faster [View in Git](https://github.com/mlrun/ui/pull/1715). | +| ML-3819 | Reduce overly-verbose logs on the backend side. [View in Git](https://github.com/mlrun/mlrun/pull/3531). | + +## v1.2.2 + +### Closed issues + +| ID | Description | +| --- | ----------------------------------------------------------------- | +| ML-3797, ML-3798 | Fixed presenting and serving large-sized projects. [View in Git](https://github.com/mlrun/mlrun/pull/3477). | ## v1.2.1 @@ -527,6 +583,7 @@ with a drill-down to view the steps and their details. [Tech Preview] | ML-2014 | Model deployment returns ResourceNotFoundException (Nuclio error that Service is invalid.) | Verify that all `metadata.labels` values are 63 characters or less. See the [Kubernetes limitation](https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set). | v1.0.0 | | ML-3315 | The feature store does not support an aggregation of aggregations | NA | v1.2.1 | | ML-3381 | Private repo is not supported as a marketplace hub | NA | v1.2.1 | +| ML-3824 | MLRun supports TensorFlow up to 2.11. | NA | v1.3.1 | ## Deprecations diff --git a/docs/cheat-sheet.md b/docs/cheat-sheet.md index dffea56e5e38..af9987d34cdb 100644 --- a/docs/cheat-sheet.md +++ b/docs/cheat-sheet.md @@ -268,7 +268,7 @@ fn.with_limits(mem="2G", cpu=2, gpus=1) # Nuclio/serving scaling fn.spec.replicas = 2 fn.spec.min_replicas = 1 -fn.spec.min_replicas = 4 +fn.spec.max_replicas = 4 ``` #### Mount persistent storage @@ -445,12 +445,17 @@ run_id = project.run( Docs: [MLRun execution context](./concepts/mlrun-execution-context.html) ```python -context.logger.debug(message="Debugging info") -context.logger.info(message="Something happened") -context.logger.warning(message="Something might go wrong") -context.logger.error(message="Something went wrong") +context.logger.debug(message="Debugging info") # logging all (debug, info, warning, error) +context.logger.info(message="Something happened") # logging info, warning and error +context.logger.warning(message="Something might go wrong") # logging warning and error +context.logger.error(message="Something went wrong") # logging only error ``` +```{admonition} Note +The real-time (nuclio) function uses default logger level `debug` (logging all) +``` + + ## Experiment tracking Docs: [MLRun execution context](./concepts/mlrun-execution-context.html), [Automated experiment tracking](./concepts/auto-logging-mlops.html), [Decorators and auto-logging](./concepts/decorators-and-auto-logging.html) @@ -654,7 +659,7 @@ redis_target = RedisNoSqlTarget(name="write", path="redis://1.2.3.4:6379") redis_target.write_dataframe(df=redis_df) # Kafka (see docs for writing online features) -kafka_target = KafkaSource( +kafka_target = KafkaTarget( name="write", bootstrap_servers='localhost:9092', topic='topic', @@ -784,7 +789,7 @@ fstore.ingest( #### Aggregations -Docs: [add_aggregation()](./api/mlrun.feature_store.html#mlrun.feature_store.FeatureSet.add_aggregation) +Docs: [add_aggregation()](./api/mlrun.feature_store.html#mlrun.feature_store.FeatureSet.add_aggregation), [Aggregations](./feature-store/transformations.html#aggregations) ```python quotes_set = fstore.FeatureSet("stock-quotes", entities=[fstore.Entity("ticker")]) @@ -1084,7 +1089,6 @@ dask_cluster.apply(mlrun.mount_v3io()) # add volume mounts dask_cluster.spec.service_type = "NodePort" # open interface to the dask UI dashboard dask_cluster.spec.replicas = 2 # define two containers uri = dask_cluster.save() -uri # Run parallel hyperparameter trials hp_tuning_run_dask = project.run_function( diff --git a/docs/concepts/functions-overview.md b/docs/concepts/functions-overview.md index 638a7b2b484c..d9c5aa5b8249 100644 --- a/docs/concepts/functions-overview.md +++ b/docs/concepts/functions-overview.md @@ -8,12 +8,12 @@ MLRun supports real-time and batch runtimes. Real-time runtimes: * **{ref}`nuclio `** - real-time serverless functions over Nuclio -* **{ref}`serving `** - higher level real-time Graph (DAG) over one or more Nuclio functions +* **{ref}`serving `** - deploy models and higher-level real-time Graph (DAG) over one or more Nuclio functions Batch runtimes: * **handler** - execute python handler (used automatically in notebooks or for debug) * **local** - execute a Python or shell program -* **job** - run the code in a Kubernetes Pod +* **{ref}`job `** - run the code in a Kubernetes Pod * **{ref}`dask `** - run the code as a Dask Distributed job (over Kubernetes) * **{ref}`mpijob `** - run distributed jobs and Horovod over the MPI job operator, used mainly for deep learning jobs * **{ref}`spark `** - run the job as a Spark job (using Spark Kubernetes Operator) @@ -52,6 +52,8 @@ The limits methods are different for Spark and Dask: ```{toctree} :maxdepth: 1 +../runtimes/job-function +../runtimes/serving-function ../runtimes/dask-overview ../runtimes/horovod ../runtimes/spark-operator diff --git a/docs/concepts/notifications.md b/docs/concepts/notifications.md new file mode 100644 index 000000000000..b26820f6fa04 --- /dev/null +++ b/docs/concepts/notifications.md @@ -0,0 +1,131 @@ +(notifications)= + +# Notifications + +MLRun supports configuring notifications on jobs and scheduled jobs. This section describes the SDK for notifications. + +- [The Notification Object](#the-notification-object) +- [Local vs Remote](#local-vs-remote) +- [Notification Params and Secrets](#notification-params-and-secrets) +- [Notification Kinds](#notification-kinds) +- [Configuring Notifications For Runs](#configuring-notifications-for-runs) +- [Configuring Notifications For Pipelines](#configuring-notifications-for-pipelines) +- [Setting Notifications on Live Runs](#setting-notifications-on-live-runs) +- [Setting Notifications on Scheduled Runs](#setting-notifications-on-scheduled-runs) +- [Notification Conditions](#notification-conditions) + + +## The Notification Object +The notification object's schema is: +- `kind`: str - notification kind (slack, git, etc...) +- `when`: list[str] - run states on which to send the notification (completed, error, aborted) +- `name`: str - notification name +- `message`: str - notification message +- `severity`: str - notification severity (info, warning, error, debug) +- `params`: dict - notification parameters (See definitions in [Notification Kinds](#notification-kinds)) +- `condition`: str - jinja template for a condition that determines whether the notification is sent or not (See [Notification Conditions](#notification-conditions)) + + +## Local vs Remote +Notifications can be sent either locally from the SDK, or remotely from the MLRun API. +Usually, a local run sends locally, and a remote run sends remotely. +However, there are several special cases where the notification is sent locally either way. +These cases are: +- Pipelines: To conserve backwards compatibility, the SDK sends the notifications as it did before adding the run + notifications mechanism. This means you need to watch the pipeline in order for its notifications to be sent. +- Dask: Dask runs are always local (against a remote dask cluster), so the notifications are sent locally as well. + +## Notification Params and Secrets +The notification parameters might contain sensitive information (slack webhook, git token, etc.). For this reason, +when a notification is created its params are masked in a kubernetes secret. The secret is named +`-` (or `-`) and is created in the namespace where mlrun is +installed. In the notification params the secret reference is stored under the `secret` key once masked. + +## Notification Kinds + +Currently, the supported notification kinds and their params are as follows: + +- `slack`: + - `webhook`: The slack webhook to which to send the notification. +- `git`: + - `token`: The git token to use for the git notification. + - `repo`: The git repo to which to send the notification. + - `issue`: The git issue to which to send the notification. + - `merge_request`: In gitlab (as opposed to github), merge requests and issues are separate entities. + If using merge request, the issue will be ignored, and vice versa. + - `server`: The git server to which to send the notification. + - `gitlab`: (bool) Whether the git server is gitlab or not. +- `console` (no params, local only) +- `ipython` (no params, local only) + +## Configuring Notifications For Runs + +In any `run` method you can configure the notifications via their model. For example: + +```python +notification = mlrun.model.Notification( + kind="slack", + when=["completed","error"], + name="notification-1", + message="completed", + severity="info", + params={"webhook": ""} +) +function.run(handler=handler, notifications=[notification]) +``` + +## Configuring Notifications For Pipelines +For pipelines, you configure the notifications on the project notifiers. For example: + +```python +project.notifiers.add_notification(notification_type="slack",params={"webhook":""}) +``` +Instead of passing the webhook in the notification params, it is also possible in a Jupyter notebook to use the ` %env` +magic command: +``` +%env SLACK_WEBHOOK= +``` + +## Setting Notifications on Live Runs +You can set notifications on live runs via the `set_run_notifications` method. For example: + +```python +import mlrun + +mlrun.get_run_db().set_run_notifications("", "", [notification1, notification2]) +``` + +Using the `set_run_notifications` method overrides any existing notifications on the run. To delete all notifications, pass an empty list. + +## Setting Notifications on Scheduled Runs +You can set notifications on scheduled runs via the `set_schedule_notifications` method. For example: + +```python +import mlrun + +mlrun.get_run_db().set_schedule_notifications("", "", [notification1, notification2]) +``` + +Using the `set_schedule_notifications` method overrides any existing notifications on the schedule. To delete all notifications, pass an empty list. + +## Notification Conditions +You can configure the notification to be sent only if the run meets certain conditions. This is done using the `condition` +parameter in the notification object. The condition is a string that is evaluated using a jinja templator with the run +object in its context. The jinja template should return a boolean value that determines whether the notification is sent or not. +If any other value is returned or if the template is malformed, the condition is ignored and the notification is sent +as normal. + +Take the case of a run that calculates and outputs model drift. This example code sets a notification to fire only +if the drift is above a certain threshold: + +```python +notification = mlrun.model.Notification( + kind="slack", + when=["completed","error"], + name="notification-1", + message="completed", + severity="info", + params={"webhook": ""}, + condition='{{ run["status"]["results"]["drift"] > 0.1 }}' +) +``` diff --git a/docs/concepts/nuclio-real-time-functions.ipynb b/docs/concepts/nuclio-real-time-functions.ipynb new file mode 100644 index 000000000000..de9c63e5e2fb --- /dev/null +++ b/docs/concepts/nuclio-real-time-functions.ipynb @@ -0,0 +1,118 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "887ae6fb", + "metadata": {}, + "source": [ + "(nuclio-real-time-functions)=\n", + "# Nuclio real-time functions\n", + "\n", + "Nuclio is a high-performance \"serverless\" framework focused on data, I/O, and compute intensive workloads. It is well integrated with popular \n", + "data science tools, such as Jupyter and Kubeflow; supports a variety of data and streaming sources; and supports execution over CPUs and GPUs. \n", + "\n", + "You can use Nuclio through a fully managed application service (in the cloud or on-prem) in the Iguazio MLOps Platform. MLRun serving \n", + "utilizes serverless Nuclio functions to create multi-stage real-time pipelines. \n", + "\n", + "The underlying Nuclio serverless engine uses a high-performance parallel processing engine that maximizes the utilization of CPUs and GPUs, \n", + "supports 13 protocols and invocation methods (for example, HTTP, Cron, Kafka, Kinesis), and includes dynamic auto-scaling for HTTP and \n", + "streaming. Nuclio and MLRun support the full life cycle, including auto-generation of micro-services, APIs, load-balancing, logging, \n", + "monitoring, and configuration management—such that developers can focus on code, and deploy to production faster with minimal work.\n", + "\n", + "Nuclio is extremely fast: a single function instance can process hundreds of thousands of HTTP requests or data records per second. To learn \n", + "more about how Nuclio works, see the Nuclio architecture [documentation](https://nuclio.io/docs/latest/concepts/architecture/). \n", + "\n", + "Nuclio is secure: Nuclio is integrated with Kaniko to allow a secure and production-ready way of building Docker images at run time.\n", + "\n", + "Read more in the [Nuclio documentation](https://nuclio.io/docs/latest/) and the open-source [MLRun library](https://github.com/mlrun/mlrun).\n", + "\n", + "## Example of Nuclio function\n", + "\n", + "You can create your own Nuclio function, for example a data processing function. The following code illustrates an example of an MLRun function, of kind 'nuclio', that can be deployed to the cluster." + ] + }, + { + "cell_type": "markdown", + "id": "3c9b59b3", + "metadata": {}, + "source": [ + "Create a file `func.py` with the code of the function: \n", + "```\n", + "def handler(context, event):\n", + " return \"Hello\"\n", + "``` " + ] + }, + { + "cell_type": "markdown", + "id": "b2dcd26e", + "metadata": {}, + "source": [ + "Create the project and the Nuclio function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "105fb38e", + "metadata": {}, + "outputs": [], + "source": [ + "import mlrun" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc620518", + "metadata": {}, + "outputs": [], + "source": [ + "# Create the project\n", + "project = mlrun.get_or_create_project(\"nuclio-project\", \"./\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5dda40ef", + "metadata": {}, + "outputs": [], + "source": [ + "# Create a Nuclio function\n", + "project.set_function(\n", + " func=\"func.py\",\n", + " image=\"mlrun/mlrun\",\n", + " kind=\"nuclio\",\n", + " name=\"nuclio-func\",\n", + " handler=\"handler\",\n", + ")\n", + "# Save the function within the project\n", + "project.save()\n", + "# Deploy the function in the cluster\n", + "project.deploy_function(\"nuclio-func\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/concepts/nuclio-real-time-functions.md b/docs/concepts/nuclio-real-time-functions.md deleted file mode 100644 index b6d0acf227bf..000000000000 --- a/docs/concepts/nuclio-real-time-functions.md +++ /dev/null @@ -1,32 +0,0 @@ -(nuclio-real-time-functions)= -# Nuclio real-time functions - -Nuclio is a high-performance "serverless" framework focused on data, I/O, and compute intensive workloads. It is well integrated with popular -data science tools, such as Jupyter and Kubeflow; supports a variety of data and streaming sources; and supports execution over CPUs and GPUs. - -You can use Nuclio through a fully managed application service (in the cloud or on-prem) in the Iguazio MLOps Platform. MLRun serving -utilizes serverless Nuclio functions to create multi-stage real-time pipelines. - -The underlying Nuclio serverless engine uses a high-performance parallel processing engine that maximizes the utilization of CPUs and GPUs, -supports 13 protocols and invocation methods (for example, HTTP, Cron, Kafka, Kinesis), and includes dynamic auto-scaling for HTTP and -streaming. Nuclio and MLRun support the full life cycle, including auto-generation of micro-services, APIs, load-balancing, logging, -monitoring, and configuration management—such that developers can focus on code, and deploy to production faster with minimal work. - -Nuclio is extremely fast: a single function instance can process hundreds of thousands of HTTP requests or data records per second. To learn -more about how Nuclio works, see the Nuclio architecture [documentation](https://nuclio.io/docs/latest/concepts/architecture/). - -Nuclio is secure: Nuclio is integrated with Kaniko to allow a secure and production-ready way of building Docker images at run time. - -Read more in the [Nuclio documentation](https://nuclio.io/docs/latest/) and the open-source [MLRun library](https://github.com/mlrun/mlrun). - -## Why another "serverless" project? -None of the existing cloud and open-source serverless solutions addressed all the desired capabilities of a serverless framework: - -- Real-time processing with minimal CPU/GPU and I/O overhead and maximum parallelism -- Native integration with a large variety of data sources, triggers, processing models, and ML frameworks -- Stateful functions with data-path acceleration -- Simple debugging, regression testing, and multi-versioned CI/CD pipelines -- Portability across low-power devices, laptops, edge and on-prem clusters, and public clouds -- Open-source but designed for the enterprise (including logging, monitoring, security, and usability) - -Nuclio was created to fulfill these requirements. It was intentionally designed as an extendable open-source framework, using a modular and layered approach that supports constant addition of triggers and data sources, with the hope that many will join the effort of developing new modules, developer tools, and platforms for Nuclio. \ No newline at end of file diff --git a/docs/concepts/workflow-overview.md b/docs/concepts/workflow-overview.md index accb7b7e75c0..18bd41a06375 100644 --- a/docs/concepts/workflow-overview.md +++ b/docs/concepts/workflow-overview.md @@ -173,8 +173,7 @@ Instead of waiting for completion, you can set up a notification in Slack with a Use one of: ``` -# If you want to get slack notification after the run with the results summary, use -# project.notifiers.slack(webhook="https://") +project.notifiers.add_notification(notification_type="slack",params={"webhook":""}) ``` or in a Jupyter notebook with the` %env` magic command: ``` diff --git a/docs/conf.py b/docs/conf.py index 8809c559195b..272cd41eb630 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -63,9 +63,6 @@ def current_version(): "sphinx_reredirects", ] -# redirect paths due to filename changes -redirects = {"runtimes/load-from-marketplace": "load-from-hub.html"} - # Add any paths that contain templates here, relative to this directory. templates_path = [ "_templates", @@ -141,7 +138,12 @@ def current_version(): myst_url_schemes = ("http", "https", "mailto") myst_heading_anchors = 2 myst_all_links_external = True -myst_substitutions = {"version": version} + +myst_substitutions = { + "version": "version", + "ceversion": "v1.2.1", + "releasedocumentation": "docs.mlrun.org/en/v1.2.1/index.html", +} # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True diff --git a/docs/contents.rst b/docs/contents.rst index 6b0e9edfb24d..eff70d5ec6e2 100644 --- a/docs/contents.rst +++ b/docs/contents.rst @@ -21,6 +21,7 @@ Table of Contents concepts/runs-workflows serving/serving-graph concepts/monitoring + concepts/notifications .. toctree:: :maxdepth: 1 diff --git a/docs/data-prep/ingest-data-fs.md b/docs/data-prep/ingest-data-fs.md index d8ef7895ac38..87a04210db21 100644 --- a/docs/data-prep/ingest-data-fs.md +++ b/docs/data-prep/ingest-data-fs.md @@ -16,6 +16,8 @@ When targets are not specified, data is stored in the configured default targets ```{admonition} Limitations - Do not name columns starting with either `_` or `aggr_`. They are reserved for internal use. See also general limitations in [Attribute name restrictions](https://www.iguazio.com/docs/latest-release/data-layer/objects/attributes/#attribute-names). +- Do not name columns to match the regex pattern `.*_[a-z]+_[0-9]+[smhd]$`, where [a-z]+ is an aggregation name, +one of: count, sum, sqr, max, min, first, last, avg, stdvar, stddev. E.g. x_count_1h. - When using the pandas engine, do not use spaces (` `) or periods (`.`) in the column names. These cause errors in the ingestion. ``` @@ -174,7 +176,7 @@ either, pass the `db_uri` or overwrite the `MLRUN_SQL__URL` env var, in this for `mysql+pymysql://:@:/`, for example: ``` -source = SqlDBSource(table_name='my_table', +source = SQLSource(table_name='my_table', db_path="mysql+pymysql://abc:abc@localhost:3306/my_db", key_field='key', time_fields=['timestamp'], ) diff --git a/docs/deployment/batch_inference.ipynb b/docs/deployment/batch_inference.ipynb index 9db49a64d85f..fcf6d4702f62 100644 --- a/docs/deployment/batch_inference.ipynb +++ b/docs/deployment/batch_inference.ipynb @@ -90,7 +90,10 @@ "outputs": [], "source": [ "import mlrun\n", - "project = mlrun.get_or_create_project('batch-inference', context=\"./\", user_project=True)\n", + "\n", + "project = mlrun.get_or_create_project(\n", + " \"batch-inference\", context=\"./\", user_project=True\n", + ")\n", "batch_inference = mlrun.import_function(\"hub://batch_inference\")" ] }, @@ -109,12 +112,10 @@ "metadata": {}, "outputs": [], "source": [ - "model_path = mlrun.get_sample_path('models/batch-predict/model.pkl')\n", + "model_path = mlrun.get_sample_path(\"models/batch-predict/model.pkl\")\n", "\n", "model_artifact = project.log_model(\n", - " key=\"model\",\n", - " model_file=model_path,\n", - " framework=\"sklearn\"\n", + " key=\"model\", model_file=model_path, framework=\"sklearn\"\n", ")" ] }, @@ -133,7 +134,7 @@ "metadata": {}, "outputs": [], "source": [ - "prediction_set_path = mlrun.get_sample_path('data/batch-predict/prediction_set.parquet')" + "prediction_set_path = mlrun.get_sample_path(\"data/batch-predict/prediction_set.parquet\")" ] }, { @@ -415,7 +416,7 @@ " batch_inference,\n", " inputs={\"dataset\": prediction_set_path},\n", " params={\"model\": model_artifact.uri},\n", - " schedule='*/30 * * * *'\n", + " schedule=\"*/30 * * * *\",\n", ")" ] }, @@ -442,18 +443,17 @@ "metadata": {}, "outputs": [], "source": [ - "training_set_path = mlrun.get_sample_path('data/batch-predict/training_set.parquet')\n", + "training_set_path = mlrun.get_sample_path(\"data/batch-predict/training_set.parquet\")\n", "\n", "batch_run = project.run_function(\n", " batch_inference,\n", - " inputs={\n", - " \"dataset\": prediction_set_path,\n", - " \"sample_set\": training_set_path\n", + " inputs={\"dataset\": prediction_set_path, \"sample_set\": training_set_path},\n", + " params={\n", + " \"model\": model_artifact.uri,\n", + " \"label_columns\": \"label\",\n", + " \"perform_drift_analysis\": True,\n", " },\n", - " params={\"model\": model_artifact.uri,\n", - " \"label_columns\": \"label\",\n", - " \"perform_drift_analysis\" : True}\n", - ")\n" + ")" ] }, { diff --git a/docs/feature-store/basic-demo.ipynb b/docs/feature-store/basic-demo.ipynb index b4767501b4d8..a7af9a1d79a5 100644 --- a/docs/feature-store/basic-demo.ipynb +++ b/docs/feature-store/basic-demo.ipynb @@ -58,6 +58,7 @@ ], "source": [ "import mlrun\n", + "\n", "mlrun.get_or_create_project(\"stocks\", \"./\")" ] }, @@ -79,6 +80,7 @@ "outputs": [], "source": [ "import pandas as pd\n", + "\n", "quotes = pd.DataFrame(\n", " {\n", " \"time\": [\n", @@ -89,55 +91,49 @@ " pd.Timestamp(\"2016-05-25 13:30:00.048\"),\n", " pd.Timestamp(\"2016-05-25 13:30:00.049\"),\n", " pd.Timestamp(\"2016-05-25 13:30:00.072\"),\n", - " pd.Timestamp(\"2016-05-25 13:30:00.075\")\n", + " pd.Timestamp(\"2016-05-25 13:30:00.075\"),\n", " ],\n", - " \"ticker\": [\n", - " \"GOOG\",\n", - " \"MSFT\",\n", - " \"MSFT\",\n", - " \"MSFT\",\n", - " \"GOOG\",\n", - " \"AAPL\",\n", - " \"GOOG\",\n", - " \"MSFT\"\n", - " ],\n", - " \"bid\": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01],\n", - " \"ask\": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03]\n", + " \"ticker\": [\"GOOG\", \"MSFT\", \"MSFT\", \"MSFT\", \"GOOG\", \"AAPL\", \"GOOG\", \"MSFT\"],\n", + " \"bid\": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01],\n", + " \"ask\": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03],\n", " }\n", ")\n", "\n", "trades = pd.DataFrame(\n", - " {\n", - " \"time\": [\n", - " pd.Timestamp(\"2016-05-25 13:30:00.023\"),\n", - " pd.Timestamp(\"2016-05-25 13:30:00.038\"),\n", - " pd.Timestamp(\"2016-05-25 13:30:00.048\"),\n", - " pd.Timestamp(\"2016-05-25 13:30:00.048\"),\n", - " pd.Timestamp(\"2016-05-25 13:30:00.048\")\n", - " ],\n", - " \"ticker\": [\"MSFT\", \"MSFT\", \"GOOG\", \"GOOG\", \"AAPL\"],\n", - " \"price\": [51.95, 51.95, 720.77, 720.92, 98.0],\n", - " \"quantity\": [75, 155, 100, 100, 100]\n", - " }\n", + " {\n", + " \"time\": [\n", + " pd.Timestamp(\"2016-05-25 13:30:00.023\"),\n", + " pd.Timestamp(\"2016-05-25 13:30:00.038\"),\n", + " pd.Timestamp(\"2016-05-25 13:30:00.048\"),\n", + " pd.Timestamp(\"2016-05-25 13:30:00.048\"),\n", + " pd.Timestamp(\"2016-05-25 13:30:00.048\"),\n", + " ],\n", + " \"ticker\": [\"MSFT\", \"MSFT\", \"GOOG\", \"GOOG\", \"AAPL\"],\n", + " \"price\": [51.95, 51.95, 720.77, 720.92, 98.0],\n", + " \"quantity\": [75, 155, 100, 100, 100],\n", + " }\n", ")\n", "\n", "\n", "stocks = pd.DataFrame(\n", - " {\n", - " \"ticker\": [\"MSFT\", \"GOOG\", \"AAPL\"],\n", - " \"name\": [\"Microsoft Corporation\", \"Alphabet Inc\", \"Apple Inc\"],\n", - " \"exchange\": [\"NASDAQ\", \"NASDAQ\", \"NASDAQ\"]\n", - " }\n", + " {\n", + " \"ticker\": [\"MSFT\", \"GOOG\", \"AAPL\"],\n", + " \"name\": [\"Microsoft Corporation\", \"Alphabet Inc\", \"Apple Inc\"],\n", + " \"exchange\": [\"NASDAQ\", \"NASDAQ\", \"NASDAQ\"],\n", + " }\n", ")\n", "\n", "import datetime\n", + "\n", + "\n", "def move_date(df, col):\n", " max_date = df[col].max()\n", " now_date = datetime.datetime.now()\n", - " delta = now_date - max_date \n", - " df[col] = df[col] + delta \n", + " delta = now_date - max_date\n", + " df[col] = df[col] + delta\n", " return df\n", "\n", + "\n", "quotes = move_date(quotes, \"time\")\n", "trades = move_date(trades, \"time\")" ] @@ -529,7 +525,7 @@ } ], "source": [ - "# add feature set without time column (stock ticker metadata) \n", + "# add feature set without time column (stock ticker metadata)\n", "stocks_set = fstore.FeatureSet(\"stocks\", entities=[fstore.Entity(\"ticker\")])\n", "fstore.ingest(stocks_set, stocks, infer_options=fstore.InferOptions.default())" ] @@ -708,10 +704,9 @@ } ], "source": [ - "quotes_set.graph.to(\"MyMap\", multiplier=3)\\\n", - " .to(\"storey.Extend\", _fn=\"({'extra': event['bid'] * 77})\")\\\n", - " .to(\"storey.Filter\", \"filter\", _fn=\"(event['bid'] > 51.92)\")\\\n", - " .to(FeaturesetValidator())\n", + "quotes_set.graph.to(\"MyMap\", multiplier=3).to(\n", + " \"storey.Extend\", _fn=\"({'extra': event['bid'] * 77})\"\n", + ").to(\"storey.Filter\", \"filter\", _fn=\"(event['bid'] > 51.92)\").to(FeaturesetValidator())\n", "\n", "quotes_set.add_aggregation(\"ask\", [\"sum\", \"max\"], \"1h\", \"10m\", name=\"asks1\")\n", "quotes_set.add_aggregation(\"ask\", [\"sum\", \"max\"], \"5h\", \"10m\", name=\"asks5\")\n", @@ -1740,7 +1735,9 @@ " \"stocks.*\",\n", "]\n", "\n", - "vector = fstore.FeatureVector(\"stocks-vec\", features, description=\"stocks demo feature vector\")\n", + "vector = fstore.FeatureVector(\n", + " \"stocks-vec\", features, description=\"stocks demo feature vector\"\n", + ")\n", "vector.save()" ] }, @@ -1862,7 +1859,9 @@ } ], "source": [ - "resp = fstore.get_offline_features(vector, entity_rows=trades, entity_timestamp_column=\"time\")\n", + "resp = fstore.get_offline_features(\n", + " vector, entity_rows=trades, entity_timestamp_column=\"time\"\n", + ")\n", "resp.to_dataframe()" ] }, diff --git a/docs/feature-store/end-to-end-demo/01-ingest-datasources.ipynb b/docs/feature-store/end-to-end-demo/01-ingest-datasources.ipynb index cf8cbc8c40aa..0143bb803b30 100644 --- a/docs/feature-store/end-to-end-demo/01-ingest-datasources.ipynb +++ b/docs/feature-store/end-to-end-demo/01-ingest-datasources.ipynb @@ -89,7 +89,7 @@ "metadata": {}, "outputs": [], "source": [ - "project_name = 'fraud-demo'" + "project_name = \"fraud-demo\"" ] }, { @@ -109,7 +109,7 @@ "import mlrun\n", "\n", "# Initialize the MLRun project object\n", - "project = mlrun.get_or_create_project(project_name, context=\"./\", user_project=True) " + "project = mlrun.get_or_create_project(project_name, context=\"./\", user_project=True)" ] }, { @@ -147,33 +147,41 @@ "# while keeping the order of the selected events and\n", "# the relative distance from one event to the other\n", "\n", + "\n", "def date_adjustment(sample, data_max, new_max, old_data_period, new_data_period):\n", - " '''\n", - " Adjust a specific sample's date according to the original and new time periods\n", - " '''\n", - " sample_dates_scale = ((data_max - sample) / old_data_period)\n", + " \"\"\"\n", + " Adjust a specific sample's date according to the original and new time periods\n", + " \"\"\"\n", + " sample_dates_scale = (data_max - sample) / old_data_period\n", " sample_delta = new_data_period * sample_dates_scale\n", " new_sample_ts = new_max - sample_delta\n", " return new_sample_ts\n", "\n", - "def adjust_data_timespan(dataframe, timestamp_col='timestamp', new_period='2d', new_max_date_str='now'):\n", - " '''\n", - " Adjust the dataframe timestamps to the new time period\n", - " '''\n", + "\n", + "def adjust_data_timespan(\n", + " dataframe, timestamp_col=\"timestamp\", new_period=\"2d\", new_max_date_str=\"now\"\n", + "):\n", + " \"\"\"\n", + " Adjust the dataframe timestamps to the new time period\n", + " \"\"\"\n", " # Calculate old time period\n", " data_min = dataframe.timestamp.min()\n", " data_max = dataframe.timestamp.max()\n", - " old_data_period = data_max-data_min\n", - " \n", + " old_data_period = data_max - data_min\n", + "\n", " # Set new time period\n", " new_time_period = pd.Timedelta(new_period)\n", " new_max = pd.Timestamp(new_max_date_str)\n", - " new_min = new_max-new_time_period\n", - " new_data_period = new_max-new_min\n", - " \n", + " new_min = new_max - new_time_period\n", + " new_data_period = new_max - new_min\n", + "\n", " # Apply the timestamp change\n", " df = dataframe.copy()\n", - " df[timestamp_col] = df[timestamp_col].apply(lambda x: date_adjustment(x, data_max, new_max, old_data_period, new_data_period))\n", + " df[timestamp_col] = df[timestamp_col].apply(\n", + " lambda x: date_adjustment(\n", + " x, data_max, new_max, old_data_period, new_data_period\n", + " )\n", + " )\n", " return df" ] }, @@ -293,16 +301,19 @@ "import pandas as pd\n", "\n", "# Fetch the transactions dataset from the server\n", - "transactions_data = pd.read_csv('https://s3.wasabisys.com/iguazio/data/fraud-demo-mlrun-fs-docs/data.csv', parse_dates=['timestamp'])\n", + "transactions_data = pd.read_csv(\n", + " \"https://s3.wasabisys.com/iguazio/data/fraud-demo-mlrun-fs-docs/data.csv\",\n", + " parse_dates=[\"timestamp\"],\n", + ")\n", "\n", "# use only first 50k\n", - "transactions_data = transactions_data.sort_values(by='source', axis=0)[:10000]\n", + "transactions_data = transactions_data.sort_values(by=\"source\", axis=0)[:10000]\n", "\n", "# Adjust the samples timestamp for the past 2 days\n", - "transactions_data = adjust_data_timespan(transactions_data, new_period='2d')\n", + "transactions_data = adjust_data_timespan(transactions_data, new_period=\"2d\")\n", "\n", "# Sorting after adjusting timestamps\n", - "transactions_data = transactions_data.sort_values(by='timestamp', axis=0)\n", + "transactions_data = transactions_data.sort_values(by=\"timestamp\", axis=0)\n", "\n", "# Preview\n", "transactions_data.head(3)" @@ -345,10 +356,12 @@ "outputs": [], "source": [ "# Define the transactions FeatureSet\n", - "transaction_set = fstore.FeatureSet(\"transactions\", \n", - " entities=[fstore.Entity(\"source\")], \n", - " timestamp_key='timestamp', \n", - " description=\"transactions feature set\")" + "transaction_set = fstore.FeatureSet(\n", + " \"transactions\",\n", + " entities=[fstore.Entity(\"source\")],\n", + " timestamp_key=\"timestamp\",\n", + " description=\"transactions feature set\",\n", + ")" ] }, { @@ -464,35 +477,57 @@ ], "source": [ "# Define and add value mapping\n", - "main_categories = [\"es_transportation\", \"es_health\", \"es_otherservices\",\n", - " \"es_food\", \"es_hotelservices\", \"es_barsandrestaurants\",\n", - " \"es_tech\", \"es_sportsandtoys\", \"es_wellnessandbeauty\",\n", - " \"es_hyper\", \"es_fashion\", \"es_home\", \"es_contents\",\n", - " \"es_travel\", \"es_leisure\"]\n", + "main_categories = [\n", + " \"es_transportation\",\n", + " \"es_health\",\n", + " \"es_otherservices\",\n", + " \"es_food\",\n", + " \"es_hotelservices\",\n", + " \"es_barsandrestaurants\",\n", + " \"es_tech\",\n", + " \"es_sportsandtoys\",\n", + " \"es_wellnessandbeauty\",\n", + " \"es_hyper\",\n", + " \"es_fashion\",\n", + " \"es_home\",\n", + " \"es_contents\",\n", + " \"es_travel\",\n", + " \"es_leisure\",\n", + "]\n", "\n", "# One Hot Encode the newly defined mappings\n", - "one_hot_encoder_mapping = {'category': main_categories,\n", - " 'gender': list(transactions_data.gender.unique())}\n", + "one_hot_encoder_mapping = {\n", + " \"category\": main_categories,\n", + " \"gender\": list(transactions_data.gender.unique()),\n", + "}\n", "\n", "# Define the graph steps\n", - "transaction_set.graph\\\n", - " .to(DateExtractor(parts = ['hour', 'day_of_week'], timestamp_col = 'timestamp'))\\\n", - " .to(MapValues(mapping={'age': {'U': '0'}}, with_original_features=True))\\\n", - " .to(OneHotEncoder(mapping=one_hot_encoder_mapping))\n", + "transaction_set.graph.to(\n", + " DateExtractor(parts=[\"hour\", \"day_of_week\"], timestamp_col=\"timestamp\")\n", + ").to(MapValues(mapping={\"age\": {\"U\": \"0\"}}, with_original_features=True)).to(\n", + " OneHotEncoder(mapping=one_hot_encoder_mapping)\n", + ")\n", "\n", "\n", "# Add aggregations for 2, 12, and 24 hour time windows\n", - "transaction_set.add_aggregation(name='amount',\n", - " column='amount',\n", - " operations=['avg','sum', 'count','max'],\n", - " windows=['2h', '12h', '24h'],\n", - " period='1h')\n", + "transaction_set.add_aggregation(\n", + " name=\"amount\",\n", + " column=\"amount\",\n", + " operations=[\"avg\", \"sum\", \"count\", \"max\"],\n", + " windows=[\"2h\", \"12h\", \"24h\"],\n", + " period=\"1h\",\n", + ")\n", "\n", "\n", "# Add the category aggregations over a 14 day window\n", "for category in main_categories:\n", - " transaction_set.add_aggregation(name=category,column=f'category_{category}',\n", - " operations=['sum'], windows=['14d'], period='1d')\n", + " transaction_set.add_aggregation(\n", + " name=category,\n", + " column=f\"category_{category}\",\n", + " operations=[\"sum\"],\n", + " windows=[\"14d\"],\n", + " period=\"1d\",\n", + " )\n", "\n", "# Add default (offline-parquet & online-nosql) targets\n", "transaction_set.set_targets()\n", @@ -712,8 +747,9 @@ ], "source": [ "# Ingest your transactions dataset through your defined pipeline\n", - "transactions_df = fstore.ingest(transaction_set, transactions_data, \n", - " infer_options=fstore.InferOptions.default())\n", + "transactions_df = fstore.ingest(\n", + " transaction_set, transactions_data, infer_options=fstore.InferOptions.default()\n", + ")\n", "\n", "transactions_df.head(3)" ] @@ -825,11 +861,15 @@ ], "source": [ "# Fetch the user_events dataset from the server\n", - "user_events_data = pd.read_csv('https://s3.wasabisys.com/iguazio/data/fraud-demo-mlrun-fs-docs/events.csv', \n", - " index_col=0, quotechar=\"\\'\", parse_dates=['timestamp'])\n", + "user_events_data = pd.read_csv(\n", + " \"https://s3.wasabisys.com/iguazio/data/fraud-demo-mlrun-fs-docs/events.csv\",\n", + " index_col=0,\n", + " quotechar=\"'\",\n", + " parse_dates=[\"timestamp\"],\n", + ")\n", "\n", "# Adjust to the last 2 days to see the latest aggregations in the online feature vectors\n", - "user_events_data = adjust_data_timespan(user_events_data, new_period='2d')\n", + "user_events_data = adjust_data_timespan(user_events_data, new_period=\"2d\")\n", "\n", "# Preview\n", "user_events_data.head(3)" @@ -851,10 +891,12 @@ "metadata": {}, "outputs": [], "source": [ - "user_events_set = fstore.FeatureSet(\"events\",\n", - " entities=[fstore.Entity(\"source\")],\n", - " timestamp_key='timestamp', \n", - " description=\"user events feature set\")" + "user_events_set = fstore.FeatureSet(\n", + " \"events\",\n", + " entities=[fstore.Entity(\"source\")],\n", + " timestamp_key=\"timestamp\",\n", + " description=\"user events feature set\",\n", + ")" ] }, { @@ -934,7 +976,7 @@ ], "source": [ "# Define and add value mapping\n", - "events_mapping = {'event': list(user_events_data.event.unique())}\n", + "events_mapping = {\"event\": list(user_events_data.event.unique())}\n", "\n", "# One Hot Encode\n", "user_events_set.graph.to(OneHotEncoder(mapping=events_mapping))\n", @@ -1065,10 +1107,10 @@ "outputs": [], "source": [ "def create_labels(df):\n", - " labels = df[['fraud','timestamp']].copy()\n", + " labels = df[[\"fraud\", \"timestamp\"]].copy()\n", " labels = labels.rename(columns={\"fraud\": \"label\"})\n", - " labels['timestamp'] = labels['timestamp'].astype(\"datetime64[ms]\")\n", - " labels['label'] = labels['label'].astype(int)\n", + " labels[\"timestamp\"] = labels[\"timestamp\"].astype(\"datetime64[ms]\")\n", + " labels[\"label\"] = labels[\"label\"].astype(int)\n", " return labels" ] }, @@ -1140,17 +1182,21 @@ "import os\n", "\n", "# Define the \"labels\" feature set\n", - "labels_set = fstore.FeatureSet(\"labels\", \n", - " entities=[fstore.Entity(\"source\")], \n", - " timestamp_key='timestamp',\n", - " description=\"training labels\",\n", - " engine=\"pandas\")\n", + "labels_set = fstore.FeatureSet(\n", + " \"labels\",\n", + " entities=[fstore.Entity(\"source\")],\n", + " timestamp_key=\"timestamp\",\n", + " description=\"training labels\",\n", + " engine=\"pandas\",\n", + ")\n", "\n", "labels_set.graph.to(name=\"create_labels\", handler=create_labels)\n", "\n", "\n", "# specify only Parquet (offline) target since its not used for real-time\n", - "target = ParquetTarget(name='labels',path=f'v3io:///projects/{project.name}/target.parquet')\n", + "target = ParquetTarget(\n", + " name=\"labels\", path=f\"v3io:///projects/{project.name}/target.parquet\"\n", + ")\n", "labels_set.set_targets([target], with_defaults=False)\n", "labels_set.plot(with_targets=True)" ] @@ -1273,7 +1319,7 @@ "outputs": [], "source": [ "# Create iguazio v3io stream and transactions push API endpoint\n", - "transaction_stream = f'v3io:///projects/{project.name}/streams/transaction'\n", + "transaction_stream = f\"v3io:///projects/{project.name}/streams/transaction\"\n", "transaction_pusher = mlrun.datastore.get_stream_pusher(transaction_stream)" ] }, @@ -1299,11 +1345,15 @@ "source": [ "# Define the source stream trigger (use v3io streams)\n", "# define the `key` and `time` fields (extracted from the Json message).\n", - "source = mlrun.datastore.sources.StreamSource(path=transaction_stream , key_field='source', time_field='timestamp')\n", + "source = mlrun.datastore.sources.StreamSource(\n", + " path=transaction_stream, key_field=\"source\", time_field=\"timestamp\"\n", + ")\n", "\n", "# Deploy the transactions feature set's ingestion service over a real-time (Nuclio) serverless function\n", "# you can use the run_config parameter to pass function/service specific configuration\n", - "transaction_set_endpoint = fstore.deploy_ingestion_service(featureset=transaction_set, source=source)" + "transaction_set_endpoint = fstore.deploy_ingestion_service(\n", + " featureset=transaction_set, source=source\n", + ")" ] }, { @@ -1355,8 +1405,10 @@ "import json\n", "\n", "# Select a sample from the dataset and serialize it to JSON\n", - "transaction_sample = json.loads(transactions_data.sample(1).to_json(orient='records'))[0]\n", - "transaction_sample['timestamp'] = str(pd.Timestamp.now())\n", + "transaction_sample = json.loads(transactions_data.sample(1).to_json(orient=\"records\"))[\n", + " 0\n", + "]\n", + "transaction_sample[\"timestamp\"] = str(pd.Timestamp.now())\n", "transaction_sample" ] }, @@ -1403,7 +1455,7 @@ "outputs": [], "source": [ "# Create iguazio v3io stream and transactions push API endpoint\n", - "events_stream = f'v3io:///projects/{project.name}/streams/events'\n", + "events_stream = f\"v3io:///projects/{project.name}/streams/events\"\n", "events_pusher = mlrun.datastore.get_stream_pusher(events_stream)" ] }, @@ -1427,11 +1479,15 @@ "source": [ "# Define the source stream trigger (use v3io streams)\n", "# define the `key` and `time` fields (extracted from the Json message).\n", - "source = mlrun.datastore.sources.StreamSource(path=events_stream , key_field='source', time_field='timestamp')\n", + "source = mlrun.datastore.sources.StreamSource(\n", + " path=events_stream, key_field=\"source\", time_field=\"timestamp\"\n", + ")\n", "\n", "# Deploy the transactions feature set's ingestion service over a real-time (Nuclio) serverless function\n", "# you can use the run_config parameter to pass function/service specific configuration\n", - "events_set_endpoint = fstore.deploy_ingestion_service(featureset=user_events_set, source=source)" + "events_set_endpoint = fstore.deploy_ingestion_service(\n", + " featureset=user_events_set, source=source\n", + ")" ] }, { @@ -1448,8 +1504,8 @@ "outputs": [], "source": [ "# Select a sample from the events dataset and serialize it to JSON\n", - "user_events_sample = json.loads(user_events_data.sample(1).to_json(orient='records'))[0]\n", - "user_events_sample['timestamp'] = str(pd.Timestamp.now())\n", + "user_events_sample = json.loads(user_events_data.sample(1).to_json(orient=\"records\"))[0]\n", + "user_events_sample[\"timestamp\"] = str(pd.Timestamp.now())\n", "user_events_sample" ] }, diff --git a/docs/feature-store/end-to-end-demo/02-create-training-model.ipynb b/docs/feature-store/end-to-end-demo/02-create-training-model.ipynb index b371f8f238f8..f5a56e748491 100644 --- a/docs/feature-store/end-to-end-demo/02-create-training-model.ipynb +++ b/docs/feature-store/end-to-end-demo/02-create-training-model.ipynb @@ -19,7 +19,7 @@ "metadata": {}, "outputs": [], "source": [ - "project_name = 'fraud-demo'" + "project_name = \"fraud-demo\"" ] }, { @@ -63,39 +63,41 @@ "outputs": [], "source": [ "# Define the list of features to use\n", - "features = ['events.*',\n", - " 'transactions.amount_max_2h', \n", - " 'transactions.amount_sum_2h', \n", - " 'transactions.amount_count_2h',\n", - " 'transactions.amount_avg_2h', \n", - " 'transactions.amount_max_12h', \n", - " 'transactions.amount_sum_12h',\n", - " 'transactions.amount_count_12h', \n", - " 'transactions.amount_avg_12h', \n", - " 'transactions.amount_max_24h',\n", - " 'transactions.amount_sum_24h', \n", - " 'transactions.amount_count_24h', \n", - " 'transactions.amount_avg_24h',\n", - " 'transactions.es_transportation_sum_14d', \n", - " 'transactions.es_health_sum_14d',\n", - " 'transactions.es_otherservices_sum_14d', \n", - " 'transactions.es_food_sum_14d',\n", - " 'transactions.es_hotelservices_sum_14d', \n", - " 'transactions.es_barsandrestaurants_sum_14d',\n", - " 'transactions.es_tech_sum_14d', \n", - " 'transactions.es_sportsandtoys_sum_14d',\n", - " 'transactions.es_wellnessandbeauty_sum_14d', \n", - " 'transactions.es_hyper_sum_14d',\n", - " 'transactions.es_fashion_sum_14d', \n", - " 'transactions.es_home_sum_14d', \n", - " 'transactions.es_travel_sum_14d', \n", - " 'transactions.es_leisure_sum_14d',\n", - " 'transactions.gender_F',\n", - " 'transactions.gender_M',\n", - " 'transactions.step', \n", - " 'transactions.amount', \n", - " 'transactions.timestamp_hour',\n", - " 'transactions.timestamp_day_of_week']" + "features = [\n", + " \"events.*\",\n", + " \"transactions.amount_max_2h\",\n", + " \"transactions.amount_sum_2h\",\n", + " \"transactions.amount_count_2h\",\n", + " \"transactions.amount_avg_2h\",\n", + " \"transactions.amount_max_12h\",\n", + " \"transactions.amount_sum_12h\",\n", + " \"transactions.amount_count_12h\",\n", + " \"transactions.amount_avg_12h\",\n", + " \"transactions.amount_max_24h\",\n", + " \"transactions.amount_sum_24h\",\n", + " \"transactions.amount_count_24h\",\n", + " \"transactions.amount_avg_24h\",\n", + " \"transactions.es_transportation_sum_14d\",\n", + " \"transactions.es_health_sum_14d\",\n", + " \"transactions.es_otherservices_sum_14d\",\n", + " \"transactions.es_food_sum_14d\",\n", + " \"transactions.es_hotelservices_sum_14d\",\n", + " \"transactions.es_barsandrestaurants_sum_14d\",\n", + " \"transactions.es_tech_sum_14d\",\n", + " \"transactions.es_sportsandtoys_sum_14d\",\n", + " \"transactions.es_wellnessandbeauty_sum_14d\",\n", + " \"transactions.es_hyper_sum_14d\",\n", + " \"transactions.es_fashion_sum_14d\",\n", + " \"transactions.es_home_sum_14d\",\n", + " \"transactions.es_travel_sum_14d\",\n", + " \"transactions.es_leisure_sum_14d\",\n", + " \"transactions.gender_F\",\n", + " \"transactions.gender_M\",\n", + " \"transactions.step\",\n", + " \"transactions.amount\",\n", + " \"transactions.timestamp_hour\",\n", + " \"transactions.timestamp_day_of_week\",\n", + "]" ] }, { @@ -108,13 +110,15 @@ "import mlrun.feature_store as fstore\n", "\n", "# Define the feature vector name for future reference\n", - "fv_name = 'transactions-fraud'\n", + "fv_name = \"transactions-fraud\"\n", "\n", "# Define the feature vector using the feature store (fstore)\n", - "transactions_fv = fstore.FeatureVector(fv_name, \n", - " features, \n", - " label_feature=\"labels.label\",\n", - " description='Predicting a fraudulent transaction')\n", + "transactions_fv = fstore.FeatureVector(\n", + " fv_name,\n", + " features,\n", + " label_feature=\"labels.label\",\n", + " description=\"Predicting a fraudulent transaction\",\n", + ")\n", "\n", "# Save the feature vector in the feature store\n", "transactions_fv.save()" @@ -391,7 +395,7 @@ "outputs": [], "source": [ "# Import the Sklearn classifier function from the functions hub\n", - "classifier_fn = mlrun.import_function('hub://auto_trainer')" + "classifier_fn = mlrun.import_function(\"hub://auto_trainer\")" ] }, { @@ -677,24 +681,30 @@ "source": [ "# Prepare the parameters list for the training function\n", "# you use 3 different models\n", - "training_params = {\"model_name\": ['transaction_fraud_rf', \n", - " 'transaction_fraud_xgboost', \n", - " 'transaction_fraud_adaboost'],\n", - " \n", - " \"model_class\": ['sklearn.ensemble.RandomForestClassifier',\n", - " 'sklearn.ensemble.GradientBoostingClassifier',\n", - " 'sklearn.ensemble.AdaBoostClassifier']}\n", + "training_params = {\n", + " \"model_name\": [\n", + " \"transaction_fraud_rf\",\n", + " \"transaction_fraud_xgboost\",\n", + " \"transaction_fraud_adaboost\",\n", + " ],\n", + " \"model_class\": [\n", + " \"sklearn.ensemble.RandomForestClassifier\",\n", + " \"sklearn.ensemble.GradientBoostingClassifier\",\n", + " \"sklearn.ensemble.AdaBoostClassifier\",\n", + " ],\n", + "}\n", "\n", "# Define the training task, including your feature vector, label and hyperparams definitions\n", - "train_task = mlrun.new_task('training', \n", - " inputs={'dataset': transactions_fv.uri},\n", - " params={'label_columns': 'label'}\n", - " )\n", + "train_task = mlrun.new_task(\n", + " \"training\",\n", + " inputs={\"dataset\": transactions_fv.uri},\n", + " params={\"label_columns\": \"label\"},\n", + ")\n", "\n", - "train_task.with_hyper_params(training_params, strategy='list', selector='max.accuracy')\n", + "train_task.with_hyper_params(training_params, strategy=\"list\", selector=\"max.accuracy\")\n", "\n", "# Specify your cluster image\n", - "classifier_fn.spec.image = 'mlrun/mlrun'\n", + "classifier_fn.spec.image = \"mlrun/mlrun\"\n", "\n", "# Run training\n", "classifier_fn.run(train_task, local=False)" @@ -954,19 +964,21 @@ } ], "source": [ - "feature_selection_fn = mlrun.import_function('hub://feature_selection')\n", + "feature_selection_fn = mlrun.import_function(\"hub://feature_selection\")\n", "\n", "feature_selection_run = feature_selection_fn.run(\n", - " params={\"k\": 18,\n", - " \"min_votes\": 2,\n", - " \"label_column\": 'label',\n", - " 'output_vector_name':fv_name + \"-short\",\n", - " 'ignore_type_errors': True},\n", - " \n", - " inputs={'df_artifact': transactions_fv.uri},\n", - " name='feature_extraction',\n", - " handler='feature_selection',\n", - " local=False)" + " params={\n", + " \"k\": 18,\n", + " \"min_votes\": 2,\n", + " \"label_column\": \"label\",\n", + " \"output_vector_name\": fv_name + \"-short\",\n", + " \"ignore_type_errors\": True,\n", + " },\n", + " inputs={\"df_artifact\": transactions_fv.uri},\n", + " name=\"feature_extraction\",\n", + " handler=\"feature_selection\",\n", + " local=False,\n", + ")" ] }, { @@ -1156,7 +1168,7 @@ } ], "source": [ - "mlrun.get_dataitem(feature_selection_run.outputs['top_features_vector']).as_df().tail(5)" + "mlrun.get_dataitem(feature_selection_run.outputs[\"top_features_vector\"]).as_df().tail(5)" ] }, { @@ -1452,11 +1464,14 @@ ], "source": [ "# Define your training task, including your feature vector, label and hyperparams definitions\n", - "ensemble_train_task = mlrun.new_task('training', \n", - " inputs={'dataset': feature_selection_run.outputs['top_features_vector']},\n", - " params={'label_columns': 'label'}\n", - " )\n", - "ensemble_train_task.with_hyper_params(training_params, strategy='list', selector='max.accuracy')\n", + "ensemble_train_task = mlrun.new_task(\n", + " \"training\",\n", + " inputs={\"dataset\": feature_selection_run.outputs[\"top_features_vector\"]},\n", + " params={\"label_columns\": \"label\"},\n", + ")\n", + "ensemble_train_task.with_hyper_params(\n", + " training_params, strategy=\"list\", selector=\"max.accuracy\"\n", + ")\n", "\n", "classifier_fn.run(ensemble_train_task)" ] diff --git a/docs/feature-store/end-to-end-demo/03-deploy-serving-model.ipynb b/docs/feature-store/end-to-end-demo/03-deploy-serving-model.ipynb index 96a071749dac..8f3112b60810 100644 --- a/docs/feature-store/end-to-end-demo/03-deploy-serving-model.ipynb +++ b/docs/feature-store/end-to-end-demo/03-deploy-serving-model.ipynb @@ -71,7 +71,7 @@ "metadata": {}, "outputs": [], "source": [ - "project_name = 'fraud-demo'" + "project_name = \"fraud-demo\"" ] }, { @@ -122,17 +122,17 @@ "from cloudpickle import load\n", "from mlrun.serving.v2_serving import V2ModelServer\n", "\n", + "\n", "class ClassifierModel(V2ModelServer):\n", - " \n", " def load(self):\n", " \"\"\"load and initialize the model and/or other elements\"\"\"\n", - " model_file, extra_data = self.get_model('.pkl')\n", - " self.model = load(open(model_file, 'rb'))\n", - " \n", + " model_file, extra_data = self.get_model(\".pkl\")\n", + " self.model = load(open(model_file, \"rb\"))\n", + "\n", " def predict(self, body: dict) -> list:\n", " \"\"\"Generate model predictions from sample\"\"\"\n", " print(f\"Input -> {body['inputs']}\")\n", - " feats = np.asarray(body['inputs'])\n", + " feats = np.asarray(body[\"inputs\"])\n", " result: np.ndarray = self.model.predict(feats)\n", " return result.tolist()" ] @@ -257,19 +257,30 @@ ], "source": [ "# Create the serving function from your code above\n", - "serving_fn = mlrun.code_to_function('transaction-fraud', kind='serving', image=\"mlrun/mlrun\").apply(mlrun.auto_mount())\n", + "serving_fn = mlrun.code_to_function(\n", + " \"transaction-fraud\", kind=\"serving\", image=\"mlrun/mlrun\"\n", + ").apply(mlrun.auto_mount())\n", "\n", - "serving_fn.set_topology('router', 'mlrun.serving.routers.EnrichmentVotingEnsemble', name='VotingEnsemble',\n", - " feature_vector_uri=\"transactions-fraud-short\", impute_policy={\"*\": \"$mean\"})\n", + "serving_fn.set_topology(\n", + " \"router\",\n", + " \"mlrun.serving.routers.EnrichmentVotingEnsemble\",\n", + " name=\"VotingEnsemble\",\n", + " feature_vector_uri=\"transactions-fraud-short\",\n", + " impute_policy={\"*\": \"$mean\"},\n", + ")\n", "\n", "model_names = [\n", - "'transaction_fraud_rf',\n", - "'transaction_fraud_xgboost',\n", - "'transaction_fraud_adaboost'\n", + " \"transaction_fraud_rf\",\n", + " \"transaction_fraud_xgboost\",\n", + " \"transaction_fraud_adaboost\",\n", "]\n", "\n", "for i, name in enumerate(model_names, start=1):\n", - " serving_fn.add_model(name, class_name=\"ClassifierModel\", model_path=project.get_artifact_uri(f\"{name}#{i}:latest\"))\n", + " serving_fn.add_model(\n", + " name,\n", + " class_name=\"ClassifierModel\",\n", + " model_path=project.get_artifact_uri(f\"{name}#{i}:latest\"),\n", + " )\n", "\n", "# Plot the ensemble configuration\n", "serving_fn.spec.graph.plot()" @@ -343,13 +354,12 @@ ], "source": [ "# Choose an id for your test\n", - "sample_id = 'C1000148617'\n", + "sample_id = \"C1000148617\"\n", "\n", - "model_inference_path = '/v2/models/infer'\n", + "model_inference_path = \"/v2/models/infer\"\n", "\n", "# Send your sample ID for prediction\n", - "local_server.test(path=model_inference_path,\n", - " body={'inputs': [[sample_id]]})\n", + "local_server.test(path=model_inference_path, body={\"inputs\": [[sample_id]]})\n", "\n", "# notice the input vector is printed 3 times (once per child model) and is enriched with data from the feature store" ] @@ -397,10 +407,12 @@ "import mlrun.feature_store as fstore\n", "\n", "# Create the online feature service\n", - "svc = fstore.get_online_feature_service('transactions-fraud-short:latest', impute_policy={\"*\": \"$mean\"})\n", + "svc = fstore.get_online_feature_service(\n", + " \"transactions-fraud-short:latest\", impute_policy={\"*\": \"$mean\"}\n", + ")\n", "\n", "# Get sample feature vector\n", - "sample_fv = svc.get([{'source': sample_id}])\n", + "sample_fv = svc.get([{\"source\": sample_id}])\n", "sample_fv" ] }, @@ -448,7 +460,7 @@ "\n", "# Enable model monitoring\n", "serving_fn.set_tracking()\n", - "project.set_model_monitoring_credentials(os.getenv('V3IO_ACCESS_KEY'))\n", + "project.set_model_monitoring_credentials(os.getenv(\"V3IO_ACCESS_KEY\"))\n", "\n", "# Deploy the serving function\n", "serving_fn.deploy()" @@ -491,13 +503,12 @@ ], "source": [ "# Choose an id for your test\n", - "sample_id = 'C1000148617'\n", + "sample_id = \"C1000148617\"\n", "\n", - "model_inference_path = '/v2/models/infer'\n", + "model_inference_path = \"/v2/models/infer\"\n", "\n", "# Send your sample ID for prediction\n", - "serving_fn.invoke(path=model_inference_path,\n", - " body={'inputs': [[sample_id]]})" + "serving_fn.invoke(path=model_inference_path, body={\"inputs\": [[sample_id]]})" ] }, { @@ -521,13 +532,15 @@ "outputs": [], "source": [ "# Load the dataset\n", - "data = mlrun.get_dataitem('https://s3.wasabisys.com/iguazio/data/fraud-demo-mlrun-fs-docs/data.csv').as_df()\n", + "data = mlrun.get_dataitem(\n", + " \"https://s3.wasabisys.com/iguazio/data/fraud-demo-mlrun-fs-docs/data.csv\"\n", + ").as_df()\n", "\n", "# use only first 10k\n", - "data = data.sort_values(by='source', axis=0)[:10000]\n", + "data = data.sort_values(by=\"source\", axis=0)[:10000]\n", "\n", "# keys\n", - "sample_ids = data['source'].to_list()" + "sample_ids = data[\"source\"].to_list()" ] }, { @@ -570,7 +583,9 @@ "for _ in range(10):\n", " data_point = choice(sample_ids)\n", " try:\n", - " resp = serving_fn.invoke(path=model_inference_path, body={'inputs': [[data_point]]})\n", + " resp = serving_fn.invoke(\n", + " path=model_inference_path, body={\"inputs\": [[data_point]]}\n", + " )\n", " print(resp)\n", " sleep(uniform(0.2, 1.7))\n", " except OSError:\n", diff --git a/docs/feature-store/end-to-end-demo/04-pipeline.ipynb b/docs/feature-store/end-to-end-demo/04-pipeline.ipynb index db3296b9aae1..415e6c8fab06 100644 --- a/docs/feature-store/end-to-end-demo/04-pipeline.ipynb +++ b/docs/feature-store/end-to-end-demo/04-pipeline.ipynb @@ -42,7 +42,7 @@ "outputs": [], "source": [ "# Set the base project name\n", - "project_name = 'fraud-demo'" + "project_name = \"fraud-demo\"" ] }, { @@ -130,9 +130,9 @@ } ], "source": [ - "project.set_function('hub://feature_selection', 'feature_selection')\n", - "project.set_function('hub://auto_trainer','train')\n", - "project.set_function('hub://v2_model_server', 'serving')" + "project.set_function(\"hub://feature_selection\", \"feature_selection\")\n", + "project.set_function(\"hub://auto_trainer\", \"train\")\n", + "project.set_function(\"hub://v2_model_server\", \"serving\")" ] }, { @@ -153,7 +153,7 @@ ], "source": [ "# set project level parameters and save\n", - "project.spec.params = {'label_column': 'label'}\n", + "project.spec.params = {\"label_column\": \"label\"}\n", "project.save()" ] }, @@ -387,7 +387,7 @@ "outputs": [], "source": [ "# Register the workflow file as \"main\"\n", - "project.set_workflow('main', 'workflow.py')" + "project.set_workflow(\"main\", \"workflow.py\")" ] }, { @@ -572,10 +572,7 @@ } ], "source": [ - "run_id = project.run(\n", - " 'main',\n", - " arguments={}, \n", - " dirty=True, watch=True)" + "run_id = project.run(\"main\", arguments={}, dirty=True, watch=True)" ] }, { @@ -626,15 +623,14 @@ ], "source": [ "# Define your serving function\n", - "serving_fn = project.get_function('serving')\n", + "serving_fn = project.get_function(\"serving\")\n", "\n", "# Choose an id for your test\n", - "sample_id = 'C1000148617'\n", - "model_inference_path = '/v2/models/fraud/infer'\n", + "sample_id = \"C1000148617\"\n", + "model_inference_path = \"/v2/models/fraud/infer\"\n", "\n", "# Send our sample ID for predcition\n", - "serving_fn.invoke(path=model_inference_path,\n", - " body={'inputs': [[sample_id]]})" + "serving_fn.invoke(path=model_inference_path, body={\"inputs\": [[sample_id]]})" ] }, { diff --git a/docs/feature-store/feature-sets.md b/docs/feature-store/feature-sets.md index 9f35fb25385a..11fedcb7a785 100644 --- a/docs/feature-store/feature-sets.md +++ b/docs/feature-store/feature-sets.md @@ -33,7 +33,7 @@ Create a {py:class}`~mlrun.feature_store.FeatureSet` with the base definitions: * **name** — The feature set name is a unique name within a project. * **entities** — Each feature set must be associated with one or more index column. When joining feature sets, the key columns - are determined by the the relations field if it exists, and otherwise by the entities. + are determined by the relations field if it exists, and otherwise by the entities. * **timestamp_key** — (optional) Used for specifying the time field when joining by time. * **engine** — The processing engine type: - Spark @@ -124,7 +124,7 @@ df = fstore.ingest(stocks_set, stocks_df) The graph steps can use built-in transformation classes, simple python classes, or function handlers. -See more details in [Feature set transformations](transformations.html) and See more details in {ref}`transformations`. +See more details in {ref}`Feature set transformations `. ## Simulate and debug the data pipeline with a small dataset During the development phase it's pretty common to check the feature set definition and to simulate the creation of the feature set before diff --git a/docs/feature-store/feature-store-overview.md b/docs/feature-store/feature-store-overview.md index ce090a6452d7..2b5e7d5b144d 100644 --- a/docs/feature-store/feature-store-overview.md +++ b/docs/feature-store/feature-store-overview.md @@ -4,18 +4,18 @@ In machine-learning scenarios, generating a new feature, called feature engineering, takes a tremendous amount of work. The same features must be used both for training, based on historical data, and for the model prediction based on the online or real-time data. This creates a significant additional engineering effort, and leads to model inaccuracy when the online and offline features do not match. Furthermore, -monitoring solutions must be built to track features and results and send alerts of data or model drift. +monitoring solutions must be built to track features and results, and to send alerts upon data or model drift. Consider a scenario in which you train a model and one of its features is a comparison of the current amount to the average amount spent -during the last 3 months by the same person. Creating such a feature is easy when you have the full dataset in training, but in serving, +during the last 3 months by the same person. Creating such a feature is easy when you have the full dataset in training, but for serving this feature must be calculated in an online manner. The "brute-force" way to address this is to have an ML engineer create an online -pipeline that reimplements all the feature calculations done in the offline process. This is not just time-consuming and error-prone, but +pipeline that re-implements all the feature calculations that comprise the offline process. This is not just time-consuming and error-prone, but very difficult to maintain over time, and results in a lengthy deployment time. This is exacerbated when having to deal with thousands of -features with an increasing number of data engineers and data scientists that are creating and using the features. +features, and an increasing number of data engineers and data scientists that are creating and using the features. ![Challenges managing features](../_static/images/challenges_managing_features.png) -With MLRun's feature store you can easily define features during the training, that are deployable to serving, without having to define all the +With MLRun's feature store you can easily define features during the training, which are deployable to serving, without having to define all the "glue" code. You simply create the necessary building blocks to define features and integration, with offline and online storage systems to access the features. ![Feature store diagram](../_static/images/feature_store_diagram.png) @@ -26,11 +26,11 @@ This can be raw data (e.g., transaction amount, image pixel, etc.) or a calculat from average, pattern on image, etc.). - **{ref}`feature-sets`** — A grouping of features that are ingested together and stored in a logical group. Feature sets take data from offline or online sources, build a list of features through a set of transformations, and store the resulting features, along with the -associated metadata and statistics. For example, a transaction may be grouped by the ID of a person performing the transfer or by the device +associated metadata and statistics. For example, transactions could be grouped by the ID of a person performing the transfer or by the device identifier used to perform the transaction. You can also define in the timestamp source in the feature set, and ingest data into a feature set. - **[Execution](./feature-sets.html#add-transformations)** — A set of operations performed on the data while it is -ingested. The graph contains steps that represent data sources and targets, and can also contain steps that transform and enrich the data that is passed through the feature set. For a deeper dive, see {ref}`transformations`. +ingested. The transformation graph contains steps that represent data sources and targets, and can also include steps that transform and enrich the data that is passed through the feature set. For a deeper dive, see {ref}`transformations`. - **{ref}`Feature vectors `** — A set of features, taken from one or more feature sets. The feature vector is defined prior to model training and serves as the input to the model training process. During model serving, the feature values in the vector are obtained from an online service. @@ -40,9 +40,10 @@ training and serves as the input to the model training process. During model ser The common flow when working with the feature store is to first define the feature set with its source, transformation graph, and targets. MLRun's robust transformation engine performs complex operations with just a few lines of Python code. To test the -execution process, call the `infer` method with a sample DataFrame. This runs all operations in memory without storing the results. Once the -graph is defined, it's time to ingest the data. +execution process, call the `infer` method with a sample DataFrame. This runs all operations in memory without storing the results. +Once the +graph is defined, it's time to ingest the data. You can ingest data directly from a DataFrame, by calling the feature set {py:class}`~mlrun.feature_store.ingest` method. You can also define an ingestion process that runs as a Kubernetes job. This is useful if there is a large ingestion process, or if there is a recurrent ingestion and you want to schedule the job. @@ -61,20 +62,20 @@ Next, extract a versioned **offline** static dataset for training, based on the model with the feature vector data by providing the input in the form of `'store://feature-vectors/{project}/{feature_vector_name}'`. Training functions generate models and various model statistics. Use MLRun's auto logging capabilities to store the models along with all -the relevant data, metadata and measurements. +the relevant data, metadata, and measurements. MLRun can apply all the MLOps functionality by using the framework specific `apply_mlrun()` method, which manages the training process and -automatically logs all the framework specific model details, data, metadata and metrics. +automatically logs all the framework specific model details, data, metadata, and metrics. The training job automatically generates a set of results and versioned artifacts (run `train_run.outputs` to view the job outputs). -For serving, once you validate the feature vector, use the **online** feature service, based on the -nosql target defined in the feature set for real-time serving. For serving, you define a serving class derived from +After you validate the feature vector, use the **online** feature service, based on the +nosql target defined in the feature set, for real-time serving. For serving, you define a serving class derived from `mlrun.serving.V2ModelServer`. In the class `load` method, call the {py:meth}`~mlrun.feature_store.get_online_feature_service` function with the vector name, which returns a feature service object. In the class `preprocess` method, call the feature service `get` method to get the values of those features. -Using this feature store centric process, using one computation graph definition for a feature set, you receive an automatic online and -offline implementation for the feature vectors, with data versioning both in terms of the actual graph that was used to calculate each data +This feature store centric process, using one computation graph definition for a feature set, gives you an automatic online and +offline implementation for the feature vectors with data versioning, both in terms of the actual graph that was used to calculate each data point, and the offline datasets that were created to train each model. See more information in {ref}`training with the feature store ` and {ref}`training-serving`. diff --git a/docs/feature-store/feature-store.md b/docs/feature-store/feature-store.md index f3d90daf7ee6..9fd70c05bc60 100644 --- a/docs/feature-store/feature-store.md +++ b/docs/feature-store/feature-store.md @@ -2,7 +2,7 @@ # Feature store A feature store provides a single pane of glass for sharing all available features across -the organization along with their metadata. MLRun Feature store support security, versioning, +the organization along with their metadata. The MLRun feature store supports security, versioning, and data snapshots, enabling better data lineage, compliance, and manageability. As illustrated in the diagram below, @@ -10,9 +10,9 @@ feature stores provide a mechanism (**`Feature Sets`**) to read data from variou conduct a set of data transformations, and persist the data in online and offline storage. Features are stored and cataloged along with all their metadata (schema, labels, statistics, etc.), allowing users to compose **`Feature Vectors`** and use them for training -or serving. The Feature Vectors are generated when needed, taking into account data versioning and time +or serving. The feature vectors are generated when needed, taking into account data versioning and time correctness (time traveling). Different function kinds (Nuclio, Spark, Dask) are used for feature retrieval, real-time -engine for serving, and batch one for training. +engines for serving, and batch for training.
feature-store
diff --git a/docs/feature-store/feature-vectors.md b/docs/feature-store/feature-vectors.md index 52e0eab88d60..369a64b76af8 100644 --- a/docs/feature-store/feature-vectors.md +++ b/docs/feature-store/feature-vectors.md @@ -85,17 +85,14 @@ Defaults to return as a return value to the caller. - **with_indexes** return vector with index columns and timestamp_key from the feature sets. Default is False. - **update_stats** — update features statistics from the requested feature sets on the vector. Default is False. - **engine** — processing engine kind ("local", "dask", or "spark") -- **engine_args** — kwargs for the processing engine -- **query** — The query string used to filter rows -- **spark_service** — Name of the spark service to be used (when using a remote-spark runtime) -- **join_type** — (optional) Indicates the join type: `{'left', 'right', 'outer', 'inner'}, default 'inner'`. The Spark retrieval engine only supports entities-based `inner` join (ie. no support for `relations`, no support for `outer`, `left`, `right` joins) - - left: use only keys from left frame (SQL: left outer join) - - right: use only keys from right frame (SQL: right outer join) - - outer: use union of keys from both frames (SQL: full outer join) - - inner: use intersection of keys from both frames (SQL: inner join). - -You can add a time-based filter condition when running `get_offline_feature` with a given vector. You can also filter with the query -argument on all the other features as relevant. +- **engine_args** — kwargs for the processing engine. +- **query** — The query string used to filter rows on the output. +- **spark_service** — Name of the spark service to be used (when using a remote-spark runtime) +- **order_by** — Name or list of names to order by. The name or the names in the list can be the feature name or the alias of the +feature you pass in the feature list. +- **timestamp_for_filtering** — (optional) Used to configure the columns that a time-based filter filters by. By default, the time-based filter is executed using the timestamp_key of each feature set. +Specifying the `timestamp_for_filtering` param overwrites this default: if it's str it specifies the timestamp column to use in all the feature sets. If it's a dictionary ({: , …}) it indicates the timestamp column name +for each feature set. The time filtering is performed on each feature set (using `start_time` and `end_time`) before the merge process. You can create a feature vector that comprises different feature sets, while joining the data based on specific fields and not the entity. For example: @@ -104,7 +101,7 @@ For example: You can build a feature vector that comprises fields in feature set A and get the count distinct for the email from feature set B. The join in this case is based on the email column. -Here's an example of a new dataset from a parquet target: +Here's an example of a new dataset from a Parquet target: ```python # Import the Parquet Target, so you can build your dataset from a parquet file @@ -117,8 +114,8 @@ offline_fv = fstore.get_offline_features(feature_vector_name, target=ParquetTarg dataset = offline_fv.to_dataframe() ``` -Once an offline feature vector is created with a static target (such as {py:class}`~mlrun.datastore.targets.ParquetTarget()`) the -reference to this dataset is saved as part of the feature vector's metadata and can now be referenced directly through the store +After you create an offline feature vector with a static target (such as {py:class}`~mlrun.datastore.targets.ParquetTarget()`) the +reference to this dataset is saved as part of the feature vector's metadata and can be referenced directly through the store as a function input using `store://feature-vectors/{project}/{feature_vector_name}`. For example: @@ -136,10 +133,10 @@ task = mlrun.new_task('training', run = fn.run(task) ``` -You can see a full example of using the offline feature vector to create an ML model in [part 2 of the end-to-end demo](./end-to-end-demo/02-create-training-model.html). +See a full example of using the offline feature vector to create an ML model in [part 2 of the end-to-end demo](./end-to-end-demo/02-create-training-model.html). You can use `get_offline_features` for a feature vector whose data is not ingested. See -[Create a feature set without ingesting its data](..feature-store/feature-sets.html#create-a-feature-set-without-ingesting-its-data). +[Create a feature set without ingesting its data](./feature-sets.html#create-a-feature-set-without-ingesting-its-data). #### Using joins in an offline feature vector @@ -177,7 +174,6 @@ vector.save() resp = fs.get_offline_features( vector, - join_type='outer', # one of following values: "inner" (as with current code), "outer", "right", "left" engine_args=engine_args, with_indexes=True, ) @@ -214,7 +210,6 @@ vector = fs.FeatureVector( resp = fs.get_offline_features( vector, - join_type='inner', # one of following values: "inner" (as with current code), "outer", "right", "left" engine_args=engine_args, with_indexes=False, ) diff --git a/docs/feature-store/training-serving.md b/docs/feature-store/training-serving.md index 909c0a451e62..c438aca4fe81 100644 --- a/docs/feature-store/training-serving.md +++ b/docs/feature-store/training-serving.md @@ -31,7 +31,7 @@ You define a serving model class and the computational graph required to run you To embed the online feature service in your model server, just create the feature vector service once when the model initializes, and then use it to retrieve the feature vectors of incoming keys. -You can import ready-made classes and functions from the MLRun [Function Hub](https://www.mlrun.org/marketplace/) or write your own. +You can import ready-made classes and functions from the MLRun [Function Hub](https://www.mlrun.org/hub/) or write your own. As example of a scikit-learn based model server: diff --git a/docs/feature-store/transformations.md b/docs/feature-store/transformations.md index a3cbc0f9ad31..e23591f9d936 100644 --- a/docs/feature-store/transformations.md +++ b/docs/feature-store/transformations.md @@ -44,9 +44,51 @@ to the [feature store example](./basic-demo.html). Aggregations, being a common tool in data preparation and ML feature engineering, are available directly through the MLRun {py:class}`~mlrun.feature_store.FeatureSet` class. These transformations add a new feature to the -feature-set that is created by performing an aggregate function over the feature's values. You can use aggregation for time-based -sliding windows and fixed windows. In general, sliding windows are used for real time data, while fixed windows are used for historical -aggregations. +feature-set, which is created by performing an aggregate function over the feature's values. + +If the `name` parameter is not specified, features are generated in the format `{column_name}_{operation}_{window}`. +If you supply the optional `name` parameter, features are generated in the format `{name}_{operation}_{window}`. + + +Feature names, which are generated internally, must match this regex pattern to be treated as aggregations: +`.*_[a-z]+_[0-9]+[smhd]$`,
+where `[a-z]+` is the name of an aggregation. + +```{admonition} Warning +You must ensure that your features will not conflict with the automatically generated feature names. For example, +when using `add_aggregation()` on a feature X, you may get a genegated feature name of `X_count_1h`. +But if your dataset already contains `X_count_1h`, this would result in either unreliable aggregations or errors. +``` + +If either the pattern or the condition is not met, the feature is treated as a static (or "regular") feature. + +These features can be fed into predictive models or can be used for additional processing and feature generation. + +```{admonition} Notes +- Internally, the graph step that is created to perform these aggregations is named `"Aggregates"`. If more than one + aggregation steps are needed, a unique name must be provided to each, using the `step_name` parameter. +- The timestamp column must be part of the feature set definition (for aggregation). +``` + +Aggregations that are supported using this function are: +- `count` +- `sum` +- `sqr` (sum of squares) +- `max` +- `min` +- `first` +- `last` +- `avg` +- `stdvar` (variance) +- `stddev` (standard deviation) + +For full description of this function, see the {py:func}`~mlrun.feature_store.FeatureSet.add_aggregation` +documentation. + +### Windows + +You can use aggregation for time-based sliding windows and fixed windows. In general, sliding windows are used for real time data, +while fixed windows are used for historical aggregations. A window can be measured in years, days, hours, seconds, minutes. A window can be a single window, e.g. ‘1h’, ‘1d’, or a @@ -97,40 +139,14 @@ All time windows are aligned to the epoch (1970-01-01T00:00:00Z). quotes_set = fstore.FeatureSet("stock-quotes", entities=[fstore.Entity("ticker")]) quotes_set.add_aggregation("bid", ["min", "max"], ["1h"] name="price") ``` - This code generates two new features: `bid_min_1h` and `bid_max_1h` once per hour. - - -If the `name` parameter is not specified, features are generated in the format `{column_name}_{operation}_{window}`. -If you supply the optional `name` parameter, features are generated in the format `{name}_{operation}_{window}`. - -These features can be fed into predictive models or be used for additional processing and feature generation. - -```{admonition} Notes -- Internally, the graph step that is created to perform these aggregations is named `"Aggregates"`. If more than one - aggregation steps are needed, a unique name must be provided to each, using the `state_name` parameter. -- The timestamp column must be part of the feature set definition (for aggregation). -``` + This code generates two new features: `bid_min_1h` and `bid_max_1h` once per hour. -Aggregations that are supported using this function are: -- `count` -- `sum` -- `sqr` (sum of squares) -- `max` -- `min` -- `first` -- `last` -- `avg` -- `stdvar` -- `stddev` - -For a full documentation of this function, see the {py:func}`~mlrun.feature_store.FeatureSet.add_aggregation` -documentation. ## Built-in transformations MLRun, and the associated `storey` package, have a built-in library of [transformation functions](../serving/available-steps.html) that can be -applied as steps in the feature-set's internal execution graph. In order to add steps to the graph, it should be -referenced from the {py:class}`~mlrun.feature_store.FeatureSet` object by using the +applied as steps in the feature-set's internal execution graph. To add steps to the graph, +reference them from the {py:class}`~mlrun.feature_store.FeatureSet` object by using the {py:attr}`~mlrun.feature_store.FeatureSet.graph` property. Then, new steps can be added to the graph using the functions in {py:mod}`storey.transformations` (follow the link to browse the documentation and the list of existing functions). The transformations are also accessible directly from the `storey` module. @@ -140,10 +156,9 @@ See the [built-in steps](../serving/available-steps.html). ```{admonition} Note Internally, MLRun makes use of functions defined in the `storey` package for various purposes. When creating a feature-set and configuring it with sources and targets, what MLRun does behind the scenes is to add steps to the -execution graph that wraps methods and classes that perform the actions. When defining an async execution graph, - +execution graph that wraps methods and classes that perform the actions. When defining an async execution graph, `storey` classes are used. For example, when defining a Parquet data-target in MLRun, a graph step is created that -wraps storey's {py:func}`~storey.targets.ParquetTarget` function. +wraps storey's [ParquetTarget function](https://storey.readthedocs.io/en/latest/api.html#storey.targets.ParquetTarget). ``` To use a function: @@ -190,4 +205,83 @@ quotes_set.graph.add_step("MyMap", "multi", after="filter", multiplier=3) ``` This uses the `add_step` function of the graph to add a step called `multi` utilizing `MyMap` after the `filter` step -that was added previously. The class is initialized with a multiplier of 3. \ No newline at end of file +that was added previously. The class is initialized with a multiplier of 3. + +## Supporting multiple engines + +MLRun supports multiple processing engines for executing graphs. These engines differ in the way they invoke graph +steps. When implementing custom transformations, the code has to support all engines that are expected to run it. + +```{admonition} Note +The vast majority of MLRun's built-in transformations support all engines. The support matrix is available +[here](../serving/available-steps.html#data-transformations). +``` + +The following are the main differences between transformation steps executing on different engines: + +* `storey` - the step receives a single event (either as a dictionary or as an Event object, depending on whether + `full_event` is configured for the step). The step is expected to process the event and return the modified event. +* `spark` - the step receives a Spark dataframe object. Steps are expected to add their processing and calculations to + the dataframe (either in-place or not) and return the resulting dataframe without materializing the data. +* `pandas` - the step receives a Pandas dataframe, processes it, and returns the dataframe. + +To support multiple engines, extend the {py:class}`~mlrun.feature_store.steps.MLRunStep` class with a custom +transformation. This class allows implementing engine-specific code by overriding the following methods: +{py:func}`~mlrun.feature_store.steps.MLRunStep._do_storey`, {py:func}`~mlrun.feature_store.steps.MLRunStep._do_pandas` +and {py:func}`~mlrun.feature_store.steps.MLRunStep._do_spark`. To add support for a given engine, the relevant `do` +method needs to be implemented. + +When a graph is executed, each step is a single instance of the relevant class that gets invoked as events flow through +the graph. For `spark` and `pandas` engines, this only happens once per ingestion, since the entire data-frame is fed to +the graph. For the `storey` engine the same instance's {py:func}`~mlrun.feature_store.steps.MLRunStep._do_storey` +function will be invoked per input row. As the graph is initialized, this class instance can receive global parameters +in its `__init__` method that determines its behavior. + +The following example class multiplies a feature by a value and adds it to the event. (For simplicity, data type +checks and validations were omitted as well as needed imports.) Note that the class also extends +{py:class}`~mlrun.serving.utils.StepToDict` - this class implements generic serialization of graph steps to +a python dictionary. This functionality allows passing instances of this class to `graph.to()` and `graph.add_step()`: + +```python +class MultiplyFeature(StepToDict, MLRunStep): + def __init__(self, feature: str, value: int, **kwargs): + super().__init__(**kwargs) + self._feature = feature + self._value = value + self._new_feature = f"{feature}_times_{value}" + + def _do_storey(self, event): + # event is a single row represented by a dictionary + event[self._new_feature] = event[self._feature] * self._value + return event + + def _do_pandas(self, event): + # event is a pandas.DataFrame + event[self._new_feature] = event[self._feature].multiply(self._value) + return event + + def _do_spark(self, event): + # event is a pyspark.sql.DataFrame + return event.withColumn(self._new_feature, + col(self._feature) * lit(self._value) + ) +``` + +The following example uses this step in a feature-set graph with the `pandas` engine. This example adds a feature called +`number1_times_4` with the value of the `number1` feature multiplied by 4. Note how the global parameters are passed +when creating the graph step: + +```python +import mlrun.feature_store as fstore + +feature_set = fstore.FeatureSet("fs-new", + entities=[fstore.Entity("id")], + engine="pandas", + ) +# Adding multiply step, with specific parameters +feature_set.graph.to(MultiplyFeature(feature="number1", value=4)) +df_pandas = fstore.ingest(feature_set, data) +``` + + + diff --git a/docs/feature-store/using-spark-engine.md b/docs/feature-store/using-spark-engine.md index ae19d51b4593..f27d8b8ba0b3 100644 --- a/docs/feature-store/using-spark-engine.md +++ b/docs/feature-store/using-spark-engine.md @@ -208,7 +208,7 @@ One-time setup: secrets = {'s3_access_key': AWS_ACCESS_KEY, 's3_secret_key': AWS_SECRET_KEY} mlrun.get_run_db().create_project_secrets( project = "uhuh-proj", - provider=mlrun.api.schemas.SecretProviderName.kubernetes, + provider=mlrun.common.schemas.SecretProviderName.kubernetes, secrets=secrets ) ``` diff --git a/docs/hyper-params.ipynb b/docs/hyper-params.ipynb index 5a947aa867f4..328f3c51f4e8 100644 --- a/docs/hyper-params.ipynb +++ b/docs/hyper-params.ipynb @@ -393,8 +393,10 @@ } ], "source": [ - "grid_params = {\"p1\": [2,4,1], \"p2\": [10,20]}\n", - "task = mlrun.new_task(\"grid-demo\").with_hyper_params(grid_params, selector=\"max.multiplier\")\n", + "grid_params = {\"p1\": [2, 4, 1], \"p2\": [10, 20]}\n", + "task = mlrun.new_task(\"grid-demo\").with_hyper_params(\n", + " grid_params, selector=\"max.multiplier\"\n", + ")\n", "run = mlrun.new_function().run(task, handler=hyper_func)" ] }, @@ -664,9 +666,11 @@ } ], "source": [ - "grid_params = {\"p1\": [2,4,1,3], \"p2\": [10,20,30]}\n", + "grid_params = {\"p1\": [2, 4, 1, 3], \"p2\": [10, 20, 30]}\n", "task = mlrun.new_task(\"random-demo\")\n", - "task.with_hyper_params(grid_params, selector=\"max.multiplier\", strategy=\"random\", max_iterations=4)\n", + "task.with_hyper_params(\n", + " grid_params, selector=\"max.multiplier\", strategy=\"random\", max_iterations=4\n", + ")\n", "run = mlrun.new_function().run(task, handler=hyper_func)" ] }, @@ -925,9 +929,13 @@ } ], "source": [ - "list_params = {\"p1\": [2,3,7,4,5], \"p2\": [15,10,10,20,30]}\n", + "list_params = {\"p1\": [2, 3, 7, 4, 5], \"p2\": [15, 10, 10, 20, 30]}\n", "task = mlrun.new_task(\"list-demo\").with_hyper_params(\n", - " list_params, selector=\"max.multiplier\", strategy=\"list\", stop_condition=\"multiplier>=70\")\n", + " list_params,\n", + " selector=\"max.multiplier\",\n", + " strategy=\"list\",\n", + " stop_condition=\"multiplier>=70\",\n", + ")\n", "run = mlrun.new_function().run(task, handler=hyper_func)" ] }, @@ -951,14 +959,14 @@ " for param in param_list:\n", " with context.get_child_context(**param) as child:\n", " hyper_func(child, **child.parameters)\n", - " multiplier = child.results['multiplier']\n", + " multiplier = child.results[\"multiplier\"]\n", " total += multiplier\n", " if multiplier > best_multiplier:\n", " child.mark_as_best()\n", " best_multiplier = multiplier\n", "\n", " # log result at the parent\n", - " context.log_result('avg_multiplier', total / len(param_list))" + " context.log_result(\"avg_multiplier\", total / len(param_list))" ] }, { @@ -1205,7 +1213,7 @@ } ], "source": [ - "param_list = [{\"p1\":2, \"p2\":10}, {\"p1\":3, \"p2\":30}, {\"p1\":4, \"p2\":7}]\n", + "param_list = [{\"p1\": 2, \"p2\": 10}, {\"p1\": 3, \"p2\": 30}, {\"p1\": 4, \"p2\": 7}]\n", "run = mlrun.new_function().run(handler=handler, params={\"param_list\": param_list})" ] }, @@ -1252,6 +1260,8 @@ "source": [ "import socket\n", "import pandas as pd\n", + "\n", + "\n", "def hyper_func2(context, data, p1, p2, p3):\n", " print(data.as_df().head())\n", " context.logger.info(f\"p2={p2}, p3={p3}, r1={p2 * p3} at {socket.gethostname()}\")\n", @@ -1307,10 +1317,10 @@ } ], "source": [ - "dask_cluster = mlrun.new_function(\"dask-cluster\", kind='dask', image='mlrun/ml-models')\n", - "dask_cluster.apply(mlrun.mount_v3io()) # add volume mounts\n", - "dask_cluster.spec.service_type = \"NodePort\" # open interface to the dask UI dashboard\n", - "dask_cluster.spec.replicas = 2 # define two containers\n", + "dask_cluster = mlrun.new_function(\"dask-cluster\", kind=\"dask\", image=\"mlrun/ml-models\")\n", + "dask_cluster.apply(mlrun.mount_v3io()) # add volume mounts\n", + "dask_cluster.spec.service_type = \"NodePort\" # open interface to the dask UI dashboard\n", + "dask_cluster.spec.replicas = 2 # define two containers\n", "uri = dask_cluster.save()\n", "uri" ] @@ -1425,10 +1435,18 @@ } ], "source": [ - "grid_params = {\"p2\": [2,1,4,1], \"p3\": [10,20]}\n", - "task = mlrun.new_task(params={\"p1\": 8}, inputs={'data': 'https://s3.wasabisys.com/iguazio/data/iris/iris_dataset.csv'})\n", + "grid_params = {\"p2\": [2, 1, 4, 1], \"p3\": [10, 20]}\n", + "task = mlrun.new_task(\n", + " params={\"p1\": 8},\n", + " inputs={\"data\": \"https://s3.wasabisys.com/iguazio/data/iris/iris_dataset.csv\"},\n", + ")\n", "task.with_hyper_params(\n", - " grid_params, selector=\"r1\", strategy=\"grid\", parallel_runs=4, dask_cluster_uri=uri, teardown_dask=True\n", + " grid_params,\n", + " selector=\"r1\",\n", + " strategy=\"grid\",\n", + " parallel_runs=4,\n", + " dask_cluster_uri=uri,\n", + " teardown_dask=True,\n", ")" ] }, @@ -1445,7 +1463,7 @@ "metadata": {}, "outputs": [], "source": [ - "fn = mlrun.code_to_function(name='hyper-tst', kind='job', image='mlrun/ml-models')" + "fn = mlrun.code_to_function(name=\"hyper-tst\", kind=\"job\", image=\"mlrun/ml-models\")" ] }, { @@ -1844,7 +1862,7 @@ } ], "source": [ - "fn = mlrun.code_to_function(name='hyper-tst2', kind='nuclio:mlrun', image='mlrun/mlrun')\n", + "fn = mlrun.code_to_function(name=\"hyper-tst2\", kind=\"nuclio:mlrun\", image=\"mlrun/mlrun\")\n", "# replicas * workers need to match or exceed parallel_runs\n", "fn.spec.replicas = 2\n", "fn.with_http(workers=2)\n", @@ -1867,6 +1885,7 @@ "# this is required to fix Jupyter issue with asyncio (not required outside of Jupyter)\n", "# run it only once\n", "import nest_asyncio\n", + "\n", "nest_asyncio.apply()" ] }, @@ -2144,8 +2163,11 @@ } ], "source": [ - "grid_params = {\"p2\": [2,1,4,1], \"p3\": [10,20]}\n", - "task = mlrun.new_task(params={\"p1\": 8}, inputs={'data': 'https://s3.wasabisys.com/iguazio/data/iris/iris_dataset.csv'})\n", + "grid_params = {\"p2\": [2, 1, 4, 1], \"p3\": [10, 20]}\n", + "task = mlrun.new_task(\n", + " params={\"p1\": 8},\n", + " inputs={\"data\": \"https://s3.wasabisys.com/iguazio/data/iris/iris_dataset.csv\"},\n", + ")\n", "task.with_hyper_params(\n", " grid_params, selector=\"r1\", strategy=\"grid\", parallel_runs=4, max_errors=3\n", ")\n", diff --git a/docs/index.md b/docs/index.md index 65196f958f96..c32bc4a12aac 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,4 +1,4 @@ -(architecture)= +(using-mlrun)= # Using MLRun ```{div} full-width @@ -127,7 +127,8 @@ MLRun rapidly deploys and manages production-grade real-time or batch applicatio Observability is built into the different MLRun objects (data, functions, jobs, models, pipelines, etc.), eliminating the need for complex integrations and code instrumentation. With MLRun, you can observe the application/model resource usage and model behavior (drift, performance, etc.), define custom app metrics, and trigger alerts or retraining jobs. {bdg-link-primary-line}`more... <./monitoring/index.html>` -`````{div} full-width{octicon}`mortar-board` **Docs:** +`````{div} full-width +{octicon}`mortar-board` **Docs:** {bdg-link-info}`Model monitoring overview <./monitoring/model-monitoring-deployment.html>` , {octicon}`code-square` **Tutorials:** {bdg-link-primary}`Model monitoring & drift detection <./tutorial/05-model-monitoring.html>` diff --git a/docs/install.md b/docs/install.md index 13bf2bb03601..5b3561be1eba 100644 --- a/docs/install.md +++ b/docs/install.md @@ -29,7 +29,7 @@ There are several deployment options: ## Set up your client -You can work with your favorite IDE (e.g. Pycharm, VScode, Jupyter, Colab etc..). Read how to configure your client against the deployed +You can work with your favorite IDE (e.g. Pycharm, VScode, Jupyter, Colab, etc.). Read how to configure your client against the deployed MLRun server in {ref}`install-remote`. Once you have installed and configured MLRun, follow the [Quick Start tutorial](https://docs.mlrun.org/en/latest/tutorial/01-mlrun-basics.html) and additional {ref}`Tutorials and Examples` to learn how to use MLRun to develop and deploy machine learning applications to production. diff --git a/docs/install/aws-install.md b/docs/install/aws-install.md index c92095b0864f..57465fab695e 100644 --- a/docs/install/aws-install.md +++ b/docs/install/aws-install.md @@ -3,6 +3,20 @@ For AWS users, the easiest way to install MLRun is to use a native AWS deployment. This option deploys MLRun on an AWS EKS service using a CloudFormation stack. +```{admonition} Note +These instructions install the community edition, which currently includes MLRun {{ ceversion }}. See the {{ '[release documentation](https://{})'.format(releasedocumentation) }}. +``` + +**In this section** +- [Prerequisites](#prerequisites) +- [Post deployment expectations](#post-deployment-expectations) +- [Configuration settings](#configuration-settings) +- [Getting started](#getting-started) +- [Storage resources](#storage-resources) +- [Configuring the online features store](#configuring-the-online-feature-store) +- [Streaming support](#streaming-support) +- [Cleanup](#cleanup) + ## Prerequisites 1. An AWS account with permissions that include the ability to: @@ -24,10 +38,10 @@ For AWS users, the easiest way to install MLRun is to use a native AWS deploymen For more information, see [how to create a new AWS account](https://aws.amazon.com/premiumsupport/knowledge-center/create-and-activate-aws-account/) and [policies and permissions in IAM](https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html). -2. You need to have a Route53 domain configured in the same AWS account and specify the full domain name in **Route 53 hosted DNS domain** configuration (See [Step 11](#route53_config) below). External domain registration is currently not supported. For more information see [What is Amazon Route 53?](https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/Welcome.html). +2. A Route53 domain configured in the same AWS account, and with the full domain name specified in **Route 53 hosted DNS domain** configuration (See [Step 11](#route53_config) below). External domain registration is currently not supported. For more information see [What is Amazon Route 53?](https://docs.aws.amazon.com/Route53/latest/DeveloperGuide/Welcome.html). ```{admonition} Notes -The MLRun software is free of charge, however, there is a cost for the AWS infrastructure services such as EKS, EC2, S3 and ECR. The actual pricing depends on a large set of factors including, for example, the region, the number of EC2 instances, the amount of storage consumed, and the data transfer costs. Other factors include, for example, reserved instance configuration, saving plan, and AWS credits you have associated with your account. It is recommended to use the [AWS pricing calculator](https://calculator.aws) to calculate the expected cost, as well as the [AWS Cost Explorer](https://aws.amazon.com/aws-cost-management/aws-cost-explorer/) to manage the cost, monitor and set-up alerts. +The MLRun software is free of charge, however, there is a cost for the AWS infrastructure services such as EKS, EC2, S3 and ECR. The actual pricing depends on a large set of factors including, for example, the region, the number of EC2 instances, the amount of storage consumed, and the data transfer costs. Other factors include, for example, reserved instance configuration, saving plan, and AWS credits you have associated with your account. It is recommended to use the [AWS pricing calculator](https://calculator.aws) to calculate the expected cost, as well as the [AWS Cost Explorer](https://aws.amazon.com/aws-cost-management/aws-cost-explorer/) to manage the cost, monitor, and set-up alerts. ``` ## Post deployment expectations @@ -65,9 +79,9 @@ You must fill in fields marked as mandatory (m) for the configuration to complet **VPC network Configuration** -3. **Number of Availability Zones** (m) — number of availability zones. The default is set to 3. Choose from the dropdown to change the number. The minimum is 2. +3. **Number of Availability Zones** (m) — The default is set to 3. Choose from the dropdown to change the number. The minimum is 2. 4. **Availability zones** (m) — select a zone from the dropdown. The list is based on the region of the instance. The number of zones must match the number of zones Number of Availability Zones. -5. **Allowed external access CIDR** (m) — range of IP address allowed to access the cluster. Addresses that are not in this range are not able to access the cluster. Contact your IT manager/network administrator if you are not sure what to fill here. +5. **Allowed external access CIDR** (m) — range of IP addresses allowed to access the cluster. Addresses that are not in this range are not able to access the cluster. Contact your IT manager/network administrator if you are not sure what to fill in here. **Amazon EKS configuration** @@ -77,9 +91,9 @@ You must fill in fields marked as mandatory (m) for the configuration to complet **Amazon EC2 configuration** -9. **SSH key name** (o) — Users who wish to access the EC2 instance via SSH can enter an existing key. If left empty, it is possible to access the EC2 instance using the AWS Systems Manager Session Manager. For more information about SSH Keys see [Amazon EC2 key pairs and Linux instances](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-key-pairs.html). +9. **SSH key name** (o) — To access the EC2 instance via SSH, enter an existing key. If left empty, it is possible to access the EC2 instance using the AWS Systems Manager Session Manager. For more information about SSH Keys see [Amazon EC2 key pairs and Linux instances](https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/ec2-key-pairs.html). -10. **Provision bastion host** (m) — create a bastion host for SSH access to the Kubernetes nodes. The default is enabled. This allows ssh access to your EKS EC2 instances through a public IP. +10. **Provision bastion host** (m) — create a bastion host for SSH access to the Kubernetes nodes. The default is enabled. This allows SSH access to your EKS EC2 instances through a public IP. **Iguazio MLRun configuration** @@ -119,7 +133,7 @@ When installing the MLRun Community Edition via Cloud Formation, several storage -## How to configure the online feature store +## Configuring the online feature store The feature store can store data on a fast key-value database table for quick serving. This online feature store capability requires an external key-value database. @@ -127,13 +141,13 @@ Currently the MLRun feature store supports the following options: - Redis - Iguazio key-value database -To use Redis, you must install Redis separately and provide the Redis URL when configuring the AWS CloudFormation stack. Refer to the [Redis getting-started page](https://redis.io/docs/getting-started/). for information about Redis installation. +To use Redis, you must install Redis separately and provide the Redis URL when configuring the AWS CloudFormation stack. Refer to the [Redis getting-started page](https://redis.io/docs/getting-started/) for information about Redis installation. ## Streaming support For online serving, it is often convenient to use MLRun graph with a streaming engine. This allows managing queues between steps and functions. MLRun supports Kafka streams as well as Iguazio V3IO streams. -See the examples on how to configure the MLRun serving graph with {ref}`kafka` and {ref}`V3IO`. +See the examples on how to configure the MLRun serving graph with {ref}`Kafka` and {ref}`V3IO`. ## Cleanup diff --git a/docs/install/compose.with-jupyter.yaml b/docs/install/compose.with-jupyter.yaml index 3aeccdffb0b8..55614666e957 100644 --- a/docs/install/compose.with-jupyter.yaml +++ b/docs/install/compose.with-jupyter.yaml @@ -1,6 +1,6 @@ services: init_nuclio: - image: alpine:3.16 + image: alpine:3.18 command: - "/bin/sh" - "-c" diff --git a/docs/install/compose.yaml b/docs/install/compose.yaml index aab890e9e4f5..584928b69f67 100644 --- a/docs/install/compose.yaml +++ b/docs/install/compose.yaml @@ -1,6 +1,6 @@ services: init_nuclio: - image: alpine:3.16 + image: alpine:3.18 command: - "/bin/sh" - "-c" diff --git a/docs/install/kubernetes.md b/docs/install/kubernetes.md index 8c994b299ba9..6f7718447343 100644 --- a/docs/install/kubernetes.md +++ b/docs/install/kubernetes.md @@ -1,22 +1,26 @@ (install-on-kubernetes)= # Install MLRun on Kubernetes +```{admonition} Note +These instructions install the community edition, which currently includes MLRun {{ ceversion }}. See the {{ '[release documentation](https://{})'.format(releasedocumentation) }}. +``` + **In this section** - [Prerequisites](#prerequisites) -- [Community Edition Flavors](#community-edition-flavors) -- [Installing the Chart](#installing-the-chart) -- [Configuring Online Feature Store](#configuring-online-feature-store) +- [Community Edition flavors](#community-edition-flavors) +- [Installing the chart](#installing-the-chart) +- [Configuring the online features store](#configuring-the-online-feature-store) - [Usage](#usage) - [Start working](#start-working) - [Configuring the remote environment](#configuring-the-remote-environment) - [Advanced chart configuration](#advanced-chart-configuration) -- [Storage Resources](#storage-resources) +- [Storage resources](#storage-resources) - [Uninstalling the chart](#uninstalling-the-chart) - [Upgrading the chart](#upgrading-the-chart) ## Prerequisites -- Access to a Kubernetes cluster. You must have administrator permissions in order to install MLRun on your cluster. MLRun fully supports k8s releases 1.22 and 1.23. For local installation +- Access to a Kubernetes cluster. To install MLRun on your cluster, you must have administrator permissions. MLRun fully supports k8s releases 1.22 and 1.23. For local installation on Windows or Mac, [Docker Desktop](https://www.docker.com/products/docker-desktop) is recommended. - The Kubernetes command-line tool (kubectl) compatible with your Kubernetes cluster is installed. Refer to the [kubectl installation instructions](https://kubernetes.io/docs/tasks/tools/install-kubectl/) for more information. @@ -24,11 +28,11 @@ instructions](https://kubernetes.io/docs/tasks/tools/install-kubectl/) for more - An accessible docker-registry (such as [Docker Hub](https://hub.docker.com)). The registry's URL and credentials are consumed by the applications via a pre-created secret. - Storage: - 8Gi - - It is also required to set a default storage class for the kubernetes cluster in order for the pods to have persistent storage. Please see the [Kubernetes documentation](https://kubernetes.io/docs/concepts/storage/storage-classes/#the-storageclass-resource) for more information. + - Set a default storage class for the kubernetes cluster, in order for the pods to have persistent storage. See the [Kubernetes documentation](https://kubernetes.io/docs/concepts/storage/storage-classes/#the-storageclass-resource) for more information. - RAM: A minimum of 8Gi is required for running all the initial MLRun components. The amount of RAM required for running MLRun jobs depends on the job's requirements. ``` {admonition} Note -The MLRun Community Edition resources are configured initially with the default cluster/namespace resources limits. You can modify the resources from outside if needed. +The MLRun Community Edition resources are configured initially with the default cluster/namespace resource limits. You can modify the resources from outside if needed. ``` ## Community Edition flavors @@ -50,7 +54,7 @@ The MLRun CE (Community Edition) includes the following components: -## Installing the Chart +## Installing the chart ```{admonition} Note These instructions use `mlrun` as the namespace (`-n` parameter). You can choose a different namespace in your kubernetes cluster. @@ -65,7 +69,7 @@ kubectl create namespace mlrun Add the Community Edition helm chart repo: ```bash -helm repo add mlrun-ce https://mlrun.github.io/ce +helm repo add mlrun https://mlrun.github.io/ce ``` Run the following command to ensure that the repo is installed and available: @@ -76,7 +80,7 @@ helm repo list It should output something like: ```bash NAME URL -mlrun-ce https://mlrun.github.io/ce +mlrun-ce https://github.com/mlrun/ce ``` Update the repo to make sure you're getting the latest chart: @@ -106,8 +110,8 @@ Where: - `` is your Docker email. ```{admonition} Note -First-time MLRun users will experience a relatively longer installation time because all required images -are being pulled locally for the first time (it will take an average of 10-15 minutes mostly depends on +First-time MLRun users experience a relatively longer installation time because all required images +are pulled locally for the first time (it takes an average of 10-15 minutes, mostly depending on your internet speed). ``` @@ -129,18 +133,18 @@ Where: - `` is the registry URL that can be authenticated by the `registry-credentials` secret (e.g., `index.docker.io/` for Docker Hub). - `` is the IP address of the host machine (or `$(minikube ip)` if using minikube). -When the installation is complete, the helm command prints the URLs and Ports of all the MLRun CE services. +When the installation is complete, the helm command prints the URLs and ports of all the MLRun CE services. > **Note:** > There is currently a known issue with installing the chart on Macs using Apple Silicon (M1). The current pipelines > mysql database fails to start. The workaround for now is to opt out of pipelines by installing the chart with the > `--set pipelines.mysql.enabled=false`. -## Configuring Online Feature Store -The MLRun Community Edition now supports the online feature store. To enable it, you need to first deploy a REDIS service that is accessible to your MLRun CE cluster. -To deploy a REDIS service, refer to the following [link](https://redis.io/docs/getting-started/). +## Configuring the online feature store +The MLRun Community Edition now supports the online feature store. To enable it, you need to first deploy a Redis service that is accessible to your MLRun CE cluster. +To deploy a Redis service, refer to the [Redis documentation](https://redis.io/docs/getting-started/). -When you have a REDIS service deployed, you can configure MLRun CE to use it by adding the following helm value configuration to your helm install command: +When you have a Redis service deployed, you can configure MLRun CE to use it by adding the following helm value configuration to your helm install command: ```bash --set mlrun.api.extraEnvKeyValue.MLRUN_REDIS__URL= ``` @@ -148,34 +152,34 @@ When you have a REDIS service deployed, you can configure MLRun CE to use it by ## Usage Your applications are now available in your local browser: -- jupyter-notebook - `http://:30040` -- nuclio - `http://:30050` -- mlrun UI - `http://:30060` -- mlrun API (external) - `http://:30070` -- minio API - `http://:30080` -- minio UI - `http://:30090` -- pipeline UI - `http://:30100` -- grafana UI - `http://:30110` +- Jupyter Notebook - `http://:30040` +- Nuclio - `http://:30050` +- MLRun UI - `http://:30060` +- MLRun API (external) - `http://:30070` +- MinIO API - `http://:30080` +- MinIO UI - `http://:30090` +- Pipeline UI - `http://:30100` +- Grafana UI - `http://:30110` ```{admonition} Check state -You can check current state of installation via command `kubectl -n mlrun get pods`, where the main information -is in columns `Ready` and `State`. If all images have already been pulled locally, typically it will take +You can check the current state of the installation via the command `kubectl -n mlrun get pods`, where the main information +is in columns `Ready` and `State`. If all images have already been pulled locally, typically it takes a minute for all services to start. ``` ```{admonition} Note You can change the ports by providing values to the helm install command. -You can add and configure a k8s ingress-controller for better security and control over external access. +You can add and configure a Kubernetes ingress-controller for better security and control over external access. ``` -## Start Working +## Start working Open the Jupyter notebook on [**jupyter-notebook UI**](http://localhost:30040) and run the code in the [**examples/mlrun_basics.ipynb**](https://github.com/mlrun/mlrun/blob/master/examples/mlrun_basics.ipynb) notebook. ```{admonition} Important -Make sure to save your changes in the `data` folder within the Jupyter Lab. The root folder and any other folder do not retain the changes when you restart the Jupyter Lab. +Make sure to save your changes in the `data` folder within the Jupyter Lab. The root folder and any other folders do not retain the changes when you restart the Jupyter Lab. ``` ## Configuring the remote environment @@ -187,8 +191,8 @@ You can use your code on a local machine while running your functions on a remot Configurable values are documented in the `values.yaml`, and the `values.yaml` of all sub charts. Override those [in the normal methods](https://helm.sh/docs/chart_template_guide/values_files/). ### Opt out of components -The chart installs many components. You might not need them all in your deployment depending on your use cases. -In order to opt out of some of the components, you can use the following helm values: +The chart installs many components. You may not need them all in your deployment depending on your use cases. +To opt out of some of the components, use the following helm values: ```bash ... --set pipelines.enabled=false \ @@ -209,11 +213,12 @@ Docker Desktop is available for Mac and Windows. For download information, syste Docker Desktop includes a standalone Kubernetes server and client, as well as Docker CLI integration that runs on your machine. The Kubernetes server runs locally within your Docker instance. To enable Kubernetes support and install a standalone instance of Kubernetes -running as a Docker container, go to **Preferences** > **Kubernetes** and then click **Enable Kubernetes**. Click **Apply & Restart** to -save the settings and then click **Install** to confirm. This instantiates the images that are required to run the Kubernetes server as +running as a Docker container, go to **Preferences** > **Kubernetes** and then press **Enable Kubernetes**. Press **Apply & Restart** to +save the settings and then press **Install** to confirm. This instantiates the images that are required to run the Kubernetes server as containers, and installs the `/usr/local/bin/kubectl` command on your machine. For more information, see [the Kubernetes documentation](https://docs.docker.com/desktop/kubernetes/). -It's recommended to limit the amount of memory allocated to Kubernetes. If you're using Windows and WSL 2, you can configure global WSL options by placing a `.wslconfig` file into the root directory of your users folder: `C:\Users\\.wslconfig`. Keep in mind that you might need to run `wsl --shutdown` to shut down the WSL 2 VM and then restart your WSL instance for these changes to take effect. +It's recommended to limit the amount of memory allocated to Kubernetes. If you're using Windows and WSL 2, you can configure global WSL options by placing a `.wslconfig` file into the root directory of +your users folder: `C:\Users\\.wslconfig`. Keep in mind that you might need to run `wsl --shutdown` to shut down the WSL 2 VM and then restart your WSL instance for these changes to take effect. ``` console [wsl2] @@ -229,10 +234,12 @@ To learn about the various UI options and their usage, see: When installing the MLRun Community Edition, several storage resources are created: -- **PVs via default configured storage class**: Used to hold the file system of the stacks pods, including the MySQL database of MLRun, Minio for artifacts and Pipelines Storage and more. These are not deleted when the stack is uninstalled to allow upgrades without losing data. -- **Container Images in the configured docker-registry**: When building and deploying MLRun and Nuclio functions via the MLRun Community Edition, the function images are stored in the given configured docker registry. These images persist in the docker registry and are not deleted. +- **PVs via default configured storage class**: Holds the file system of the stacks pods, including the MySQL database of MLRun, Minio for artifacts and Pipelines Storage and more. +These are not deleted when the stack is uninstalled, which allows upgrading without losing data. +- **Container Images in the configured docker-registry**: When building and deploying MLRun and Nuclio functions via the MLRun Community Edition, the function images are +stored in the given configured docker registry. These images persist in the docker registry and are not deleted. -## Uninstalling the Chart +## Uninstalling the chart The following command deletes the pods, deployments, config maps, services and roles+role bindings associated with the chart and release. @@ -248,8 +255,8 @@ helm --namespace mlrun uninstall mlrun-ce ### Note on terminating pods and hanging resources This chart generates several persistent volume claims that provide persistency (via PVC) out of the box. -Upon uninstallation, any hanging / terminating pods will hold the PVCs and PVs respectively, as those prevent their safe removal. -Since pods that are stuck in terminating state seem to be a never-ending plague in k8s, note this, +Upon uninstallation, any hanging / terminating pods hold the PVCs and PVs respectively, as those prevent their safe removal. +Since pods that are stuck in terminating state seem to be a never-ending plague in Kubernetes, note this, and remember to clean the remaining PVs and PVCs. ### Handing stuck-at-terminating pods: @@ -283,14 +290,31 @@ $ kubectl --namespace mlrun delete pvc ## Upgrading the chart -In order to upgrade to the latest version of the chart, first make sure you have the latest helm repo +To upgrade to the latest version of the chart, first make sure you have the latest helm repo ```bash helm repo update ``` -Then upgrade the chart: +Then try to upgrade the chart: + +```bash +helm upgrade --install --reuse-values mlrun-ce —namespace mlrun mlrun-ce/mlrun-ce +``` + +If it fails, you should reinstall the chart: +1. remove current mlrun-ce +```bash +mkdir ~/tmp +helm get values -n mlrun mlrun-ce > ~/tmp/mlrun-ce-values.yaml +helm uninstall mlrun-ce +``` +2. reinstall mlrun-ce, reuse values ```bash -helm upgrade --install --reuse-values mlrun-ce mlrun-ce/mlrun-ce +helm install -n mlrun --values ~/tmp/mlrun-ce-values.yaml mlrun-ce mlrun-ce/mlrun-ce --devel +``` + +```{admonition} Note +If your values have fixed mlrun service versions (e..g: mlrun:1.2.1) then you might want to remove it from the values file to allow newer chart defaults to kick in ``` \ No newline at end of file diff --git a/docs/install/remote.md b/docs/install/remote.md index 8f7cf27d1dc2..b070944941d5 100644 --- a/docs/install/remote.md +++ b/docs/install/remote.md @@ -9,7 +9,7 @@ You can write your code on a local machine while running your functions on a rem - [Configure remote environment](#configure-remote-environment) - [Using `mlrun config set` command in MLRun CLI](#using-mlrun-config-set-command-in-mlrun-cli) - [Using `mlrun.set_environment` command in MLRun SDK](#using-mlrun-set-environment-command-in-mlrun-sdk) - - [Using your IDE (e.g PyCharm or VSCode)](#using-your-ide-e-g-pycharm-or-vscode) + - [Using your IDE (e.g. PyCharm or VSCode)](#using-your-ide-e-g-pycharm-or-vscode) ## Prerequisites @@ -41,13 +41,27 @@ To install a specific version, use the command: `pip install mlrun==`. - To install all extras, run: ```pip install mlrun[complete]``` See the full list [here](https://github.com/mlrun/mlrun/blob/development/setup.py#L75).
-2. Alternatively, if you already installed a previous version of MLRun, upgrade it by running: +3. Alternatively, if you already installed a previous version of MLRun, upgrade it by running: ```sh pip install -U mlrun== ``` -3. Ensure that you have remote access to your MLRun service (i.e., to the service URL on the remote Kubernetes cluster). +4. Ensure that you have remote access to your MLRun service (i.e., to the service URL on the remote Kubernetes cluster). +5. When installing other python packages on top of MLRun, make sure to install them with mlrun in the same command/requirement file to avoid version conflicts. For example: + ```sh + pip install mlrun + ``` + or + ```sh + pip install -r requirements.txt + ``` + where `requirements.txt` contains: + ``` + mlrun + + ``` + Do so even if you already have MLRun installed so that pip will take MLRun requirements into consideration when installing the other package. ## Configure remote environment You have a few options to configure your remote environment: diff --git a/docs/monitoring/dashboards/iguazio-3.5.2-and-older/model-monitoring-details.json b/docs/monitoring/dashboards/iguazio-3.5.2-and-older/model-monitoring-details.json new file mode 100644 index 000000000000..d5d8a440e38b --- /dev/null +++ b/docs/monitoring/dashboards/iguazio-3.5.2-and-older/model-monitoring-details.json @@ -0,0 +1,791 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "gnetId": null, + "graphTooltip": 0, + "id": 45, + "iteration": 1679739783082, + "links": [ + { + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [], + "targetBlank": false, + "title": "Model Monitoring - Performance", + "type": "link", + "url": "/d/9CazA-UGz/model-monitoring-performance" + }, + { + "icon": "dashboard", + "includeVars": true, + "keepTime": true, + "tags": [], + "targetBlank": false, + "title": "Model Monitoring - Overview", + "tooltip": "", + "type": "link", + "url": "/d/g0M4uh0Mz" + } + ], + "panels": [ + { + "datasource": "model-monitoring", + "description": "", + "fieldConfig": { + "defaults": { + "custom": { + "align": null, + "displayMode": "auto", + "filterable": false + }, + "mappings": [ + { + "from": "", + "id": 0, + "text": "", + "to": "", + "type": 1 + } + ], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "First Request" + }, + "properties": [ + { + "id": "unit", + "value": "dateTimeFromNow" + }, + { + "id": "custom.align", + "value": "center" + }, + { + "id": "custom.width", + "value": null + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Last Request" + }, + "properties": [ + { + "id": "unit", + "value": "dateTimeFromNow" + }, + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Endpoint ID" + }, + "properties": [ + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Model" + }, + "properties": [ + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Function URI" + }, + "properties": [ + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Model Class" + }, + "properties": [ + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Predictions/s (5 minute avg)" + }, + "properties": [ + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Average Latency (1 hour)" + }, + "properties": [ + { + "id": "custom.align", + "value": "center" + }, + { + "id": "unit", + "value": "µs" + } + ] + } + ] + }, + "gridPos": { + "h": 3, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 12, + "options": { + "showHeader": true, + "sortBy": [] + }, + "pluginVersion": "7.2.0", + "targets": [ + { + "hide": false, + "rawQuery": true, + "refId": "A", + "target": "project=$PROJECT;target_endpoint=list_endpoints", + "type": "table" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "", + "transformations": [ + { + "id": "organize", + "options": { + "excludeByName": { + "accuracy": true, + "drift_status": true, + "endpoint_function": false, + "endpoint_model": false, + "error_count": true + }, + "indexByName": { + "accuracy": 4, + "drift_status": 6, + "endpoint_function": 2, + "endpoint_id": 0, + "endpoint_model": 1, + "endpoint_model_class": 10, + "error_count": 5, + "first_request": 9, + "last_request": 3, + "latency_avg_1h": 8, + "predictions_per_second": 7 + }, + "renameByName": { + "endpoint_function": "Function URI", + "endpoint_id": "Endpoint ID", + "endpoint_model": "Model", + "endpoint_model_class": "Model Class", + "first_request": "First Request", + "function": "Function", + "function_uri": "Function URI", + "last_request": "Last Request", + "latency_avg_1h": "Average Latency (1 hour)", + "latency_avg_1s": "Average Latency", + "latency_avg_5m": "Average Latency (1 hour)", + "model": "Model", + "model_class": "Model Class", + "predictions_per_second": "Predictions/s (5 minute avg)", + "predictions_per_second_count_1s": "Predictions/sec", + "tag": "Tag" + } + } + } + ], + "transparent": true, + "type": "table" + }, + { + "datasource": "model-monitoring", + "description": "", + "fieldConfig": { + "defaults": { + "custom": { + "align": null, + "displayMode": "auto", + "filterable": false + }, + "mappings": [ + { + "from": "", + "id": 0, + "text": "", + "to": "", + "type": 1 + } + ], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "tvd_sum" + }, + "properties": [ + { + "id": "displayName", + "value": "TVD (sum)" + }, + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "tvd_mean" + }, + "properties": [ + { + "id": "displayName", + "value": "TVD (mean)" + }, + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "hellinger_sum" + }, + "properties": [ + { + "id": "displayName", + "value": "Hellinger (sum)" + }, + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "hellinger_mean" + }, + "properties": [ + { + "id": "displayName", + "value": "Hellinger (mean)" + }, + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "kld_sum" + }, + "properties": [ + { + "id": "displayName", + "value": "KLD (sum)" + }, + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "kld_mean" + }, + "properties": [ + { + "id": "displayName", + "value": "KLD (mean)" + }, + { + "id": "custom.align", + "value": "center" + } + ] + } + ] + }, + "gridPos": { + "h": 3, + "w": 24, + "x": 0, + "y": 3 + }, + "id": 21, + "options": { + "showHeader": true, + "sortBy": [ + { + "desc": false, + "displayName": "name" + } + ] + }, + "pluginVersion": "7.2.0", + "targets": [ + { + "hide": false, + "rawQuery": true, + "refId": "A", + "target": "target_endpoint=overall_feature_analysis;endpoint_id=$MODEL;project=$PROJECT", + "type": "table" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Overall Drift Analysis", + "transformations": [ + { + "id": "organize", + "options": { + "excludeByName": {}, + "indexByName": {}, + "renameByName": { + "endpoint_id": "Endpoint ID", + "first_request": "First Request", + "function": "Function", + "last_request": "Last Request", + "latency_avg_1s": "Average Latency", + "model": "Model", + "model_class": "Model Class", + "predictions_per_second_count_1s": "Predictions/sec", + "tag": "Tag" + } + } + } + ], + "transparent": true, + "type": "table" + }, + { + "datasource": "model-monitoring", + "description": "Feature analysis of the latest batch", + "fieldConfig": { + "defaults": { + "custom": { + "align": "center", + "displayMode": "auto", + "filterable": false + }, + "mappings": [ + { + "from": "", + "id": 0, + "text": "", + "to": "", + "type": 1 + } + ], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Feature" + }, + "properties": [] + }, + { + "matcher": { + "id": "byName", + "options": "Actual Min" + }, + "properties": [] + }, + { + "matcher": { + "id": "byName", + "options": "Expected Min" + }, + "properties": [ + { + "id": "noValue", + "value": "N/A" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Expected Mean" + }, + "properties": [ + { + "id": "noValue", + "value": "N/A" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Expected Max" + }, + "properties": [ + { + "id": "noValue", + "value": "N/A" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "tvd" + }, + "properties": [ + { + "id": "displayName", + "value": "TVD" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "hellinger" + }, + "properties": [ + { + "id": "displayName", + "value": "Hellinger" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "kld" + }, + "properties": [ + { + "id": "displayName", + "value": "KLD" + } + ] + } + ] + }, + "gridPos": { + "h": 7, + "w": 24, + "x": 0, + "y": 6 + }, + "id": 14, + "options": { + "showHeader": true, + "sortBy": [ + { + "desc": false, + "displayName": "Feature" + } + ] + }, + "pluginVersion": "7.2.0", + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "target_endpoint=individual_feature_analysis;endpoint_id=$MODEL;project=$PROJECT", + "type": "table" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Features Analysis", + "transformations": [ + { + "id": "organize", + "options": { + "excludeByName": { + "count": true, + "idx": true, + "model": true + }, + "indexByName": { + "actual_max": 3, + "actual_mean": 2, + "actual_min": 1, + "expected_max": 4, + "expected_mean": 5, + "expected_min": 6, + "feature_name": 0 + }, + "renameByName": { + "actual_max": "Actual Max", + "actual_mean": "Actual Mean", + "actual_min": "Actual Min", + "expected_max": "Expected Min", + "expected_mean": "Expected Mean", + "expected_min": "Expected Max", + "feature_name": "Feature" + } + } + } + ], + "transparent": true, + "type": "table" + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 1, + "gridPos": { + "h": 7, + "w": 24, + "x": 0, + "y": 13 + }, + "hiddenSeries": false, + "id": 16, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "max": false, + "min": false, + "rightSide": true, + "show": true, + "sideWidth": 250, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.2.0", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfilter=endpoint_id=='$MODEL' AND record_type=='endpoint_features';", + "type": "timeserie" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Incoming Features", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "transformations": [], + "transparent": true, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + } + ], + "refresh": "30s", + "schemaVersion": 26, + "style": "dark", + "tags": [], + "templating": { + "list": [ + { + "datasource": "model-monitoring", + "definition": "target_endpoint=list_projects", + "hide": 0, + "includeAll": false, + "label": "Project", + "multi": false, + "name": "PROJECT", + "options": [], + "query": "target_endpoint=list_projects", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": null, + "current": {}, + "datasource": "iguazio", + "definition": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=uid;", + "hide": 0, + "includeAll": false, + "label": "Model", + "multi": false, + "name": "MODEL", + "options": [], + "query": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=uid;", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + } + ] + }, + "time": { + "from": "now-1h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Model Monitoring - Details", + "uid": "AohIXhAMk", + "version": 3 +} \ No newline at end of file diff --git a/docs/monitoring/dashboards/iguazio-3.5.2-and-older/model-monitoring-overview.json b/docs/monitoring/dashboards/iguazio-3.5.2-and-older/model-monitoring-overview.json new file mode 100644 index 000000000000..bb979bb18468 --- /dev/null +++ b/docs/monitoring/dashboards/iguazio-3.5.2-and-older/model-monitoring-overview.json @@ -0,0 +1,836 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "gnetId": null, + "graphTooltip": 0, + "id": 37, + "iteration": 1679742399589, + "links": [ + { + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [], + "title": "Model Monitoring - Performance", + "type": "link", + "url": "/d/9CazA-UGz/model-monitoring-performance" + }, + { + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [], + "targetBlank": false, + "title": "Model Monitoring - Details", + "type": "link", + "url": "d/AohIXhAMk/model-monitoring-details" + } + ], + "panels": [ + { + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {}, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 3, + "w": 5, + "x": 0, + "y": 0 + }, + "id": 6, + "options": { + "colorMode": "value", + "graphMode": "none", + "justifyMode": "center", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "textMode": "value" + }, + "pluginVersion": "7.2.0", + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=uid;", + "type": "table" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Endpoints", + "transformations": [ + { + "id": "reduce", + "options": { + "reducers": [ + "count" + ] + } + } + ], + "transparent": true, + "type": "stat" + }, + { + "datasource": "model-monitoring", + "fieldConfig": { + "defaults": { + "custom": {}, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 3, + "w": 5, + "x": 6, + "y": 0 + }, + "id": 8, + "options": { + "colorMode": "value", + "graphMode": "none", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "textMode": "auto" + }, + "pluginVersion": "7.2.0", + "targets": [ + { + "hide": false, + "rawQuery": true, + "refId": "A", + "target": "project=$PROJECT;target_endpoint=list_endpoints", + "type": "table" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Predictions/s (5 Minute Average)", + "transformations": [ + { + "id": "organize", + "options": { + "excludeByName": { + "accuracy": true, + "drift_status": true, + "endpoint_function": true, + "endpoint_id": true, + "endpoint_model": true, + "endpoint_model_class": true, + "error_count": true, + "first_request": true, + "last_request": true, + "latency_avg_1h": true + }, + "indexByName": {}, + "renameByName": {} + } + } + ], + "transparent": true, + "type": "stat" + }, + { + "datasource": "model-monitoring", + "fieldConfig": { + "defaults": { + "custom": {}, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + }, + "unit": "µs" + }, + "overrides": [] + }, + "gridPos": { + "h": 3, + "w": 5, + "x": 12, + "y": 0 + }, + "id": 25, + "options": { + "colorMode": "value", + "graphMode": "none", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "textMode": "auto" + }, + "pluginVersion": "7.2.0", + "targets": [ + { + "hide": false, + "rawQuery": true, + "refId": "A", + "target": "project=$PROJECT;target_endpoint=list_endpoints", + "type": "table" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Average Latency (Last Hour)", + "transformations": [ + { + "id": "organize", + "options": { + "excludeByName": { + "accuracy": true, + "drift_status": true, + "endpoint_function": true, + "endpoint_id": true, + "endpoint_model": true, + "endpoint_model_class": true, + "error_count": true, + "first_request": true, + "last_request": true, + "latency_avg_1h": false, + "predictions_per_second": true + }, + "indexByName": {}, + "renameByName": { + "latency_avg_1h": "Average Latency (Last Hour)" + } + } + } + ], + "transparent": true, + "type": "stat" + }, + { + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {}, + "mappings": [], + "noValue": "0", + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 3, + "w": 6, + "x": 18, + "y": 0 + }, + "id": 12, + "options": { + "colorMode": "value", + "graphMode": "none", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "mean" + ], + "fields": "", + "values": false + }, + "textMode": "auto" + }, + "pluginVersion": "7.2.0", + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=error_count;", + "type": "table" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Errors", + "transparent": true, + "type": "stat" + }, + { + "datasource": "iguazio", + "description": "", + "fieldConfig": { + "defaults": { + "custom": { + "align": "center", + "displayMode": "auto", + "filterable": true + }, + "mappings": [ + { + "from": "", + "id": 0, + "text": "", + "to": "", + "type": 1 + } + ], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "First Request" + }, + "properties": [ + { + "id": "unit", + "value": "dateTimeFromNow" + }, + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Last Request" + }, + "properties": [ + { + "id": "unit", + "value": "dateTimeFromNow" + }, + { + "id": "custom.align", + "value": "center" + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Endpoint ID" + }, + "properties": [ + { + "id": "links", + "value": [ + { + "title": "Endpoint ID Details", + "url": "/d/AohIXhAMk/model-monitoring-details?orgId=1&refresh=1m&var-PROJECT=$PROJECT&var-MODEL=${__value.text}" + } + ] + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Drift Status" + }, + "properties": [ + { + "id": "mappings", + "value": [ + { + "from": "", + "id": 0, + "text": "0", + "to": "", + "type": 1, + "value": "NO_DRIFT" + }, + { + "from": "", + "id": 1, + "text": "1", + "to": "", + "type": 1, + "value": "POSSIBLE_DRIFT" + }, + { + "from": "", + "id": 2, + "text": "2", + "to": "", + "type": 1, + "value": "DRIFT_DETECTED" + } + ] + } + ] + } + ] + }, + "gridPos": { + "h": 13, + "w": 24, + "x": 0, + "y": 3 + }, + "id": 24, + "options": { + "showHeader": true, + "sortBy": [ + { + "desc": false, + "displayName": "Name" + } + ] + }, + "pluginVersion": "7.2.0", + "targets": [ + { + "hide": false, + "rawQuery": true, + "refId": "A", + "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=uid,function_uri,model,model_class,first_request,last_request,error_count,drift_status", + "type": "table" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Models", + "transformations": [ + { + "id": "organize", + "options": { + "excludeByName": { + "model_hash": false + }, + "indexByName": { + "drift_status": 7, + "error_count": 6, + "first_request": 4, + "function_uri": 1, + "last_request": 5, + "model": 2, + "model_class": 3, + "uid": 0 + }, + "renameByName": { + "drift_status": "Drift Status", + "endpoint_function": "Function", + "endpoint_model": "Model", + "endpoint_model_class": "Model Class", + "endpoint_tag": "Tag", + "error_count": "Error Count", + "first_request": "First Request", + "function": "Function", + "function_uri": "Function", + "last_request": "Last Request", + "latency_avg_1s": "Average Latency", + "model": "Model", + "model_class": "Class", + "predictions_per_second_count_1s": "Predictions/1s", + "tag": "Tag", + "uid": "Endpoint ID" + } + } + } + ], + "type": "table" + }, + { + "cards": { + "cardPadding": null, + "cardRound": null + }, + "color": { + "cardColor": "#b4ff00", + "colorScale": "sqrt", + "colorScheme": "interpolatePlasma", + "exponent": 0.5, + "mode": "spectrum" + }, + "dataFormat": "timeseries", + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {}, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 6, + "w": 8, + "x": 0, + "y": 16 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 18, + "legend": { + "show": false + }, + "pluginVersion": "7.2.0", + "reverseYBuckets": false, + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=predictions_per_second;", + "type": "timeserie" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Predictions/s (5 Minute Average)", + "tooltip": { + "show": true, + "showHistogram": false + }, + "transparent": true, + "type": "heatmap", + "xAxis": { + "show": true + }, + "xBucketNumber": null, + "xBucketSize": null, + "yAxis": { + "decimals": null, + "format": "short", + "logBase": 1, + "max": null, + "min": null, + "show": true, + "splitFactor": null + }, + "yBucketBound": "auto", + "yBucketNumber": null, + "yBucketSize": null + }, + { + "cards": { + "cardPadding": null, + "cardRound": null + }, + "color": { + "cardColor": "#b4ff00", + "colorScale": "sqrt", + "colorScheme": "interpolatePlasma", + "exponent": 0.5, + "mode": "spectrum" + }, + "dataFormat": "timeseries", + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {}, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byType", + "options": "number" + }, + "properties": [ + { + "id": "unit", + "value": "µs" + } + ] + } + ] + }, + "gridPos": { + "h": 6, + "w": 8, + "x": 8, + "y": 16 + }, + "heatmap": {}, + "hideZeroBuckets": false, + "highlightCards": true, + "id": 19, + "legend": { + "show": false + }, + "pluginVersion": "7.2.0", + "reverseYBuckets": false, + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=latency_avg_1h;", + "type": "timeserie" + } + ], + "timeFrom": null, + "timeShift": null, + "title": "Average Latency (1 Hour)", + "tooltip": { + "show": true, + "showHistogram": false + }, + "transparent": true, + "type": "heatmap", + "xAxis": { + "show": true + }, + "xBucketNumber": null, + "xBucketSize": null, + "yAxis": { + "decimals": null, + "format": "short", + "logBase": 1, + "max": null, + "min": null, + "show": true, + "splitFactor": null + }, + "yBucketBound": "auto", + "yBucketNumber": null, + "yBucketSize": null + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 6, + "w": 8, + "x": 16, + "y": 16 + }, + "hiddenSeries": false, + "id": 20, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.2.0", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "refId": "A", + "target": "select metric", + "type": "timeserie" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Errors", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "transparent": true, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + } + ], + "refresh": "30s", + "schemaVersion": 26, + "style": "dark", + "tags": [], + "templating": { + "list": [ + { + "allValue": null, + "current": {}, + "datasource": "model-monitoring", + "definition": "target_endpoint=list_projects", + "hide": 0, + "includeAll": false, + "label": "Project", + "multi": false, + "name": "PROJECT", + "options": [], + "query": "target_endpoint=list_projects", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + } + ] + }, + "time": { + "from": "now-3h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Model Monitoring - Overview", + "uid": "g0M4uh0Mz", + "version": 13 +} \ No newline at end of file diff --git a/docs/monitoring/dashboards/iguazio-3.5.2-and-older/model-monitoring-performance.json b/docs/monitoring/dashboards/iguazio-3.5.2-and-older/model-monitoring-performance.json new file mode 100644 index 000000000000..14bb34e8d319 --- /dev/null +++ b/docs/monitoring/dashboards/iguazio-3.5.2-and-older/model-monitoring-performance.json @@ -0,0 +1,593 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": "-- Grafana --", + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "gnetId": null, + "graphTooltip": 0, + "id": 9, + "iteration": 1627466092078, + "links": [ + { + "asDropdown": true, + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [], + "title": "Model Monitoring - Overview", + "type": "link", + "url": "d/g0M4uh0Mz/model-monitoring-overview" + }, + { + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [], + "targetBlank": false, + "title": "Model Monitoring - Details", + "type": "link", + "url": "d/AohIXhAMk/model-monitoring-details" + } + ], + "panels": [ + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 1, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 0 + }, + "hiddenSeries": false, + "id": 5, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "max": false, + "min": false, + "rightSide": true, + "show": true, + "sideWidth": 250, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.2.0", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfilter=endpoint_id=='$MODEL' AND record_type=='drift_measures';", + "type": "timeserie" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Drift Measures", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "transformations": [], + "transparent": true, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 8, + "w": 12, + "x": 12, + "y": 0 + }, + "hiddenSeries": false, + "id": 6, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.2.0", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=latency_avg_5m,latency_avg_1h;\nfilter=endpoint_id=='$MODEL';", + "type": "timeserie" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Average Latency", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "transparent": true, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 12, + "x": 0, + "y": 8 + }, + "hiddenSeries": false, + "id": 2, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.2.0", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=predictions_per_second;\nfilter=endpoint_id=='$MODEL';", + "type": "timeserie" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Predictions/s (5 minute average)", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "transparent": true, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 0, + "gridPos": { + "h": 7, + "w": 12, + "x": 12, + "y": 8 + }, + "hiddenSeries": false, + "id": 7, + "legend": { + "avg": false, + "current": false, + "max": false, + "min": false, + "show": true, + "total": false, + "values": false + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.2.0", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=predictions_count_5m,predictions_count_1h;\nfilter=endpoint_id=='$MODEL';", + "type": "timeserie" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Predictions Count", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "transparent": true, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "iguazio", + "fieldConfig": { + "defaults": { + "custom": {} + }, + "overrides": [] + }, + "fill": 1, + "fillGradient": 1, + "gridPos": { + "h": 7, + "w": 24, + "x": 0, + "y": 15 + }, + "hiddenSeries": false, + "id": 4, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "max": false, + "min": false, + "rightSide": true, + "show": true, + "sideWidth": 250, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "nullPointMode": "null", + "options": { + "alertThreshold": true + }, + "percentage": false, + "pluginVersion": "7.2.0", + "pointradius": 2, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "rawQuery": true, + "refId": "A", + "target": "backend=tsdb; container=users; table=pipelines/$PROJECT/model-endpoints/events; filter=endpoint_id=='$MODEL' AND record_type=='custom_metrics';", + "type": "timeserie" + } + ], + "thresholds": [], + "timeFrom": null, + "timeRegions": [], + "timeShift": null, + "title": "Custom Metrics", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "transformations": [], + "transparent": true, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + } + ], + "refresh": "30s", + "schemaVersion": 26, + "style": "dark", + "tags": [], + "templating": { + "list": [ + { + "allValue": null, + "current": {}, + "datasource": "model-monitoring", + "definition": "target_endpoint=list_projects", + "hide": 0, + "includeAll": false, + "label": "Project", + "multi": false, + "name": "PROJECT", + "options": [], + "query": "target_endpoint=list_projects", + "refresh": 1, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": null, + "current": {}, + "datasource": "iguazio", + "definition": "backend=kv;container=users;table=pipelines/$PROJECT/model-endpoints/endpoints;fields=uid;", + "hide": 0, + "includeAll": false, + "label": "Model", + "multi": false, + "name": "MODEL", + "options": [], + "query": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=uid;", + "refresh": 0, + "regex": "", + "skipUrlSync": false, + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + } + ] + }, + "time": { + "from": "now-1h", + "to": "now" + }, + "timepicker": {}, + "timezone": "", + "title": "Model Monitoring - Performance", + "uid": "9CazA-UGz", + "version": 2 +} \ No newline at end of file diff --git a/docs/monitoring/dashboards/model-monitoring-details.json b/docs/monitoring/dashboards/model-monitoring-details.json index 475c43464f4f..4868e67c83d6 100644 --- a/docs/monitoring/dashboards/model-monitoring-details.json +++ b/docs/monitoring/dashboards/model-monitoring-details.json @@ -3,43 +3,50 @@ "list": [ { "builtIn": 1, - "datasource": "-- Grafana --", + "datasource": { + "type": "datasource", + "uid": "grafana" + }, "enable": true, "hide": true, "iconColor": "rgba(0, 211, 255, 1)", "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, "type": "dashboard" } ] }, "editable": true, - "gnetId": null, + "fiscalYearStartMonth": 0, "graphTooltip": 0, - "id": 8, - "iteration": 1627466479152, + "id": 33, "links": [ { "icon": "external link", "includeVars": true, "keepTime": true, "tags": [], - "targetBlank": true, "title": "Model Monitoring - Performance", "type": "link", "url": "/d/9CazA-UGz/model-monitoring-performance" }, { - "icon": "dashboard", - "includeVars": false, + "asDropdown": true, + "icon": "external link", + "includeVars": true, "keepTime": true, "tags": [], - "targetBlank": true, "title": "Model Monitoring - Overview", - "tooltip": "", "type": "link", - "url": "/d/g0M4uh0Mz" + "url": "d/g0M4uh0Mz/model-monitoring-overview" } ], + "liveNow": false, "panels": [ { "datasource": "iguazio", @@ -47,25 +54,17 @@ "fieldConfig": { "defaults": { "custom": { - "align": null, + "align": "auto", "displayMode": "auto", - "filterable": false + "filterable": false, + "inspect": false }, - "mappings": [ - { - "from": "", - "id": 0, - "text": "", - "to": "", - "type": 1 - } - ], + "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" }, { "color": "red", @@ -191,8 +190,15 @@ "x": 0, "y": 0 }, - "id": 12, + "id": 22, "options": { + "footer": { + "fields": "", + "reducer": [ + "sum" + ], + "show": false + }, "showHeader": true, "sortBy": [ { @@ -201,24 +207,40 @@ } ] }, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "targets": [ { + "datasource": "iguazio", "hide": false, "rawQuery": true, "refId": "A", - "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfilter=endpoint_id==\"$MODEL\";\nfields=endpoint_id,model,function_uri,model_class,predictions_per_second,latency_avg_1h,first_request,last_request;", + "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfilter=uid==\"$MODELENDPOINT\";\nfields=uid,model,function_uri,model_class,first_request,metrics,last_request;", "type": "table" } ], - "timeFrom": null, - "timeShift": null, - "title": "", "transformations": [ + { + "id": "extractFields", + "options": { + "source": "metrics" + } + }, + { + "id": "extractFields", + "options": { + "source": "generic" + } + }, { "id": "organize", "options": { - "excludeByName": {}, + "excludeByName": { + "generic": true, + "latency_avg_5m": true, + "metrics": true, + "predictions_count_1h": true, + "predictions_count_5m": true + }, "indexByName": {}, "renameByName": { "endpoint_id": "Endpoint ID", @@ -228,12 +250,14 @@ "last_request": "Last Request", "latency_avg_1h": "Average Latency (1 hour)", "latency_avg_1s": "Average Latency", - "latency_avg_5m": "Average Latency (1 hour)", + "latency_avg_5m": "", + "metrics": "", "model": "Model", "model_class": "Model Class", "predictions_per_second": "Predictions/s (5 minute avg)", "predictions_per_second_count_1s": "Predictions/sec", - "tag": "Tag" + "tag": "Tag", + "uid": "Endpoint ID" } } } @@ -242,30 +266,22 @@ "type": "table" }, { - "datasource": "model-monitoring", + "datasource": "iguazio", "description": "", "fieldConfig": { "defaults": { "custom": { - "align": null, + "align": "auto", "displayMode": "auto", - "filterable": false + "filterable": false, + "inspect": false }, - "mappings": [ - { - "from": "", - "id": 0, - "text": "", - "to": "", - "type": 1 - } - ], + "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" }, { "color": "red", @@ -370,6 +386,22 @@ "value": "center" } ] + }, + { + "matcher": { + "id": "byName", + "options": "drift_measures" + }, + "properties": [ + { + "id": "custom.hidden", + "value": false + }, + { + "id": "mappings", + "value": [] + } + ] } ] }, @@ -379,8 +411,15 @@ "x": 0, "y": 3 }, - "id": 21, + "id": 25, "options": { + "footer": { + "fields": "", + "reducer": [ + "sum" + ], + "show": false + }, "showHeader": true, "sortBy": [ { @@ -389,35 +428,39 @@ } ] }, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "targets": [ { + "datasource": "iguazio", "hide": false, "rawQuery": true, "refId": "A", - "target": "target_endpoint=overall_feature_analysis;endpoint_id=$MODEL;project=$PROJECT", + "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfilter=uid==\"$MODELENDPOINT\";\nfields=drift_measures;", "type": "table" } ], - "timeFrom": null, - "timeShift": null, "title": "Overall Drift Analysis", "transformations": [ { - "id": "organize", + "id": "extractFields", "options": { - "excludeByName": {}, - "indexByName": {}, - "renameByName": { - "endpoint_id": "Endpoint ID", - "first_request": "First Request", - "function": "Function", - "last_request": "Last Request", - "latency_avg_1s": "Average Latency", - "model": "Model", - "model_class": "Model Class", - "predictions_per_second_count_1s": "Predictions/sec", - "tag": "Tag" + "format": "json", + "replace": false, + "source": "drift_measures" + } + }, + { + "id": "filterFieldsByName", + "options": { + "include": { + "names": [ + "tvd_sum", + "tvd_mean", + "hellinger_sum", + "hellinger_mean", + "kld_sum", + "kld_mean" + ] } } } @@ -426,30 +469,22 @@ "type": "table" }, { - "datasource": "model-monitoring", + "datasource": "iguazio", "description": "Feature analysis of the latest batch", "fieldConfig": { "defaults": { "custom": { "align": "center", "displayMode": "auto", - "filterable": false + "filterable": false, + "inspect": false }, - "mappings": [ - { - "from": "", - "id": 0, - "text": "", - "to": "", - "type": 1 - } - ], + "mappings": [], "thresholds": { "mode": "absolute", "steps": [ { - "color": "green", - "value": null + "color": "green" }, { "color": "red", @@ -459,20 +494,6 @@ } }, "overrides": [ - { - "matcher": { - "id": "byName", - "options": "Feature" - }, - "properties": [] - }, - { - "matcher": { - "id": "byName", - "options": "Actual Min" - }, - "properties": [] - }, { "matcher": { "id": "byName", @@ -553,55 +574,349 @@ "x": 0, "y": 6 }, - "id": 14, + "id": 29, "options": { + "footer": { + "fields": "", + "reducer": [ + "sum" + ], + "show": false + }, "showHeader": true, "sortBy": [ { - "desc": false, - "displayName": "Feature" + "desc": true, + "displayName": "current_stats" } ] }, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", - "target": "target_endpoint=individual_feature_analysis;endpoint_id=$MODEL;project=$PROJECT", + "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfilter=uid==\"$MODELENDPOINT\";\nfields= current_stats;", + "type": "table" + }, + { + "datasource": "iguazio", + "hide": false, + "refId": "B", + "target": "backend=kv; container=users; table=pipelines/$PROJECT/model-endpoints/endpoints; filter=uid==\"$MODELENDPOINT\"; fields= feature_stats;", + "type": "table" + }, + { + "datasource": "iguazio", + "hide": false, + "refId": "C", + "target": "backend=kv; container=users; table=pipelines/$PROJECT/model-endpoints/endpoints; filter=uid==\"$MODELENDPOINT\"; fields= drift_measures;", "type": "table" } ], - "timeFrom": null, - "timeShift": null, "title": "Features Analysis", "transformations": [ + { + "id": "extractFields", + "options": { + "format": "auto", + "replace": false, + "source": "current_stats" + } + }, + { + "id": "extractFields", + "options": { + "format": "auto", + "source": "feature_stats" + } + }, + { + "id": "extractFields", + "options": { + "replace": false, + "source": "drift_measures" + } + }, + { + "id": "merge", + "options": {} + }, + { + "id": "reduce", + "options": { + "includeTimeField": false, + "labelsToFields": false, + "mode": "seriesToRows", + "reducers": [ + "allValues" + ] + } + }, + { + "id": "filterByValue", + "options": { + "filters": [ + { + "config": { + "id": "equal", + "options": { + "value": "feature_stats" + } + }, + "fieldName": "Field" + }, + { + "config": { + "id": "equal", + "options": { + "value": "current_stats" + } + }, + "fieldName": "Field" + }, + { + "config": { + "id": "equal", + "options": { + "value": "timestamp" + } + }, + "fieldName": "Field" + }, + { + "config": { + "id": "equal", + "options": { + "value": "drift_measures" + } + }, + "fieldName": "Field" + }, + { + "config": { + "id": "equal", + "options": { + "value": "kld_sum" + } + }, + "fieldName": "Field" + }, + { + "config": { + "id": "equal", + "options": { + "value": "kld_mean" + } + }, + "fieldName": "Field" + }, + { + "config": { + "id": "equal", + "options": { + "value": "tvd_mean" + } + }, + "fieldName": "Field" + }, + { + "config": { + "id": "equal", + "options": { + "value": "tvd_sum" + } + }, + "fieldName": "Field" + }, + { + "config": { + "id": "equal", + "options": { + "value": "hellinger_sum" + } + }, + "fieldName": "Field" + }, + { + "config": { + "id": "equal", + "options": { + "value": "hellinger_mean" + } + }, + "fieldName": "Field" + } + ], + "match": "any", + "type": "exclude" + } + }, + { + "id": "extractFields", + "options": { + "replace": false, + "source": "All values" + } + }, + { + "id": "filterFieldsByName", + "options": { + "include": { + "names": [ + "Field", + "0", + "1", + "2" + ] + } + } + }, + { + "id": "extractFields", + "options": { + "replace": false, + "source": "0" + } + }, + { + "id": "filterByValue", + "options": { + "filters": [ + { + "config": { + "id": "isNull", + "options": {} + }, + "fieldName": "1" + }, + { + "config": { + "id": "greater", + "options": { + "value": 0 + } + }, + "fieldName": "2" + } + ], + "match": "any", + "type": "exclude" + } + }, + { + "id": "extractFields", + "options": { + "format": "json", + "source": "1" + } + }, + { + "id": "extractFields", + "options": { + "source": "2" + } + }, + { + "id": "filterFieldsByName", + "options": { + "include": { + "names": [ + "Field", + "mean 1", + "min 1", + "max 1", + "mean 2", + "min 2", + "max 2", + "tvd", + "hellinger", + "kld" + ] + } + } + }, + { + "id": "renameByRegex", + "options": { + "regex": "mean 1", + "renamePattern": "Actual Mean" + } + }, + { + "id": "renameByRegex", + "options": { + "regex": "min 1", + "renamePattern": "Actual Min" + } + }, + { + "id": "renameByRegex", + "options": { + "regex": "max 1", + "renamePattern": "Actual Max" + } + }, + { + "id": "renameByRegex", + "options": { + "regex": "mean 2", + "renamePattern": "Expected Mean" + } + }, + { + "id": "renameByRegex", + "options": { + "regex": "min 2", + "renamePattern": "Expected Min" + } + }, + { + "id": "renameByRegex", + "options": { + "regex": "max 2", + "renamePattern": "Expected Max" + } + }, + { + "id": "renameByRegex", + "options": { + "regex": "tvd", + "renamePattern": "TVD" + } + }, + { + "id": "renameByRegex", + "options": { + "regex": "hellinger", + "renamePattern": "Hellinger" + } + }, + { + "id": "renameByRegex", + "options": { + "regex": "kld", + "renamePattern": "KLD" + } + }, { "id": "organize", "options": { - "excludeByName": { - "count": true, - "idx": true, - "model": true - }, + "excludeByName": {}, "indexByName": { - "actual_max": 3, - "actual_mean": 2, - "actual_min": 1, - "expected_max": 4, - "expected_mean": 5, - "expected_min": 6, - "feature_name": 0 + "Actual Max": 6, + "Actual Mean": 2, + "Actual Min": 4, + "Expected Max": 5, + "Expected Mean": 1, + "Expected Min": 3, + "Field": 0, + "Hellinger": 8, + "KLD": 9, + "TVD": 7 }, - "renameByName": { - "actual_max": "Actual Max", - "actual_mean": "Actual Mean", - "actual_min": "Actual Min", - "expected_max": "Expected Min", - "expected_mean": "Expected Mean", - "expected_min": "Expected Max", - "feature_name": "Feature" - } + "renameByName": {} } } ], @@ -614,12 +929,6 @@ "dashLength": 10, "dashes": false, "datasource": "iguazio", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, "fill": 1, "fillGradient": 1, "gridPos": { @@ -649,7 +958,7 @@ "alertThreshold": true }, "percentage": false, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "pointradius": 2, "points": false, "renderer": "flot", @@ -659,16 +968,15 @@ "steppedLine": false, "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", - "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfilter=endpoint_id=='$MODEL' AND record_type=='endpoint_features';", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfilter=endpoint_id=='$MODELENDPOINT' AND record_type=='endpoint_features';", "type": "timeserie" } ], "thresholds": [], - "timeFrom": null, "timeRegions": [], - "timeShift": null, "title": "Incoming Features", "tooltip": { "shared": true, @@ -679,44 +987,34 @@ "transparent": true, "type": "graph", "xaxis": { - "buckets": null, "mode": "time", - "name": null, "show": true, "values": [] }, "yaxes": [ { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true }, { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true } ], "yaxis": { - "align": false, - "alignLevel": null + "align": false } } ], - "refresh": "1m", - "schemaVersion": 26, + "refresh": "30s", + "schemaVersion": 37, "style": "dark", "tags": [], "templating": { "list": [ { - "allValue": null, "current": {}, "datasource": "model-monitoring", "definition": "target_endpoint=list_projects", @@ -727,34 +1025,31 @@ "name": "PROJECT", "options": [], "query": "target_endpoint=list_projects", - "refresh": 0, + "refresh": 1, "regex": "", "skipUrlSync": false, "sort": 0, "tagValuesQuery": "", - "tags": [], "tagsQuery": "", "type": "query", "useTags": false }, { - "allValue": null, "current": {}, "datasource": "iguazio", - "definition": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=endpoint_id;", + "definition": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=uid;", "hide": 0, "includeAll": false, - "label": "Model", + "label": "Model Endpoint", "multi": false, - "name": "MODEL", + "name": "MODELENDPOINT", "options": [], - "query": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=endpoint_id;", + "query": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=uid;", "refresh": 1, "regex": "", "skipUrlSync": false, "sort": 0, "tagValuesQuery": "", - "tags": [], "tagsQuery": "", "type": "query", "useTags": false @@ -769,5 +1064,6 @@ "timezone": "", "title": "Model Monitoring - Details", "uid": "AohIXhAMk", - "version": 3 -} + "version": 9, + "weekStart": "" +} \ No newline at end of file diff --git a/docs/monitoring/dashboards/model-monitoring-overview.json b/docs/monitoring/dashboards/model-monitoring-overview.json index 0821a9225537..7f3829118aee 100644 --- a/docs/monitoring/dashboards/model-monitoring-overview.json +++ b/docs/monitoring/dashboards/model-monitoring-overview.json @@ -3,20 +3,28 @@ "list": [ { "builtIn": 1, - "datasource": "-- Grafana --", + "datasource": { + "type": "datasource", + "uid": "grafana" + }, "enable": true, "hide": true, "iconColor": "rgba(0, 211, 255, 1)", "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, "type": "dashboard" } ] }, "editable": true, - "gnetId": null, + "fiscalYearStartMonth": 0, "graphTooltip": 0, - "id": 7, - "iteration": 1627466285618, + "id": 31, "links": [ { "icon": "external link", @@ -28,20 +36,22 @@ "url": "/d/9CazA-UGz/model-monitoring-performance" }, { - "icon": "info", + "icon": "external link", + "includeVars": true, "keepTime": true, "tags": [], - "title": "Model Alerts", + "targetBlank": false, + "title": "Model Monitoring - Details", "type": "link", - "url": "/d/q6GvXh0Gz/model-alerts" + "url": "d/AohIXhAMk/model-monitoring-details" } ], + "liveNow": false, "panels": [ { "datasource": "iguazio", "fieldConfig": { "defaults": { - "custom": {}, "mappings": [], "thresholds": { "mode": "absolute", @@ -80,17 +90,16 @@ }, "textMode": "value" }, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", - "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=endpoint_id;", + "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=uid;", "type": "table" } ], - "timeFrom": null, - "timeShift": null, "title": "Endpoints", "transformations": [ { @@ -109,7 +118,6 @@ "datasource": "iguazio", "fieldConfig": { "defaults": { - "custom": {}, "mappings": [], "thresholds": { "mode": "absolute", @@ -148,19 +156,49 @@ }, "textMode": "auto" }, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "targets": [ { + "datasource": "iguazio", "hide": false, "rawQuery": true, "refId": "A", - "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=predictions_per_second;", + "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=metrics;", "type": "table" } ], - "timeFrom": null, - "timeShift": null, "title": "Predictions/s (5 Minute Average)", + "transformations": [ + { + "id": "extractFields", + "options": { + "source": "metrics" + } + }, + { + "id": "extractFields", + "options": { + "source": "generic" + } + }, + { + "id": "organize", + "options": { + "excludeByName": { + "generic": true, + "latency_avg_1h": true, + "latency_avg_5m": true, + "metrics": true, + "predictions_count_1h": true, + "predictions_count_5m": true + }, + "indexByName": {}, + "renameByName": { + "predictions_per_second": "Predictions/s (5 Minute Average)" + } + } + } + ], "transparent": true, "type": "stat" }, @@ -168,7 +206,6 @@ "datasource": "iguazio", "fieldConfig": { "defaults": { - "custom": {}, "mappings": [], "thresholds": { "mode": "absolute", @@ -186,10 +223,10 @@ "gridPos": { "h": 3, "w": 5, - "x": 12, + "x": 11, "y": 0 }, - "id": 10, + "id": 23, "options": { "colorMode": "value", "graphMode": "none", @@ -204,25 +241,48 @@ }, "textMode": "auto" }, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "targets": [ { + "datasource": "iguazio", + "hide": false, "rawQuery": true, "refId": "A", - "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=latency_avg_1h;", + "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=metrics;", "type": "table" } ], - "timeFrom": null, - "timeShift": null, "title": "Average Latency (Last Hour)", "transformations": [ { - "id": "reduce", + "id": "extractFields", "options": { - "reducers": [ - "mean" - ] + "source": "metrics" + } + }, + { + "id": "extractFields", + "options": { + "source": "generic" + } + }, + { + "id": "organize", + "options": { + "excludeByName": { + "generic": true, + "latency_avg_1h": false, + "latency_avg_5m": true, + "metrics": true, + "predictions_count_1h": true, + "predictions_count_5m": true, + "predictions_per_second": true + }, + "indexByName": {}, + "renameByName": { + "latency_avg_1h": "Average Latency (Last Hour)", + "predictions_per_second": "Predictions/s (5 Minute Average)" + } } } ], @@ -233,7 +293,6 @@ "datasource": "iguazio", "fieldConfig": { "defaults": { - "custom": {}, "mappings": [], "noValue": "0", "thresholds": { @@ -266,47 +325,39 @@ "orientation": "auto", "reduceOptions": { "calcs": [ - "mean" + "sum" ], "fields": "", "values": false }, "textMode": "auto" }, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=error_count;", "type": "table" } ], - "timeFrom": null, - "timeShift": null, "title": "Errors", "transparent": true, "type": "stat" }, { - "datasource": "model-monitoring", + "datasource": "iguazio", "description": "", "fieldConfig": { "defaults": { "custom": { "align": "center", "displayMode": "auto", - "filterable": true + "filterable": true, + "inspect": false }, - "mappings": [ - { - "from": "", - "id": 0, - "text": "", - "to": "", - "type": 1 - } - ], + "mappings": [], "thresholds": { "mode": "absolute", "steps": [ @@ -325,7 +376,7 @@ { "matcher": { "id": "byName", - "options": "Function" + "options": "function_uri" }, "properties": [ { @@ -452,36 +503,28 @@ "id": "mappings", "value": [ { - "from": "", - "id": 0, - "text": "0", - "to": "", - "type": 1, - "value": "NO_DRIFT" - }, - { - "from": "", - "id": 1, - "text": "1", - "to": "", - "type": 1, - "value": "POSSIBLE_DRIFT" - }, - { - "from": "", - "id": 2, - "text": "2", - "to": "", - "type": 1, - "value": "DRIFT_DETECTED" - }, - { - "from": "", - "id": 3, - "text": "-1", - "to": "", - "type": 1, - "value": "N\\A" + "options": { + "DRIFT_DETECTED": { + "color": "red", + "index": 3, + "text": "2" + }, + "NO_DRIFT": { + "color": "green", + "index": 2, + "text": "0" + }, + "N\\A": { + "index": 1, + "text": "-1" + }, + "POSSIBLE_DRIFT": { + "color": "yellow", + "index": 0, + "text": "1" + } + }, + "type": "value" } ] }, @@ -543,6 +586,13 @@ }, "id": 22, "options": { + "footer": { + "fields": "", + "reducer": [ + "sum" + ], + "show": false + }, "showHeader": true, "sortBy": [ { @@ -551,18 +601,17 @@ } ] }, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "targets": [ { + "datasource": "iguazio", "hide": false, "rawQuery": true, "refId": "A", - "target": "project=$PROJECT;target_endpoint=list_endpoints", + "target": "backend=kv;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/endpoints;\nfields=uid,model,function_uri,model_class,first_request,last_request,error_count,drift_status;", "type": "table" } ], - "timeFrom": null, - "timeShift": null, "title": "Models", "transformations": [ { @@ -572,16 +621,14 @@ "model_hash": false }, "indexByName": { - "accuracy": 8, - "drift_status": 9, - "endpoint_function": 1, - "endpoint_id": 0, - "endpoint_model": 2, - "endpoint_model_class": 3, - "endpoint_tag": 4, - "error_count": 7, - "first_request": 5, - "last_request": 6 + "drift_status": 7, + "error_count": 6, + "first_request": 4, + "function_uri": 1, + "last_request": 5, + "model": 2, + "model_class": 3, + "uid": 0 }, "renameByName": { "accuracy": "Accuracy", @@ -594,12 +641,14 @@ "error_count": "Error Count", "first_request": "First Request", "function": "Function", + "function_uri": "Function", "last_request": "Last Request", "latency_avg_1s": "Average Latency", "model": "Model", "model_class": "Class", "predictions_per_second_count_1s": "Predictions/1s", - "tag": "Tag" + "tag": "Tag", + "uid": "Endpoint ID" } } } @@ -607,10 +656,7 @@ "type": "table" }, { - "cards": { - "cardPadding": null, - "cardRound": null - }, + "cards": {}, "color": { "cardColor": "#b4ff00", "colorScale": "sqrt", @@ -622,20 +668,15 @@ "datasource": "iguazio", "fieldConfig": { "defaults": { - "custom": {}, - "mappings": [], - "thresholds": { - "mode": "absolute", - "steps": [ - { - "color": "green", - "value": null - }, - { - "color": "red", - "value": 80 - } - ] + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false + }, + "scaleDistribution": { + "type": "linear" + } } }, "overrides": [] @@ -653,18 +694,54 @@ "legend": { "show": false }, - "pluginVersion": "7.2.0", + "options": { + "calculate": true, + "calculation": {}, + "cellGap": 2, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#b4ff00", + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Plasma", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": false + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "show": true, + "yHistogram": false + }, + "yAxis": { + "axisPlacement": "left", + "reverse": false, + "unit": "short" + } + }, + "pluginVersion": "9.2.2", "reverseYBuckets": false, "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=predictions_per_second;", "type": "timeserie" } ], - "timeFrom": null, - "timeShift": null, "title": "Predictions/s (5 Minute Average)", "tooltip": { "show": true, @@ -675,26 +752,15 @@ "xAxis": { "show": true }, - "xBucketNumber": null, - "xBucketSize": null, "yAxis": { - "decimals": null, "format": "short", "logBase": 1, - "max": null, - "min": null, - "show": true, - "splitFactor": null + "show": true }, - "yBucketBound": "auto", - "yBucketNumber": null, - "yBucketSize": null + "yBucketBound": "auto" }, { - "cards": { - "cardPadding": null, - "cardRound": null - }, + "cards": {}, "color": { "cardColor": "#b4ff00", "colorScale": "sqrt", @@ -706,36 +772,18 @@ "datasource": "iguazio", "fieldConfig": { "defaults": { - "custom": {}, - "mappings": [], - "thresholds": { - "mode": "absolute", - "steps": [ - { - "color": "green", - "value": null - }, - { - "color": "red", - "value": 80 - } - ] - } - }, - "overrides": [ - { - "matcher": { - "id": "byType", - "options": "number" + "custom": { + "hideFrom": { + "legend": false, + "tooltip": false, + "viz": false }, - "properties": [ - { - "id": "unit", - "value": "µs" - } - ] + "scaleDistribution": { + "type": "linear" + } } - ] + }, + "overrides": [] }, "gridPos": { "h": 6, @@ -750,18 +798,54 @@ "legend": { "show": false }, - "pluginVersion": "7.2.0", + "options": { + "calculate": true, + "calculation": {}, + "cellGap": 2, + "cellValues": {}, + "color": { + "exponent": 0.5, + "fill": "#b4ff00", + "mode": "scheme", + "reverse": false, + "scale": "exponential", + "scheme": "Plasma", + "steps": 128 + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "filterValues": { + "le": 1e-9 + }, + "legend": { + "show": false + }, + "rowsFrame": { + "layout": "auto" + }, + "showValue": "never", + "tooltip": { + "show": true, + "yHistogram": false + }, + "yAxis": { + "axisPlacement": "left", + "reverse": false, + "unit": "short" + } + }, + "pluginVersion": "9.2.2", "reverseYBuckets": false, "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=latency_avg_1h;", "type": "timeserie" } ], - "timeFrom": null, - "timeShift": null, "title": "Average Latency (1 Hour)", "tooltip": { "show": true, @@ -772,20 +856,12 @@ "xAxis": { "show": true }, - "xBucketNumber": null, - "xBucketSize": null, "yAxis": { - "decimals": null, "format": "short", "logBase": 1, - "max": null, - "min": null, - "show": true, - "splitFactor": null + "show": true }, - "yBucketBound": "auto", - "yBucketNumber": null, - "yBucketSize": null + "yBucketBound": "auto" }, { "aliasColors": {}, @@ -793,12 +869,6 @@ "dashLength": 10, "dashes": false, "datasource": "iguazio", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, "fill": 1, "fillGradient": 0, "gridPos": { @@ -825,7 +895,7 @@ "alertThreshold": true }, "percentage": false, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "pointradius": 2, "points": false, "renderer": "flot", @@ -835,15 +905,14 @@ "steppedLine": false, "targets": [ { + "datasource": "iguazio", "refId": "A", "target": "select metric", "type": "timeserie" } ], "thresholds": [], - "timeFrom": null, "timeRegions": [], - "timeShift": null, "title": "Errors", "tooltip": { "shared": true, @@ -853,44 +922,34 @@ "transparent": true, "type": "graph", "xaxis": { - "buckets": null, "mode": "time", - "name": null, "show": true, "values": [] }, "yaxes": [ { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true }, { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true } ], "yaxis": { - "align": false, - "alignLevel": null + "align": false } } ], - "refresh": "5s", - "schemaVersion": 26, + "refresh": "30s", + "schemaVersion": 37, "style": "dark", "tags": [], "templating": { "list": [ { - "allValue": null, "current": {}, "datasource": "model-monitoring", "definition": "target_endpoint=list_projects", @@ -906,7 +965,6 @@ "skipUrlSync": false, "sort": 0, "tagValuesQuery": "", - "tags": [], "tagsQuery": "", "type": "query", "useTags": false @@ -921,5 +979,6 @@ "timezone": "", "title": "Model Monitoring - Overview", "uid": "g0M4uh0Mz", - "version": 2 -} + "version": 2, + "weekStart": "" +} \ No newline at end of file diff --git a/docs/monitoring/dashboards/model-monitoring-performance.json b/docs/monitoring/dashboards/model-monitoring-performance.json index ab343c5055dc..1956e3fcafc9 100644 --- a/docs/monitoring/dashboards/model-monitoring-performance.json +++ b/docs/monitoring/dashboards/model-monitoring-performance.json @@ -3,25 +3,33 @@ "list": [ { "builtIn": 1, - "datasource": "-- Grafana --", + "datasource": { + "type": "datasource", + "uid": "grafana" + }, "enable": true, "hide": true, "iconColor": "rgba(0, 211, 255, 1)", "name": "Annotations & Alerts", + "target": { + "limit": 100, + "matchAny": false, + "tags": [], + "type": "dashboard" + }, "type": "dashboard" } ] }, "editable": true, - "gnetId": null, + "fiscalYearStartMonth": 0, "graphTooltip": 0, - "id": 9, - "iteration": 1627466092078, + "id": 32, "links": [ { "asDropdown": true, "icon": "external link", - "includeVars": false, + "includeVars": true, "keepTime": true, "tags": [], "title": "Model Monitoring - Overview", @@ -34,11 +42,12 @@ "keepTime": true, "tags": [], "targetBlank": false, - "title": "Model Monitoring Details", + "title": "Model Monitoring - Details", "type": "link", "url": "d/AohIXhAMk/model-monitoring-details" } ], + "liveNow": false, "panels": [ { "aliasColors": {}, @@ -46,12 +55,6 @@ "dashLength": 10, "dashes": false, "datasource": "iguazio", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, "fill": 1, "fillGradient": 1, "gridPos": { @@ -81,7 +84,7 @@ "alertThreshold": true }, "percentage": false, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "pointradius": 2, "points": false, "renderer": "flot", @@ -91,16 +94,15 @@ "steppedLine": false, "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", - "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfilter=endpoint_id=='$MODEL' AND record_type=='drift_measures';", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfilter=endpoint_id=='$MODELENDPOINT' AND record_type=='drift_measures';", "type": "timeserie" } ], "thresholds": [], - "timeFrom": null, "timeRegions": [], - "timeShift": null, "title": "Drift Measures", "tooltip": { "shared": true, @@ -111,33 +113,24 @@ "transparent": true, "type": "graph", "xaxis": { - "buckets": null, "mode": "time", - "name": null, "show": true, "values": [] }, "yaxes": [ { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true }, { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true } ], "yaxis": { - "align": false, - "alignLevel": null + "align": false } }, { @@ -146,12 +139,6 @@ "dashLength": 10, "dashes": false, "datasource": "iguazio", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, "fill": 1, "fillGradient": 0, "gridPos": { @@ -178,7 +165,7 @@ "alertThreshold": true }, "percentage": false, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "pointradius": 2, "points": false, "renderer": "flot", @@ -188,16 +175,15 @@ "steppedLine": false, "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", - "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=latency_avg_5m,latency_avg_1h;\nfilter=endpoint_id=='$MODEL';", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=latency_avg_5m,latency_avg_1h;\nfilter=endpoint_id=='$MODELENDPOINT';", "type": "timeserie" } ], "thresholds": [], - "timeFrom": null, "timeRegions": [], - "timeShift": null, "title": "Average Latency", "tooltip": { "shared": true, @@ -207,33 +193,24 @@ "transparent": true, "type": "graph", "xaxis": { - "buckets": null, "mode": "time", - "name": null, "show": true, "values": [] }, "yaxes": [ { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true }, { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true } ], "yaxis": { - "align": false, - "alignLevel": null + "align": false } }, { @@ -242,12 +219,6 @@ "dashLength": 10, "dashes": false, "datasource": "iguazio", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, "fill": 1, "fillGradient": 0, "gridPos": { @@ -274,7 +245,7 @@ "alertThreshold": true }, "percentage": false, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "pointradius": 2, "points": false, "renderer": "flot", @@ -284,16 +255,15 @@ "steppedLine": false, "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", - "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=predictions_per_second;\nfilter=endpoint_id=='$MODEL';", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=predictions_per_second;\nfilter=endpoint_id=='$MODELENDPOINT';", "type": "timeserie" } ], "thresholds": [], - "timeFrom": null, "timeRegions": [], - "timeShift": null, "title": "Predictions/s (5 minute average)", "tooltip": { "shared": true, @@ -303,33 +273,24 @@ "transparent": true, "type": "graph", "xaxis": { - "buckets": null, "mode": "time", - "name": null, "show": true, "values": [] }, "yaxes": [ { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true }, { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true } ], "yaxis": { - "align": false, - "alignLevel": null + "align": false } }, { @@ -338,12 +299,6 @@ "dashLength": 10, "dashes": false, "datasource": "iguazio", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, "fill": 1, "fillGradient": 0, "gridPos": { @@ -370,7 +325,7 @@ "alertThreshold": true }, "percentage": false, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "pointradius": 2, "points": false, "renderer": "flot", @@ -380,16 +335,15 @@ "steppedLine": false, "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", - "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=predictions_count_5m,predictions_count_1h;\nfilter=endpoint_id=='$MODEL';", + "target": "backend=tsdb;\ncontainer=users;\ntable=pipelines/$PROJECT/model-endpoints/events;\nfields=predictions_count_5m,predictions_count_1h;\nfilter=endpoint_id=='$MODELENDPOINT';", "type": "timeserie" } ], "thresholds": [], - "timeFrom": null, "timeRegions": [], - "timeShift": null, "title": "Predictions Count", "tooltip": { "shared": true, @@ -399,33 +353,24 @@ "transparent": true, "type": "graph", "xaxis": { - "buckets": null, "mode": "time", - "name": null, "show": true, "values": [] }, "yaxes": [ { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true }, { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true } ], "yaxis": { - "align": false, - "alignLevel": null + "align": false } }, { @@ -434,12 +379,6 @@ "dashLength": 10, "dashes": false, "datasource": "iguazio", - "fieldConfig": { - "defaults": { - "custom": {} - }, - "overrides": [] - }, "fill": 1, "fillGradient": 1, "gridPos": { @@ -469,7 +408,7 @@ "alertThreshold": true }, "percentage": false, - "pluginVersion": "7.2.0", + "pluginVersion": "9.2.2", "pointradius": 2, "points": false, "renderer": "flot", @@ -479,16 +418,15 @@ "steppedLine": false, "targets": [ { + "datasource": "iguazio", "rawQuery": true, "refId": "A", - "target": "backend=tsdb; container=users; table=pipelines/$PROJECT/model-endpoints/events; filter=endpoint_id=='$MODEL' AND record_type=='custom_metrics';", + "target": "backend=tsdb; container=users; table=pipelines/$PROJECT/model-endpoints/events; filter=endpoint_id=='$MODELENDPOINT' AND record_type=='custom_metrics';", "type": "timeserie" } ], "thresholds": [], - "timeFrom": null, "timeRegions": [], - "timeShift": null, "title": "Custom Metrics", "tooltip": { "shared": true, @@ -499,44 +437,34 @@ "transparent": true, "type": "graph", "xaxis": { - "buckets": null, "mode": "time", - "name": null, "show": true, "values": [] }, "yaxes": [ { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true }, { "format": "short", - "label": null, "logBase": 1, - "max": null, - "min": null, "show": true } ], "yaxis": { - "align": false, - "alignLevel": null + "align": false } } ], - "refresh": "1m", - "schemaVersion": 26, + "refresh": "30s", + "schemaVersion": 37, "style": "dark", "tags": [], "templating": { "list": [ { - "allValue": null, "current": {}, "datasource": "model-monitoring", "definition": "target_endpoint=list_projects", @@ -552,29 +480,26 @@ "skipUrlSync": false, "sort": 0, "tagValuesQuery": "", - "tags": [], "tagsQuery": "", "type": "query", "useTags": false }, { - "allValue": null, "current": {}, "datasource": "iguazio", - "definition": "backend=kv;container=users;table=pipelines/$PROJECT/model-endpoints/endpoints;fields=endpoint_id;", + "definition": "backend=kv;container=users;table=pipelines/$PROJECT/model-endpoints/endpoints;fields=uid;", "hide": 0, "includeAll": false, - "label": "Model", + "label": "Model Endpoint", "multi": false, - "name": "MODEL", + "name": "MODELENDPOINT", "options": [], - "query": "backend=kv;container=users;table=pipelines/$PROJECT/model-endpoints/endpoints;fields=endpoint_id;", - "refresh": 0, + "query": "backend=kv;container=users;table=pipelines/$PROJECT/model-endpoints/endpoints;fields=uid;", + "refresh": 1, "regex": "", "skipUrlSync": false, "sort": 0, "tagValuesQuery": "", - "tags": [], "tagsQuery": "", "type": "query", "useTags": false @@ -589,5 +514,6 @@ "timezone": "", "title": "Model Monitoring - Performance", "uid": "9CazA-UGz", - "version": 2 -} + "version": 2, + "weekStart": "" +} \ No newline at end of file diff --git a/docs/monitoring/initial-setup-configuration.ipynb b/docs/monitoring/initial-setup-configuration.ipynb index b25fd9dcab6c..125ee219224d 100644 --- a/docs/monitoring/initial-setup-configuration.ipynb +++ b/docs/monitoring/initial-setup-configuration.ipynb @@ -33,7 +33,9 @@ " \n", " `fn.set_tracking(stream_path, batch, sample)`\n", " \n", - "- **stream_path** — the v3io stream path (e.g. `v3io:///users/..`)\n", + "- **stream_path**\n", + " - Enterprise: the v3io stream path (e.g. `v3io:///users/..`)\n", + " - CE: a valid Kafka stream (e.g. `kafka://kafka.default.svc.cluster.local:9092`)\n", "- **sample** — optional, sample every N requests\n", "- **batch** — optional, send micro-batches every N requests\n", " \n", @@ -90,15 +92,20 @@ "project.set_model_monitoring_credentials(os.environ.get(\"V3IO_ACCESS_KEY\"))\n", "\n", "# Download the pre-trained Iris model\n", - "get_dataitem(\"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\").download(\"model.pkl\")\n", + "get_dataitem(\"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\").download(\n", + " \"model.pkl\"\n", + ")\n", "\n", "iris = load_iris()\n", - "train_set = pd.DataFrame(iris['data'],\n", - " columns=['sepal_length_cm', 'sepal_width_cm',\n", - " 'petal_length_cm', 'petal_width_cm'])\n", + "train_set = pd.DataFrame(\n", + " iris[\"data\"],\n", + " columns=[\"sepal_length_cm\", \"sepal_width_cm\", \"petal_length_cm\", \"petal_width_cm\"],\n", + ")\n", "\n", "# Import the serving function from the Function Hub\n", - "serving_fn = import_function('hub://v2_model_server', project=project_name).apply(auto_mount())\n", + "serving_fn = import_function(\"hub://v2_model_server\", project=project_name).apply(\n", + " auto_mount()\n", + ")\n", "\n", "model_name = \"RandomForestClassifier\"\n", "\n", @@ -106,7 +113,9 @@ "project.log_model(model_name, model_file=\"model.pkl\", training_set=train_set)\n", "\n", "# Add the model to the serving function's routing spec\n", - "serving_fn.add_model(model_name, model_path=f\"store://models/{project_name}/{model_name}:latest\")\n", + "serving_fn.add_model(\n", + " model_name, model_path=f\"store://models/{project_name}/{model_name}:latest\"\n", + ")\n", "\n", "# Enable model monitoring\n", "serving_fn.set_tracking()\n", @@ -140,12 +149,14 @@ "from time import sleep\n", "from random import choice, uniform\n", "\n", - "iris_data = iris['data'].tolist()\n", + "iris_data = iris[\"data\"].tolist()\n", "\n", "while True:\n", " data_point = choice(iris_data)\n", - " serving_fn.invoke(f'v2/models/{model_name}/infer', json.dumps({'inputs': [data_point]}))\n", - " sleep(uniform(0.2, 1.7))\n" + " serving_fn.invoke(\n", + " f\"v2/models/{model_name}/infer\", json.dumps({\"inputs\": [data_point]})\n", + " )\n", + " sleep(uniform(0.2, 1.7))" ] } ], diff --git a/docs/monitoring/model-monitoring-deployment.ipynb b/docs/monitoring/model-monitoring-deployment.ipynb index 9c453034094b..ebffbe53071b 100644 --- a/docs/monitoring/model-monitoring-deployment.ipynb +++ b/docs/monitoring/model-monitoring-deployment.ipynb @@ -64,7 +64,7 @@ "* [Model features analysis](#model-features-analysis)\n", "\n", "1. Select a project from the project tiles screen.\n", - "2. From the project dashboard, press the **Models** tile to view the models currently deployed .\n", + "2. From the project dashboard, press the **Models** tile to view the models currently deployed.\n", "2. Press **Model Endpoints** from the menu to display a list of monitored endpoints.
\n", " If the Model Monitoring feature is not enabled, the endpoints list is empty.\n", "\n", diff --git a/docs/projects/create-project.md b/docs/projects/create-project.md index b7fbf733a337..cff1f6693663 100644 --- a/docs/projects/create-project.md +++ b/docs/projects/create-project.md @@ -207,7 +207,7 @@ Use standard Git commands to push the current project tree into a git archive. M git commit -m "Commit message" git push origin master -Alternatively you can use MLRun SDK calls: +Alternatively, you can use MLRun SDK calls: - {py:meth}`~mlrun.projects.MlrunProject.create_remote` - to register the remote Git path - {py:meth}`~mlrun.projects.MlrunProject.push` - save project spec (`project.yaml`) and commit/push updates to remote repo diff --git a/docs/projects/git-best-practices.ipynb b/docs/projects/git-best-practices.ipynb new file mode 100644 index 000000000000..6f0c77ba2c49 --- /dev/null +++ b/docs/projects/git-best-practices.ipynb @@ -0,0 +1,344 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "54d41da3", + "metadata": {}, + "source": [ + "# Git best practices" + ] + }, + { + "cell_type": "markdown", + "id": "29a0e0da", + "metadata": {}, + "source": [ + "This section provides an overview of developing and deploying ML applications using MLRun and Git. It covers the following:\n", + "- [MLRun and Git Overview](#mlrun-and-git-overview)\n", + " - [Load Code from Container vs Load Code at Runtime](#load-code-from-container-vs-load-code-at-runtime)\n", + "- [Common Tasks](#common-tasks)\n", + " - [Setting Up New MLRun Project Repo](#setting-up-new-mlrun-project-repo)\n", + " - [Running Existing MLRun Project Repo](#running-existing-mlrun-project-repo)\n", + " - [Pushing Changes to MLRun Project Repo](#pushing-changes-to-mlrun-project-repo)\n", + " - [Utilizing Different Branches](#utilizing-different-branches)" + ] + }, + { + "cell_type": "markdown", + "id": "e298490d-b0ce-4cc1-af66-2c4b00f09270", + "metadata": {}, + "source": [ + "```{admonition} Note\n", + "This section assumes basic familiarity with version control software such as GitHub, GitLab, etc. If you're new to Git and version control, see the [GitHub Hello World documentation](https://docs.github.com/en/get-started/quickstart/hello-world).\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "6a3a0b29", + "metadata": {}, + "source": [ + "## MLRun and Git Overview" + ] + }, + { + "cell_type": "markdown", + "id": "43164106", + "metadata": {}, + "source": [ + "As a best practice, your MLRun project **should be backed by a Git repo**. This allows you to keep track of your code in source control as well as utilize your entire code library within your MLRun functions." + ] + }, + { + "cell_type": "markdown", + "id": "d88ad2d5", + "metadata": {}, + "source": [ + "The typical lifecycle of a project is as follows:" + ] + }, + { + "cell_type": "markdown", + "id": "60ca000e", + "metadata": {}, + "source": [ + "![](https://docs.mlrun.org/en/latest/_static/images/project-lifecycle.png)" + ] + }, + { + "cell_type": "markdown", + "id": "bc981d84", + "metadata": {}, + "source": [ + "Many people like to develop locally on their laptops, Jupyter environments, or local IDE before submitting the code to Git and running on the larger cluster. See [Set up your client environment](https://docs.mlrun.org/en/latest/install/remote.html) for more details." + ] + }, + { + "cell_type": "markdown", + "id": "d4f36927-b688-406f-9555-1d6e90abcb50", + "metadata": {}, + "source": [ + "### Loading the code from container vs. loading the code at runtime" + ] + }, + { + "cell_type": "markdown", + "id": "dc5cd2ab-bd08-44a7-812e-47f252666ec7", + "metadata": {}, + "source": [ + "MLRun supports two approaches to loading the code from Git:\n", + "\n", + "- Loading the code from container (default behavior)
\n", + "The image for the MLRun function is built once, and consumes the code in the repo. **This is the preferred approach for production workloads**. For example:\n", + "\n", + "```python\n", + "project.set_source(source=\"git://github.com/mlrun/project-archive.git\")\n", + "\n", + "fn = project.set_function(\n", + " name=\"myjob\", handler=\"job_func.job_handler\",\n", + " image=\"mlrun/mlrun\", kind=\"job\", with_repo=True,\n", + ")\n", + "\n", + "project.build_function(fn)\n", + "```\n", + "\n", + "- Loading the code at runtime
\n", + "The MLRun function pulls the source code directly from Git at runtime. **This is a simpler approach during development that allows for making code changes without re-building the image each time.** For example:\n", + "\n", + "```python\n", + "project.set_source(source=\"git://github.com/mlrun/project-archive.git\", pull_at_runtime=True)\n", + "\n", + "fn = project.set_function(\n", + " name=\"nuclio\", handler=\"nuclio_func:nuclio_handler\",\n", + " image=\"mlrun/mlrun\", kind=\"nuclio\", with_repo=True,\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "6cd96715-f85b-4ad1-82c1-1d063d45b3c9", + "metadata": {}, + "source": [ + "## Common tasks" + ] + }, + { + "cell_type": "markdown", + "id": "7641829b", + "metadata": {}, + "source": [ + "### Setting up a new MLRun project repo" + ] + }, + { + "cell_type": "markdown", + "id": "b994758b-6cf5-4c91-aa00-e4f1641471a1", + "metadata": {}, + "source": [ + "1. Initialize your repo using the command line as per [this guide](https://dev.to/bowmanjd/create-and-initialize-a-new-github-repository-from-the-command-line-85e) or using your version control software of choice (e.g. GitHub, GitLab, etc).\n", + "\n", + "```bash\n", + "git init ...\n", + "git add ...\n", + "git commit -m ...\n", + "git remote add origin ...\n", + "git branch -M \n", + "git push -u origin \n", + "\n", + "```\n", + "\n", + "2. Clone the repo to the local environment where the MLRun client is installed (e.g. Jupyter, VSCode, etc.) and navigate to the repo.\n", + "\n", + "```{admonition} Note\n", + "It is assumed that your local environment has the required access to pull a private repo.\n", + "```\n", + "```bash\n", + "git clone \n", + "cd \n", + "```\n", + "\n", + "3. Initialize a new MLRun project with the context pointing to your newly cloned repo.\n", + "\n", + "```python\n", + "import mlrun\n", + "\n", + "project = mlrun.get_or_create_project(name=\"my-super-cool-project\", context=\"./\")\n", + "```\n", + "\n", + "4. Set the MLRun project source with the desired `pull_at_runtime` behavior (see [Loading the code from container vs. loading the code at runtime](#loading-the-code-from-container-vs-loading-the-code-at-runtime) for more info). Also set `GIT_TOKEN` in MLRun project secrets for working with private repos.\n", + "\n", + "```python\n", + "# Notice the prefix has been changed to git://\n", + "project.set_source(source=\"git://github.com/mlrun/project-archive.git\", pull_at_runtime=True)\n", + "project.set_secrets(secrets={\"GIT_TOKEN\" : \"XXXXXXXXXXXXXXX\"}, provider=\"kubernetes\")\n", + "```\n", + "\n", + "5. Register any MLRun functions or workflows and save. Make sure `with_repo` is `True` in order to add source code to the function.\n", + "\n", + "```python\n", + "project.set_function(name='train_model', func='train_model.py', kind='job', image='mlrun/mlrun', with_repo=True)\n", + "project.set_workflow(name='training_pipeline', workflow_path='training_pipeline.py')\n", + "project.save()\n", + "```\n", + "\n", + "6. Push additions to Git.\n", + "\n", + "```bash\n", + "git add ...\n", + "git commit -m ...\n", + "git push ...\n", + "```\n", + "\n", + "7. Run the MLRun function/workflow. The source code is added to the function and is available via imports as expected.\n", + "\n", + "```python\n", + "project.run_function(function=\"train_model\")\n", + "project.run(name=\"training_pipeline\")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "547733d0", + "metadata": {}, + "source": [ + "### Running an existing MLRun project repo" + ] + }, + { + "cell_type": "markdown", + "id": "8bbf162a-7348-424b-a72d-d64c90dd4db2", + "metadata": {}, + "source": [ + "1. Clone an existing MLRun project repo to your local environment where the MLRun client is installed (e.g. Jupyter, VSCode, etc.) and navigate to the repo.\n", + "\n", + "```bash\n", + "git clone \n", + "cd \n", + "```\n", + "\n", + "2. Load the MLRun project with the context pointing to your newly cloned repo. **MLRun is looking for a `project.yaml` file in the root of the repo**.\n", + "\n", + "```python\n", + "project = mlrun.load_project(context=\"./\")\n", + "```\n", + "\n", + "3. Optionally enable `pull_at_runtime` for easier development. Also set `GIT_TOKEN` in the MLRun Project secrets for working with private repos.\n", + "\n", + "```python\n", + "# source=None will use current Git source\n", + "project.set_source(source=None, pull_at_runtime=True)\n", + "project.set_secrets(secrets={\"GIT_TOKEN\" : \"XXXXXXXXXXXXXXX\"}, provider=\"kubernetes\")\n", + "```\n", + "\n", + "4. Run the MLRun function/workflow. The source code is added to the function and is available via imports as expected.\n", + "\n", + "```python\n", + "project.run_function(function=\"train_model\")\n", + "project.run(name=\"training_pipeline\")\n", + "```\n", + "\n", + "```{admonition} Note\n", + "If another user previously ran the project in your MLRun environment, ensure that your user has project permissions (otherwise you may not be able to view or run the project).\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "aea0970c", + "metadata": {}, + "source": [ + "### Pushing changes to the MLRun project repo" + ] + }, + { + "cell_type": "markdown", + "id": "ef0d8e9a-f5b0-4675-99e1-7764b054c0ba", + "metadata": {}, + "source": [ + "1. Edit the source code/functions/workflows in some way.\n", + "2. Check-in changes to Git.\n", + "\n", + "```bash\n", + "git add ...\n", + "git commit -m ...\n", + "git push ...\n", + "```\n", + "\n", + "3. If `pull_at_runtime=False`, re-build the Docker image. If `pull_at_runtime=True`, skip this step.\n", + "\n", + "```python\n", + "import mlrun\n", + "\n", + "project = mlrun.load_project(context=\"./\")\n", + "project.build_function(\"my_updated_function\")\n", + "```\n", + "\n", + "4. Run the MLRun function/workflow. The source code with changes is added to the function and is available via imports as expected.\n", + "\n", + "```python\n", + "project.run_function(function=\"train_model\")\n", + "project.run(name=\"training_pipeline\")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "7d0a5e97", + "metadata": {}, + "source": [ + "### Utilizing different branches" + ] + }, + { + "cell_type": "markdown", + "id": "c5a1878c-a565-478d-a9b6-96a876a7f3ff", + "metadata": {}, + "source": [ + "1. Check out the desired branch in the local environment.\n", + "\n", + "```bash\n", + "git checkout \n", + "```\n", + "\n", + "2. Update the desired branch in MLRun project. Optionally, save if the branch should be used for future runs.\n", + "\n", + "```python\n", + "project.set_source(\n", + " source=\"git://github.com/igz-us-sales/mlrun-git-example.git#spanish\",\n", + " pull_at_runtime=True\n", + ")\n", + "project.save()\n", + "```\n", + "\n", + "3. Run the MLRun function/workflow. The source code from desired branch is added to the function and is available via imports as expected.\n", + "\n", + "```python\n", + "project.run_function(\"greetings\")\n", + "```" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/projects/project.md b/docs/projects/project.md index a059718979d9..9de2cdf8e519 100644 --- a/docs/projects/project.md +++ b/docs/projects/project.md @@ -7,7 +7,7 @@ MLRun **Project** is a container for all your work on a particular ML applicatio

mlrun-project


Projects are stored in a GIT or archive and map to IDE projects (in PyCharm, VSCode, etc.), which enables versioning, collaboration, and [CI/CD](../projects/ci-integration.html). -Projects simplify how you process data, [submit jobs](../concepts/submitting-tasks-jobs-to-functions.html), run [multi-stage workflows](../concepts/workflow-overview.html), and deploy [real-time pipelines](../serving/serving-graph.html) in continious development or production environments. +Projects simplify how you process data, [submit jobs](../concepts/submitting-tasks-jobs-to-functions.html), run [multi-stage workflows](../concepts/workflow-overview.html), and deploy [real-time pipelines](../serving/serving-graph.html) in continuous development or production environments.

project-lifecycle


@@ -17,9 +17,10 @@ Projects simplify how you process data, [submit jobs](../concepts/submitting-tas :maxdepth: 1 create-project +git-best-practices load-project run-build-deploy build-run-workflows-pipelines ci-integration ../secrets -``` \ No newline at end of file +``` diff --git a/docs/projects/run-build-deploy.md b/docs/projects/run-build-deploy.md index 52bf1b88b311..9585646c59c4 100644 --- a/docs/projects/run-build-deploy.md +++ b/docs/projects/run-build-deploy.md @@ -7,6 +7,8 @@ - [build_function](#build) - [deploy_function](#deploy) - [Default image](#default_image) +- [Image build configuration](#build_config) +- [build_image](#build_image) ## Overview @@ -20,29 +22,34 @@ When used inside a pipeline, each method is automatically mapped to the relevant You can use those methods as `project` methods, or as global (`mlrun.`) methods. For example: - # run the "train" function in myproject - run = myproject.run_function("train", inputs={"data": data_url}) - - # run the "train" function in the current/active project (or in a pipeline) - run = mlrun.run_function("train", inputs={"data": data_url}) +```python +# run the "train" function in myproject +run = myproject.run_function("train", inputs={"data": data_url}) + +# run the "train" function in the current/active project (or in a pipeline) +run = mlrun.run_function("train", inputs={"data": data_url}) +``` The first parameter in all three methods is either the function name (in the project), or a function object, used if you want to specify functions that you imported/created ad hoc, or to modify a function spec. For example: - # import a serving function from the Function Hub and deploy a trained model over it - serving = import_function("hub://v2_model_server", new_name="serving") - serving.spec.replicas = 2 - deploy = deploy_function( - serving, - models=[{"key": "mymodel", "model_path": train.outputs["model"]}], - ) +```python +# import a serving function from the Function Hub and deploy a trained model over it +serving = import_function("hub://v2_model_server", new_name="serving") +serving.spec.replicas = 2 +deploy = deploy_function( + serving, + models=[{"key": "mymodel", "model_path": train.outputs["model"]}], +) +``` You can use the {py:meth}`~mlrun.projects.MlrunProject.get_function` method to get the function object and manipulate it, for example: - trainer = project.get_function("train") - trainer.with_limits(mem="2G", cpu=2, gpus=1) - run = project.run_function("train", inputs={"data": data_url}) - +```python +trainer = project.get_function("train") +trainer.with_limits(mem="2G", cpu=2, gpus=1) +run = project.run_function("train", inputs={"data": data_url}) +``` ## run_function @@ -65,16 +72,17 @@ Read further details on [**running tasks and getting their results**](../concept Usage examples: - # create a project with two functions (local and from Function Hub) - project = mlrun.new_project(project_name, "./proj") - project.set_function("mycode.py", "prep", image="mlrun/mlrun") - project.set_function("hub://auto_trainer", "train") - - # run functions (refer to them by name) - run1 = project.run_function("prep", params={"x": 7}, inputs={'data': data_url}) - run2 = project.run_function("train", inputs={"dataset": run1.outputs["data"]}) - run2.artifact('confusion-matrix').show() +```python +# create a project with two functions (local and from Function Hub) +project = mlrun.new_project(project_name, "./proj") +project.set_function("mycode.py", "prep", image="mlrun/mlrun") +project.set_function("hub://auto_trainer", "train") +# run functions (refer to them by name) +run1 = project.run_function("prep", params={"x": 7}, inputs={'data': data_url}) +run2 = project.run_function("train", inputs={"dataset": run1.outputs["data"]}) +run2.artifact('confusion-matrix').show() +``` ```{admonition} Run/simulate functions locally: Functions can also run and be debugged locally by using the `local` runtime or by setting the `local=True` @@ -88,8 +96,10 @@ The {py:meth}`~mlrun.projects.build_function` method is used to deploy an ML fun Example: - # build the "trainer" function image (based on the specified requirements and code repo) - project.build_function("trainer") +```python +# build the "trainer" function image (based on the specified requirements and code repo) +project.build_function("trainer") +``` The {py:meth}`~mlrun.projects.build_function` method accepts different parameters that can add to, or override, the function build spec. You can specify the target or base `image` extra docker `commands`, builder environment, and source credentials (`builder_env`), etc. @@ -105,29 +115,33 @@ Read more about [**Real-time serving pipelines**](../serving/serving-graph.html) Basic example: - # Deploy a real-time nuclio function ("myapi") - deployment = project.deploy_function("myapi") - - # invoke the deployed function (using HTTP request) - resp = deployment.function.invoke("/do") +```python +# Deploy a real-time nuclio function ("myapi") +deployment = project.deploy_function("myapi") + +# invoke the deployed function (using HTTP request) +resp = deployment.function.invoke("/do") +``` You can provide the `env` dict with: extra environment variables; `models` list to specify specific models and their attributes (in the case of serving functions); builder environment; and source credentials (`builder_env`). Example of using `deploy_function` inside a pipeline, after the `train` step, to generate a model: - # Deploy the trained model (from the "train" step) as a serverless serving function - serving_fn = mlrun.new_function("serving", image="mlrun/mlrun", kind="serving") - mlrun.deploy_function( - serving_fn, - models=[ - { - "key": model_name, - "model_path": train.outputs["model"], - "class_name": 'mlrun.frameworks.sklearn.SklearnModelServer', - } - ], - ) +```python +# Deploy the trained model (from the "train" step) as a serverless serving function +serving_fn = mlrun.new_function("serving", image="mlrun/mlrun", kind="serving") +mlrun.deploy_function( + serving_fn, + models=[ + { + "key": model_name, + "model_path": train.outputs["model"], + "class_name": 'mlrun.frameworks.sklearn.SklearnModelServer', + } + ], +) +``` ```{admonition} Note @@ -147,20 +161,92 @@ image that was set when the function was added to the project. For example: - project = mlrun.new_project(project_name, "./proj") - # use v1 of a pre-built image as default - project.set_default_image("myrepo/my-prebuilt-image:v1") - # set function without an image, will use the project's default image - project.set_function("mycode.py", "prep") +```python + project = mlrun.new_project(project_name, "./proj") + # use v1 of a pre-built image as default + project.set_default_image("myrepo/my-prebuilt-image:v1") + # set function without an image, will use the project's default image + project.set_function("mycode.py", "prep") + + # function will run with the "myrepo/my-prebuilt-image:v1" image + run1 = project.run_function("prep", params={"x": 7}, inputs={'data': data_url}) + + ... + + # replace the default image with a newer v2 + project.set_default_image("myrepo/my-prebuilt-image:v2") + # function will now run using the v2 version of the image + run2 = project.run_function("prep", params={"x": 7}, inputs={'data': data_url}) +``` - # function will run with the "myrepo/my-prebuilt-image:v1" image - run1 = project.run_function("prep", params={"x": 7}, inputs={'data': data_url}) + +## Image build configuration - ... +Use the {py:meth}`~mlrun.projects.MlrunProject.set_default_image` function to configure a project to use an existing +image. The configuration for building this default image can be contained within the project, by using the +{py:meth}`~mlrun.projects.MlrunProject.build_config` and {py:meth}`~mlrun.projects.MlrunProject.build_image` +functions. - # replace the default image with a newer v2 - project.set_default_image("myrepo/my-prebuilt-image:v2") - # function will now run using the v2 version of the image - run2 = project.run_function("prep", params={"x": 7}, inputs={'data': data_url}) +The project build configuration is maintained in the project object. When saving, exporting and importing the project +these configurations are carried over with it. This makes it simple to transport a project between systems while +ensuring that the needed runtime images are built and are ready for execution. +When using {py:meth}`~mlrun.projects.MlrunProject.build_config`, build configurations can be passed along with the +resulting image name, and these are used to build the image. The image name is assigned following these rules, +based on the project configuration and provided parameters: +1. If provided, the name passed in the `image` parameter of {py:meth}`~mlrun.projects.MlrunProject.build_config`. +2. The project's default image name, if configured using {py:meth}`~mlrun.projects.MlrunProject.set_default_image`. +3. The value set in MLRun's `default_project_image_name` config parameter - by default this value is + `.mlrun-project-image-{name}` with the project name as template parameter. + +For example: + +```python + # Set image config for current project object, using base mlrun image with additional requirements. + image_name = ".my-project-image" + project.build_config( + image=image_name, + set_as_default=True, + with_mlrun=False, + base_image="mlrun/mlrun", + requirements=["vaderSentiment"], + ) + + # Export the project configuration. The yaml file will contain the build configuration + proj_file_path = "~/mlrun/my-project/project.yaml" + project.export(proj_file_path) +``` + +This project can then be imported and the default image can be built: + +```python + # Import the project as a new project with a different name + new_project = mlrun.load_project("~/mlrun/my-project", name="my-other-project") + # Build the default image for the project, based on project build config + new_project.build_image() + + # Set a new function and run it (new function uses the my-project-image image built previously) + new_project.set_function("sentiment.py", name="scores", kind="job", handler="handler") + new_project.run_function("scores") +``` + + +## build_image + +The {py:meth}`~mlrun.projects.MlrunProject.build_image` function builds an image using the existing build configuration. +This method can also be used to set the build configuration and build the image based on it - in a single step. + +When using `set_as_default=False` any build config provided is still kept in the project object but the generated +image name is not set as the default image for this project. + +For example: + +```python +image_name = ".temporary-image" +project.build_image(image=image_name, set_as_default=False) + +# Create a function using the temp image name +project.set_function("sentiment.py", name="scores", kind="job", handler="handler", image=image_name) +``` + \ No newline at end of file diff --git a/docs/runtimes/configuring-job-resources.md b/docs/runtimes/configuring-job-resources.md index 1eecffd27b1e..e91ed0b8950d 100644 --- a/docs/runtimes/configuring-job-resources.md +++ b/docs/runtimes/configuring-job-resources.md @@ -17,11 +17,20 @@ Configuration of job resources is relevant for all supported cloud platforms. Some runtimes can scale horizontally, configured either as a number of replicas:
`spec.replicas`
-or a range (for auto scaling in Dask or Nuclio:
+or a range (for auto scaling in Dask or Nuclio):
``` spec.min_replicas = 1 spec.max_replicas = 4 ``` + +```{admonition} Note +Scaling (replication) algorithm, if a `target utilization` +(Target CPU%) value is set, the replication controller calculates the utilization +value as a percentage of the equivalent `resource request` (CPU request) on +the replicas and based on that provides horizontal scaling. +See also [Kubernetes horizontal autoscale](https://kubernetes.io/docs/tasks/run-application/horizontal-pod-autoscale/#how-does-a-horizontalpodautoscaler-work) +``` + See more details in [Dask](../runtimes/dask-overview.html), [MPIJob and Horovod](../runtimes/horovod.html), [Spark](../runtimes/spark-operator.html), [Nuclio](../concepts/nuclio-real-time-functions.html). ## CPU, GPU, and memory limits for user jobs @@ -205,7 +214,7 @@ Pods (services, or jobs created by those services) can have priorities, which in scheduling: a lower priority pod can be evicted to allow scheduling of a higher priority pod. Pod priority is relevant for all pods created by the service. For MLRun, it applies to the jobs created by MLRun. For Nuclio it applies to the pods of the Nuclio-created functions. -Eviction uses these values in conjuction with pod priority to determine what to evict [Pod Priority and Preemption](https://kubernetes.io/docs/concepts/configuration/pod-priority-preemption). +Eviction uses these values in conjunction with pod priority to determine what to evict [Pod Priority and Preemption](https://kubernetes.io/docs/concepts/configuration/pod-priority-preemption). Pod priority is specified through Priority classes, which map to a priority value. The priority values are: High, Medium, Low. The default is Medium. Pod priority is supported for: - MLRun jobs: the default priority class for the jobs that MLRun creates. diff --git a/docs/runtimes/create-and-use-functions.ipynb b/docs/runtimes/create-and-use-functions.ipynb index 87f1d1567ae2..800becb932b1 100644 --- a/docs/runtimes/create-and-use-functions.ipynb +++ b/docs/runtimes/create-and-use-functions.ipynb @@ -46,7 +46,7 @@ "\n", "![MLRun Function](../_static/images/mlrun_function_diagram.png)\n", "\n", - "You can read more about MLRun Functions [**here**](./functions.html). Each parameter and capability is explained in more detail in the following sections [**Creating functions**](#creating-functions) and [**Customizing functions**](#customizing-functions)." + "You can read more about MLRun Functions [**here**](./functions.html). Each parameter and capability are explained in more detail in the following sections [**Creating functions**](#creating-functions) and [**Customizing functions**](#customizing-functions)." ] }, { @@ -181,7 +181,7 @@ "A good place to start is one of the default MLRun images:\n", "- `mlrun/mlrun`: Suits most lightweight components (includes `sklearn`, `pandas`, `numpy` and more)\n", "- `mlrun/ml-models`: Suits most CPU ML/DL workloads (includes `Tensorflow`, `Keras`, `PyTorch` and more)\n", - "- `mlrun/ml-models-gpu`: Suits most GPU ML/DL workloads (includes GPU `Tensorflow`, `Keras`, `PyTorch` and more )\n", + "- `mlrun/ml-models-gpu`: Suits most GPU ML/DL workloads (includes GPU `Tensorflow`, `Keras`, `PyTorch` and more)\n", "\n", "Dockerfiles for the MLRun images can be found [**here**](https://github.com/mlrun/mlrun/tree/development/dockerfiles)." ] @@ -381,7 +381,7 @@ "id": "8a65c196", "metadata": {}, "source": [ - "Functions can also be imported from the [**MLRun Function Hub**](https://www.mlrun.org/marketplace): simply import using the name of the function and the `hub://` prefix:\n", + "Functions can also be imported from the [**MLRun Function Hub**](https://www.mlrun.org/hub): simply import using the name of the function and the `hub://` prefix:\n", "``` {admonition} Note\n", "By default, the `hub://` prefix points to the official Function Hub. You can, however, also substitute your own repo to create your own hub.\n", "```" @@ -506,7 +506,7 @@ "# Nuclio/serving scaling\n", "fn.spec.replicas = 2\n", "fn.spec.min_replicas = 1\n", - "fn.spec.min_replicas = 4\n", + "fn.spec.max_replicas = 4\n", "```\n", "\n", "### Mount persistent storage\n", @@ -532,7 +532,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -546,7 +546,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.7" + "version": "3.10.6 (main, Nov 14 2022, 16:10:14) [GCC 11.3.0]" + }, + "vscode": { + "interpreter": { + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + } } }, "nbformat": 4, diff --git a/docs/runtimes/dask-mlrun.ipynb b/docs/runtimes/dask-mlrun.ipynb index 5fd15e38318e..65951cd4a122 100644 --- a/docs/runtimes/dask-mlrun.ipynb +++ b/docs/runtimes/dask-mlrun.ipynb @@ -42,6 +42,7 @@ "source": [ "# set mlrun api path and artifact path for logging\n", "import mlrun\n", + "\n", "project = mlrun.get_or_create_project(\"dask-demo\", \"./\")" ] }, @@ -89,7 +90,9 @@ "source": [ "# create an mlrun function that will init the dask cluster\n", "dask_cluster_name = \"dask-cluster\"\n", - "dask_cluster = mlrun.new_function(dask_cluster_name, kind='dask', image='mlrun/ml-models')\n", + "dask_cluster = mlrun.new_function(\n", + " dask_cluster_name, kind=\"dask\", image=\"mlrun/ml-models\"\n", + ")\n", "dask_cluster.apply(mlrun.mount_v3io())" ] }, @@ -103,12 +106,12 @@ "dask_cluster.spec.min_replicas = 1\n", "dask_cluster.spec.max_replicas = 4\n", "\n", - "# set the use of dask remote cluster (distributed) \n", + "# set the use of dask remote cluster (distributed)\n", "dask_cluster.spec.remote = True\n", "dask_cluster.spec.service_type = \"NodePort\"\n", "\n", "# set dask memory and cpu limits\n", - "dask_cluster.with_worker_requests(mem='2G', cpu='2')" + "dask_cluster.with_worker_requests(mem=\"2G\", cpu=\"2\")" ] }, { @@ -285,7 +288,7 @@ "metadata": {}, "outputs": [], "source": [ - "import mlrun " + "import mlrun" ] }, { @@ -345,29 +348,26 @@ "metadata": {}, "outputs": [], "source": [ - "def test_dask(context,\n", - " dataset: mlrun.DataItem,\n", - " client=None,\n", - " dask_function: str=None) -> None:\n", - " \n", + "def test_dask(\n", + " context, dataset: mlrun.DataItem, client=None, dask_function: str = None\n", + ") -> None:\n", + "\n", " # setup dask client from the MLRun dask cluster function\n", " if dask_function:\n", " client = mlrun.import_function(dask_function).client\n", " elif not client:\n", " client = Client()\n", - " \n", + "\n", " # load the dataitem as dask dataframe (dd)\n", " df = dataset.as_df(df_module=dd)\n", - " \n", + "\n", " # run describe (get statistics for the dataframe) with dask\n", " df_describe = df.describe().compute()\n", - " \n", - " # run groupby and count using dask \n", + "\n", + " # run groupby and count using dask\n", " df_grpby = df.groupby(\"VendorID\").count().compute()\n", - " \n", - " context.log_dataset(\"describe\", \n", - " df=df_grpby,\n", - " format='csv', index=True)\n", + "\n", + " context.log_dataset(\"describe\", df=df_grpby, format=\"csv\", index=True)\n", " return" ] }, @@ -400,7 +400,7 @@ "metadata": {}, "outputs": [], "source": [ - "DATA_URL=\"/User/examples/ytrip.csv\"" + "DATA_URL = \"/User/examples/ytrip.csv\"" ] }, { @@ -444,9 +444,11 @@ "metadata": {}, "outputs": [], "source": [ - "# mlrun transforms the code above (up to nuclio: end-code cell) into serverless function \n", + "# mlrun transforms the code above (up to nuclio: end-code cell) into serverless function\n", "# which runs in k8s pods\n", - "fn = mlrun.code_to_function(\"test_dask\", kind='job', handler=\"test_dask\").apply(mlrun.mount_v3io())" + "fn = mlrun.code_to_function(\"test_dask\", kind=\"job\", handler=\"test_dask\").apply(\n", + " mlrun.mount_v3io()\n", + ")" ] }, { @@ -470,7 +472,7 @@ "outputs": [], "source": [ "# function URI is db:///\n", - "dask_uri = f'db://{project.name}/{dask_cluster_name}'" + "dask_uri = f\"db://{project.name}/{dask_cluster_name}\"" ] }, { @@ -723,9 +725,9 @@ } ], "source": [ - "r = fn.run(handler = test_dask,\n", - " inputs={\"dataset\": DATA_URL},\n", - " params={\"dask_function\": dask_uri})" + "r = fn.run(\n", + " handler=test_dask, inputs={\"dataset\": DATA_URL}, params={\"dask_function\": dask_uri}\n", + ")" ] }, { diff --git a/docs/runtimes/dask-overview.ipynb b/docs/runtimes/dask-overview.ipynb index 95b5b635ed56..913095625833 100644 --- a/docs/runtimes/dask-overview.ipynb +++ b/docs/runtimes/dask-overview.ipynb @@ -148,8 +148,9 @@ "from collections import Counter\n", "from dask.distributed import Client\n", "\n", - "import warnings \n", - "warnings.filterwarnings('ignore')" + "import warnings\n", + "\n", + "warnings.filterwarnings(\"ignore\")" ] }, { @@ -165,9 +166,9 @@ " :param size: the size in bytes\n", " :return: void\n", " \"\"\"\n", - " chars = ''.join([random.choice(string.ascii_letters) for i in range(size)]) #1\n", + " chars = \"\".join([random.choice(string.ascii_letters) for i in range(size)]) # 1\n", "\n", - " with open(filename, 'w') as f:\n", + " with open(filename, \"w\") as f:\n", " f.write(chars)\n", " pass" ] @@ -178,12 +179,11 @@ "metadata": {}, "outputs": [], "source": [ - "PATH = '/User/howto/dask/random_files'\n", + "PATH = \"/User/howto/dask/random_files\"\n", "SIZE = 10000000\n", "\n", "for i in range(100):\n", - " generate_big_random_letters(filename = PATH + '/file_' + str(i) + '.txt', \n", - " size = SIZE)" + " generate_big_random_letters(filename=PATH + \"/file_\" + str(i) + \".txt\", size=SIZE)" ] }, { @@ -212,10 +212,10 @@ "\n", " # sort file\n", " sorted_file = sorted(data)\n", - " \n", + "\n", " # count file\n", " number_of_characters = len(sorted_file)\n", - " \n", + "\n", " return number_of_characters" ] }, @@ -232,12 +232,12 @@ " \"\"\"\n", " num_list = []\n", " files = os.listdir(path)\n", - " \n", + "\n", " for file in files:\n", " cnt = count_letters(os.path.join(path, file))\n", " num_list.append(cnt)\n", - " \n", - " l = num_list \n", + "\n", + " l = num_list\n", " return print(\"done!\")" ] }, @@ -265,7 +265,7 @@ ], "source": [ "%%time\n", - "PATH = '/User/howto/dask/random_files/'\n", + "PATH = \"/User/howto/dask/random_files/\"\n", "process_files(PATH)" ] }, @@ -282,7 +282,7 @@ "metadata": {}, "outputs": [], "source": [ - "# get the dask client address \n", + "# get the dask client address\n", "client = Client()" ] }, @@ -332,7 +332,7 @@ ], "source": [ "%%time\n", - "# gather results \n", + "# gather results\n", "l = client.gather(a)" ] }, diff --git a/docs/runtimes/dask-pipeline.ipynb b/docs/runtimes/dask-pipeline.ipynb index 6ace6a67bd4a..8654ada46a83 100644 --- a/docs/runtimes/dask-pipeline.ipynb +++ b/docs/runtimes/dask-pipeline.ipynb @@ -47,11 +47,12 @@ "import os\n", "import mlrun\n", "import warnings\n", + "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "# set project name and dir\n", - "project_name = 'sk-project-dask'\n", - "project_dir = './'\n", + "project_name = \"sk-project-dask\"\n", + "project_dir = \"./\"\n", "\n", "# specify artifacts target location\n", "_, artifact_path = mlrun.set_environment(artifact_path=path)\n", @@ -82,13 +83,14 @@ "outputs": [], "source": [ "import mlrun\n", + "\n", "# set up function from local file\n", "dsf = mlrun.new_function(name=\"mydask\", kind=\"dask\", image=\"mlrun/ml-models\")\n", "\n", "# set up function specs for dask\n", "dsf.spec.remote = True\n", "dsf.spec.replicas = 5\n", - "dsf.spec.service_type = 'NodePort'\n", + "dsf.spec.service_type = \"NodePort\"\n", "dsf.with_limits(mem=\"6G\")\n", "dsf.spec.nthreads = 5" ] @@ -417,7 +419,7 @@ "outputs": [], "source": [ "# register the workflow file as \"main\", embed the workflow code into the project YAML\n", - "sk_dask_proj.set_workflow('main', 'workflow.py', embed=False)" + "sk_dask_proj.set_workflow(\"main\", \"workflow.py\", embed=False)" ] }, { @@ -581,13 +583,9 @@ } ], "source": [ - "artifact_path = os.path.abspath('./pipe/{{workflow.uid}}')\n", + "artifact_path = os.path.abspath(\"./pipe/{{workflow.uid}}\")\n", "run_id = sk_dask_proj.run(\n", - " 'main',\n", - " arguments={}, \n", - " artifact_path=artifact_path, \n", - " dirty=False, \n", - " watch=True\n", + " \"main\", arguments={}, artifact_path=artifact_path, dirty=False, watch=True\n", ")" ] }, diff --git a/docs/runtimes/functions-architecture.md b/docs/runtimes/functions-architecture.md index dc71ece53a01..2af97ba3a594 100644 --- a/docs/runtimes/functions-architecture.md +++ b/docs/runtimes/functions-architecture.md @@ -7,7 +7,7 @@ MLRun supports: - Iterative tasks for automatic and distributed execution of many tasks with variable parameters (hyperparams). See [Hyperparam and iterative jobs](../hyper-params.html). - Horizontal scaling of functions across multiple containers. See [Distributed and Parallel Jobs](./distributed.html). -MLRun has an open [public Function Hub](https://www.mlrun.org/marketplace/functions/) that stores many pre-developed functions for +MLRun has an open [public Function Hub](https://www.mlrun.org/hub/functions/) that stores many pre-developed functions for use in your projects. mlrun-architecture
diff --git a/docs/runtimes/job-function.md b/docs/runtimes/job-function.md new file mode 100644 index 000000000000..da94f71c3c37 --- /dev/null +++ b/docs/runtimes/job-function.md @@ -0,0 +1,38 @@ +(job-function)= +# Function of type `job` + +You can deploy a model using a `job` type function, which runs the code in a Kubernetes Pod. + +You can create (register) a `job` function with basic attributes such as code, requirements, image, etc. using the +{py:meth}`~mlrun.projects.MlrunProject.set_function` method. +You can also import an existing job function/template from the {ref}`function-hub`. + +Functions can be created from a single code, notebook file, or have access to the entire project context directory. +(By adding the `with_repo=True` flag, the project context is cloned into the function runtime environment.) + +Examples: + + +```python +# register a (single) python file as a function +project.set_function('src/data_prep.py', name='data-prep', image='mlrun/mlrun', handler='prep', kind="job") + +# register a notebook file as a function, specify custom image and extra requirements +project.set_function('src/mynb.ipynb', name='test-function', image="my-org/my-image", + handler="run_test", requirements=["scikit-learn"], kind="job") + +# register a module.handler as a function (requires defining the default sources/work dir, if it's not root) +project.spec.workdir = "src" +project.set_function(name="train", handler="training.train", image="mlrun/mlrun", kind="job", with_repo=True) +``` + +To run the job: +``` +project.run_function("train") +``` + +**See also** +- [Create and register functions](../runtimes/create-and-use-functions.html) +- [How to annotate notebooks (to be used as functions)](../runtimes/mlrun_code_annotations.html) +- [How to run, build, or deploy functions](./run-build-deploy.html) +- [Using functions in workflows](./build-run-workflows-pipelines.html) \ No newline at end of file diff --git a/docs/runtimes/load-from-hub.md b/docs/runtimes/load-from-hub.md index 4499efdb2ad6..8e829f4104cf 100644 --- a/docs/runtimes/load-from-hub.md +++ b/docs/runtimes/load-from-hub.md @@ -22,7 +22,7 @@ Functions can be easily imported into your project and therefore help you to spe ## Searching for functions -The Function Hub is located [here](https://www.mlrun.org/marketplace/).
+The Function Hub is located [here](https://www.mlrun.org/hub/).
You can search and filter the categories and kinds to find a function that meets your needs. ![Hub](../_static/images/marketplace-ui.png) @@ -59,7 +59,7 @@ print(f'Artifacts path: {artifact_path}\nMLRun DB path: {mlconf.dbpath}') ## Loading functions from the Hub -Run `project.set_function` to load a functions.
+Run `project.set_function` to load a function.
`set_function` updates or adds a function object to the project. `set_function(func, name='', kind='', image=None, with_repo=None)` @@ -74,17 +74,17 @@ Parameters: Returns: project object -For more information see the {py:meth}`~mlrun.projects.MlrunProject.set_function`API documentation. +For more information see the {py:meth}`~mlrun.projects.MlrunProject.set_function` API documentation. ### Load function example -This example loads the describe function. This function analyzes a csv or parquet file for data analysis. +This example loads the `describe` function. This function analyzes a csv or parquet file for data analysis. ```python project.set_function('hub://describe', 'describe') ``` -Create a function object called my_describe: +Create a function object called `my_describe`: ```python my_describe = project.func('describe') @@ -115,7 +115,7 @@ my_describe.doc() ## Running the function -Use the `run` method to to run the function. +Use the `run` method to run the function. When working with functions pay attention to the following: @@ -143,7 +143,7 @@ my_describe.run(name='describe', ### Viewing the jobs & the artifacts -There are few options to view the outputs of the jobs we ran: +There are few options to view the outputs of the jobs you ran: - In Jupyter the result of the job is displayed in the Jupyter notebook. When you click on the artifacts it displays its content in Jupyter. - In the MLRun UI, under the project name, you can view the job that was running as well as the artifacts it generated. diff --git a/docs/runtimes/mlrun_code_annotations.ipynb b/docs/runtimes/mlrun_code_annotations.ipynb index 528ace5254bc..6fc751c80e58 100644 --- a/docs/runtimes/mlrun_code_annotations.ipynb +++ b/docs/runtimes/mlrun_code_annotations.ipynb @@ -28,6 +28,7 @@ "source": [ "# mlrun: start-code\n", "\n", + "\n", "def sub_handler():\n", " return \"hello world\"" ] @@ -61,6 +62,7 @@ "def handler(context, event):\n", " return sub_handler()\n", "\n", + "\n", "# mlrun: end-code" ] }, @@ -333,8 +335,8 @@ "source": [ "from mlrun import code_to_function\n", "\n", - "some_function = code_to_function('some-function-name', kind='job', code_output='.')\n", - "some_function.run(name='some-function-name', handler='handler', local=True)" + "some_function = code_to_function(\"some-function-name\", kind=\"job\", code_output=\".\")\n", + "some_function.run(name=\"some-function-name\", handler=\"handler\", local=True)" ] }, { @@ -365,9 +367,11 @@ "source": [ "# mlrun: start-code my-function-name\n", "\n", + "\n", "def handler(context, event):\n", " return \"hello from my-function\"\n", "\n", + "\n", "# mlrun: end-code my-function-name" ] }, @@ -629,8 +633,8 @@ } ], "source": [ - "my_function = code_to_function('my-function-name', kind='job')\n", - "my_function.run(name='my-function-name', handler='handler', local=True)" + "my_function = code_to_function(\"my-function-name\", kind=\"job\")\n", + "my_function.run(name=\"my-function-name\", handler=\"handler\", local=True)" ] }, { @@ -964,8 +968,10 @@ } ], "source": [ - "my_multi_section_function = code_to_function('multi-section-function-name', kind='job')\n", - "my_multi_section_function.run(name='multi-section-function-name', handler='handler', local=True)" + "my_multi_section_function = code_to_function(\"multi-section-function-name\", kind=\"job\")\n", + "my_multi_section_function.run(\n", + " name=\"multi-section-function-name\", handler=\"handler\", local=True\n", + ")" ] }, { @@ -985,9 +991,11 @@ "source": [ "# mlrun: start-code part-cell-function\n", "\n", + "\n", "def handler(context, event):\n", " return f\"hello from {function_name}\"\n", "\n", + "\n", "function_name = \"part-cell-function\"\n", "\n", "# mlrun: end-code part-cell-function\n", @@ -1246,8 +1254,8 @@ } ], "source": [ - "my_multi_section_function = code_to_function('part-cell-function', kind='job')\n", - "my_multi_section_function.run(name='part-cell-function', handler='handler', local=True)" + "my_multi_section_function = code_to_function(\"part-cell-function\", kind=\"job\")\n", + "my_multi_section_function.run(name=\"part-cell-function\", handler=\"handler\", local=True)" ] }, { diff --git a/docs/runtimes/serving-function.md b/docs/runtimes/serving-function.md new file mode 100644 index 000000000000..e159354b96fb --- /dev/null +++ b/docs/runtimes/serving-function.md @@ -0,0 +1,21 @@ +(serving-function)= +# Function of type `serving` + +Deploying models in MLRun uses the function type `serving`. You can create a serving function using the `set_function()` call from a notebook. +You can also import an existing serving function/template from the {ref}`function-hub`. + +This example converts a notebook to a serving function, adds a model to it, and deploys it: + +```python +serving = project.set_function(name="my-serving", func="my_serving.ipynb", kind="serving", image="mlrun/mlrun", handler="handler") +serving.add_model(key="iris", model_path="https://s3.wasabisys.com/iguazio/models/iris/model.pkl", model_class="ClassifierModel") +project.deploy_function(serving) +``` + + +**See also** +- {ref}`Real-time serving pipelines (graphs) `: higher level real-time graphs (DAG) over one or more Nuclio functions +- {ref}`Serving graphs demos and tutorials ` +- {ref}`Real-time serving ` +- {ref}`Serving pre-trained ML/DL models ` + diff --git a/docs/runtimes/spark-operator.ipynb b/docs/runtimes/spark-operator.ipynb index 4bf99d0c4eb3..440fa6b9f9c6 100644 --- a/docs/runtimes/spark-operator.ipynb +++ b/docs/runtimes/spark-operator.ipynb @@ -36,22 +36,22 @@ "# set up new spark function with spark operator\n", "# command will use our spark code which needs to be located on our file system\n", "# the name param can have only non capital letters (k8s convention)\n", - "read_csv_filepath = os.path.join(os.path.abspath('.'), 'spark_read_csv.py')\n", - "sj = mlrun.new_function(kind='spark', command=read_csv_filepath, name='sparkreadcsv') \n", + "read_csv_filepath = os.path.join(os.path.abspath(\".\"), \"spark_read_csv.py\")\n", + "sj = mlrun.new_function(kind=\"spark\", command=read_csv_filepath, name=\"sparkreadcsv\")\n", "\n", "# set spark driver config (gpu_type & gpus= supported too)\n", "sj.with_driver_limits(cpu=\"1300m\")\n", - "sj.with_driver_requests(cpu=1, mem=\"512m\") \n", + "sj.with_driver_requests(cpu=1, mem=\"512m\")\n", "\n", "# set spark executor config (gpu_type & gpus= are supported too)\n", "sj.with_executor_limits(cpu=\"1400m\")\n", "sj.with_executor_requests(cpu=1, mem=\"512m\")\n", "\n", "# adds fuse, daemon & iguazio's jars support\n", - "sj.with_igz_spark() \n", + "sj.with_igz_spark()\n", "\n", - "# Alternately, move volume_mounts to driver and executor-specific fields and leave \n", - "# v3io mounts out of executor mounts if mount_v3io_to_executor=False \n", + "# Alternately, move volume_mounts to driver and executor-specific fields and leave\n", + "# v3io mounts out of executor mounts if mount_v3io_to_executor=False\n", "# sj.with_igz_spark(mount_v3io_to_executor=False)\n", "\n", "# set spark driver volume mount\n", @@ -61,13 +61,13 @@ "# sj.function.with_executor_host_path_volume(\"/host/path\", \"/mount/path\")\n", "\n", "# confs are also supported\n", - "sj.spec.spark_conf['spark.eventLog.enabled'] = True\n", + "sj.spec.spark_conf[\"spark.eventLog.enabled\"] = True\n", "\n", "# add python module\n", - "sj.spec.build.commands = ['pip install matplotlib']\n", + "sj.with_requiremants([`matplotlib`])\n", "\n", "# Number of executors\n", - "sj.spec.replicas = 2 " + "sj.spec.replicas = 2" ] }, { @@ -77,7 +77,7 @@ "outputs": [], "source": [ "# Rebuilds the image with MLRun - needed in order to support artifactlogging etc\n", - "sj.deploy() " + "sj.deploy()" ] }, { @@ -87,7 +87,7 @@ "outputs": [], "source": [ "# Run task while setting the artifact path on which our run artifact (in any) will be saved\n", - "sj.run(artifact_path='/User')" + "sj.run(artifact_path=\"/User\")" ] }, { diff --git a/docs/secrets.md b/docs/secrets.md index 715218c79689..080519dac9b6 100644 --- a/docs/secrets.md +++ b/docs/secrets.md @@ -10,11 +10,11 @@ and how much exposure they create for your secrets. **In this section** - [Overview](#overview) - [MLRun-managed secrets](#mlrun-managed-secrets) - - [Using tasks with secrets](#using-tasks-with-secrets) + - [Using tasks with secrets](#using-tasks-with-secrets) - [Secret providers](#secret-providers) - - [Kubernetes project secrets](#kubernetes-project-secrets) - - [Azure Vault](#azure-vault) - - [Demo/Development secret providers](#demo-development-secret-providers) + - [Kubernetes project secrets](#kubernetes-project-secrets) + - [Azure Vault](#azure-vault) + - [Demo/Development secret providers](#demo-development-secret-providers) - [Externally managed secrets](#externally-managed-secrets) - [Mapping secrets to environment](#mapping-secrets-to-environment) - [Mapping secrets as files](#mapping-secrets-as-files) @@ -317,9 +317,11 @@ MLRun provides facilities to map k8s secrets that were created externally to job the spec of the runtime that is created should be modified by mounting secrets to it - either as files or as environment variables containing specific keys from the secret. +In the following examples, assume a k8s secret called `my-secret` was created in the same k8s namespace where MLRun is running, with two +keys in it - `secret1` and `secret2`. + ### Mapping secrets to environment -Let's assume a k8s secret called `my-secret` was created in the same k8s namespace where MLRun is running, with two -keys in it - `secret1` and `secret2`. The following example adds these two secret keys as environment variables + The following example adds these two secret keys as environment variables to an MLRun job: ```{code-block} python @@ -359,9 +361,9 @@ function: ```python # Mount all keys in the secret as files under /mnt/secrets -function.mount_secret("my-secret", "/mnt/secrets/") +function.apply(mlrun.platforms.mount_secret("my-secret", "/mnt/secrets/")) ``` -This creates two files in the function pod, called `/mnt/secrets/secret1` and `/mnt/secrets/secret2`. Reading these +In our example, the two keys in `my-secret` are created as two files in the function pod, called `/mnt/secrets/secret1` and `/mnt/secrets/secret2`. Reading these files provide the values. It is possible to limit the keys mounted to the function - see the documentation of {py:func}`~mlrun.platforms.mount_secret` for more details. diff --git a/docs/serving/available-steps.md b/docs/serving/available-steps.md index 511e65680c22..ea8d7a883acd 100644 --- a/docs/serving/available-steps.md +++ b/docs/serving/available-steps.md @@ -57,7 +57,7 @@ The following table lists the available data-transformation steps. The next tabl | [mlrun.datastore.DataItem](../api/mlrun.datastore.html#mlrun.datastore.DataItem) | Data input/output class abstracting access to various local/remote data sources. | | [storey.transformations.JoinWithTable](https://storey.readthedocs.io/en/latest/api.html#storey.transformations.JoinWithTable) | Joins each event with data from the given table. | | JoinWithV3IOTable | Joins each event with a V3IO table. Used for event augmentation. | -| [QueryByKey](https://storey.readthedocs.io/en/latest/api.html#storey.aggregations.QueryByKey) | Similar to to AggregateByKey, but this step is for serving only and does not aggregate the event. | +| [QueryByKey](https://storey.readthedocs.io/en/latest/api.html#storey.aggregations.QueryByKey) | Similar to AggregateByKey, but this step is for serving only and does not aggregate the event. | | [RemoteStep](../api/mlrun.serving.html#mlrun.serving.remote.RemoteStep) | Class for calling remote endpoints. | | [storey.transformations.SendToHttp](https://storey.readthedocs.io/en/latest/api.html#storey.transformations.SendToHttp) | Joins each event with data from any HTTP source. Used for event augmentation. | @@ -84,7 +84,6 @@ The following table lists the available data-transformation steps. The next tabl | mlrun.datastore.SqlTarget | Persists the data in SQL table to its associated storage by key. | Y | N | Y | | [mlrun.datastore.ParquetTarget](https://storey.readthedocs.io/en/latest/api.html#storey.targets.ParquetTarget) | The Parquet target storage driver, used to materialize feature set/vector data into parquet files. | Y | Y | Y | | [mlrun.datastore.StreamTarget](https://storey.readthedocs.io/en/latest/api.html#storey.targets.StreamTarget) | Writes all incoming events into a V3IO stream. | Y | N | N | -| [storey.transformations.ToDataFrame](https://storey.readthedocs.io/en/latest/api.html#storey.transformations.ToDataFrame) | Create pandas data frame from events. Can appear in the middle of the flow. | Y | N | N | ## Models | Class name | Description | diff --git a/docs/serving/custom-model-serving-class.md b/docs/serving/custom-model-serving-class.md index 3fe8033e3787..6147409ca1e2 100644 --- a/docs/serving/custom-model-serving-class.md +++ b/docs/serving/custom-model-serving-class.md @@ -172,6 +172,7 @@ To set the tracking stream options, specify the following function spec attribut fn.set_tracking(stream_path, batch, sample) -* **stream_path** — the v3io stream path (e.g. `v3io:///users/..`) +* **stream_path** — Enterprise: the v3io stream path (e.g. `v3io:///users/..`); CE: a valid Kafka stream +(e.g. kafka://kafka.default.svc.cluster.local:9092) * **sample** — optional, sample every N requests * **batch** — optional, send micro-batches every N requests diff --git a/docs/serving/distributed-graph.ipynb b/docs/serving/distributed-graph.ipynb index 5d2e09ffaf1b..09d602bafac7 100644 --- a/docs/serving/distributed-graph.ipynb +++ b/docs/serving/distributed-graph.ipynb @@ -83,6 +83,7 @@ "source": [ "# set up the environment\n", "import mlrun\n", + "\n", "project = mlrun.get_or_create_project(\"pipe\")" ] }, @@ -260,20 +261,29 @@ ], "source": [ "# define a new real-time serving function (from code) with an async graph\n", - "fn = mlrun.code_to_function(\"multi-func\", filename=\"./data_prep.py\", kind=\"serving\", image='mlrun/mlrun')\n", + "fn = mlrun.code_to_function(\n", + " \"multi-func\", filename=\"./data_prep.py\", kind=\"serving\", image=\"mlrun/mlrun\"\n", + ")\n", "graph = fn.set_topology(\"flow\", engine=\"async\")\n", "\n", "# define the graph steps (DAG)\n", - "graph.to(name=\"load_url\", handler=\"load_url\")\\\n", - " .to(name=\"to_paragraphs\", handler=\"to_paragraphs\")\\\n", - " .to(\"storey.FlatMap\", \"flatten_paragraphs\", _fn=\"(event)\")\\\n", - " .to(\">>\", \"q1\", path=internal_stream)\\\n", - " .to(name=\"nlp\", class_name=\"ApplyNLP\", function=\"enrich\")\\\n", - " .to(name=\"extract_entities\", handler=\"extract_entities\", function=\"enrich\")\\\n", - " .to(name=\"enrich_entities\", handler=\"enrich_entities\", function=\"enrich\")\\\n", - " .to(\"storey.FlatMap\", \"flatten_entities\", _fn=\"(event)\", function=\"enrich\")\\\n", - " .to(name=\"printer\", handler=\"myprint\", function=\"enrich\")\\\n", - " .to(\">>\", \"output_stream\", path=out_stream)" + "graph.to(name=\"load_url\", handler=\"load_url\").to(\n", + " name=\"to_paragraphs\", handler=\"to_paragraphs\"\n", + ").to(\"storey.FlatMap\", \"flatten_paragraphs\", _fn=\"(event)\").to(\n", + " \">>\", \"q1\", path=internal_stream\n", + ").to(\n", + " name=\"nlp\", class_name=\"ApplyNLP\", function=\"enrich\"\n", + ").to(\n", + " name=\"extract_entities\", handler=\"extract_entities\", function=\"enrich\"\n", + ").to(\n", + " name=\"enrich_entities\", handler=\"enrich_entities\", function=\"enrich\"\n", + ").to(\n", + " \"storey.FlatMap\", \"flatten_entities\", _fn=\"(event)\", function=\"enrich\"\n", + ").to(\n", + " name=\"printer\", handler=\"myprint\", function=\"enrich\"\n", + ").to(\n", + " \">>\", \"output_stream\", path=out_stream\n", + ")" ] }, { @@ -435,10 +445,12 @@ ], "source": [ "# specify the \"enrich\" child function, add extra package requirements\n", - "child = fn.add_child_function('enrich', './nlp.py', 'mlrun/mlrun')\n", - "child.spec.build.commands = [\"python -m pip install spacy\",\n", - " \"python -m spacy download en_core_web_sm\"]\n", - "graph.plot(rankdir='LR')" + "child = fn.add_child_function(\"enrich\", \"./nlp.py\", \"mlrun/mlrun\")\n", + "child.spec.build.commands = [\n", + " \"python -m pip install spacy\",\n", + " \"python -m spacy download en_core_web_sm\",\n", + "]\n", + "graph.plot(rankdir=\"LR\")" ] }, { @@ -650,7 +662,7 @@ } ], "source": [ - "fn.invoke('', body={\"url\": \"v3io:///users/admin/pipe/in.json\"})" + "fn.invoke(\"\", body={\"url\": \"v3io:///users/admin/pipe/in.json\"})" ] }, { diff --git a/docs/serving/getting-started.ipynb b/docs/serving/getting-started.ipynb index 2298bb072b43..cd08dec0f7bc 100644 --- a/docs/serving/getting-started.ipynb +++ b/docs/serving/getting-started.ipynb @@ -39,12 +39,15 @@ "source": [ "# mlrun: start-code\n", "\n", + "\n", "def inc(x):\n", " return x + 1\n", "\n", + "\n", "def mul(x):\n", " return x * 2\n", "\n", + "\n", "class WithState:\n", " def __init__(self, name, context, init_val=0):\n", " self.name = name\n", @@ -55,7 +58,8 @@ " self.counter += 1\n", " print(f\"Echo: {self.name}, x: {x}, counter: {self.counter}\")\n", " return x + self.counter\n", - " \n", + "\n", + "\n", "# mlrun: end-code" ] }, @@ -75,6 +79,7 @@ "outputs": [], "source": [ "import mlrun\n", + "\n", "fn = mlrun.code_to_function(\"simple-graph\", kind=\"serving\", image=\"mlrun/mlrun\")\n", "graph = fn.set_topology(\"flow\")" ] @@ -113,9 +118,9 @@ } ], "source": [ - "graph.to(name=\"+1\", handler='inc')\\\n", - " .to(name=\"*2\", handler='mul')\\\n", - " .to(name=\"(X+counter)\", class_name='WithState').respond()" + "graph.to(name=\"+1\", handler=\"inc\").to(name=\"*2\", handler=\"mul\").to(\n", + " name=\"(X+counter)\", class_name=\"WithState\"\n", + ").respond()" ] }, { @@ -201,7 +206,7 @@ } ], "source": [ - "graph.plot(rankdir='LR')" + "graph.plot(rankdir=\"LR\")" ] }, { @@ -316,7 +321,7 @@ } ], "source": [ - "fn.deploy(project='basic-graph-demo')" + "fn.deploy(project=\"basic-graph-demo\")" ] }, { @@ -352,7 +357,7 @@ } ], "source": [ - "fn.invoke('', body=5)" + "fn.invoke(\"\", body=5)" ] }, { @@ -379,7 +384,7 @@ } ], "source": [ - "fn.invoke('', body=5)" + "fn.invoke(\"\", body=5)" ] } ], diff --git a/docs/serving/graph-example.ipynb b/docs/serving/graph-example.ipynb index 2ad39ac2f0c5..c152b35697ad 100644 --- a/docs/serving/graph-example.ipynb +++ b/docs/serving/graph-example.ipynb @@ -43,31 +43,33 @@ "class ClassifierModel(mlrun.serving.V2ModelServer):\n", " def load(self):\n", " \"\"\"load and initialize the model and/or other elements\"\"\"\n", - " model_file, extra_data = self.get_model('.pkl')\n", - " self.model = load(open(model_file, 'rb'))\n", + " model_file, extra_data = self.get_model(\".pkl\")\n", + " self.model = load(open(model_file, \"rb\"))\n", "\n", " def predict(self, body: dict) -> List:\n", " \"\"\"Generate model predictions from sample.\"\"\"\n", - " feats = np.asarray(body['inputs'])\n", + " feats = np.asarray(body[\"inputs\"])\n", " result: np.ndarray = self.model.predict(feats)\n", " return result.tolist()\n", "\n", + "\n", "# echo class, custom class example\n", "class Echo:\n", " def __init__(self, context, name=None, **kw):\n", " self.context = context\n", " self.name = name\n", " self.kw = kw\n", - " \n", + "\n", " def do(self, x):\n", " print(\"Echo:\", self.name, x)\n", " return x\n", "\n", + "\n", "# error echo function, demo catching error and using custom function\n", "def error_catcher(x):\n", - " x.body = {\"body\": x.body, \"origin_state\": x.origin_state, \"error\": x.error}\n", - " print(\"EchoError:\", x)\n", - " return None" + " x.body = {\"body\": x.body, \"origin_state\": x.origin_state, \"error\": x.error}\n", + " print(\"EchoError:\", x)\n", + " return None" ] }, { @@ -94,11 +96,11 @@ "metadata": {}, "outputs": [], "source": [ - "function = mlrun.code_to_function(\"advanced\", kind=\"serving\", \n", - " image=\"mlrun/mlrun\",\n", - " requirements=['storey'])\n", + "function = mlrun.code_to_function(\n", + " \"advanced\", kind=\"serving\", image=\"mlrun/mlrun\", requirements=[\"storey\"]\n", + ")\n", "graph = function.set_topology(\"flow\", engine=\"async\")\n", - "#function.verbose = True" + "# function.verbose = True" ] }, { @@ -119,7 +121,7 @@ "metadata": {}, "outputs": [], "source": [ - "models_path = 'https://s3.wasabisys.com/iguazio/models/iris/model.pkl'\n", + "models_path = \"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\"\n", "path1 = models_path\n", "path2 = models_path" ] @@ -266,22 +268,25 @@ ], "source": [ "# use built-in storey class or our custom Echo class to create and link Task states\n", - "graph.to(\"storey.Extend\", name=\"enrich\", _fn='({\"tag\": \"something\"})') \\\n", - " .to(class_name=\"Echo\", name=\"pre-process\", some_arg='abc').error_handler(\"catcher\")\n", + "graph.to(\"storey.Extend\", name=\"enrich\", _fn='({\"tag\": \"something\"})').to(\n", + " class_name=\"Echo\", name=\"pre-process\", some_arg=\"abc\"\n", + ").error_handler(\"catcher\")\n", "\n", "# add an Ensemble router with two child models (routes). The \"*\" prefix mark it is a router class\n", - "router = graph.add_step(\"*mlrun.serving.VotingEnsemble\", name=\"ensemble\", after=\"pre-process\")\n", + "router = graph.add_step(\n", + " \"*mlrun.serving.VotingEnsemble\", name=\"ensemble\", after=\"pre-process\"\n", + ")\n", "router.add_route(\"m1\", class_name=\"ClassifierModel\", model_path=path1)\n", "router.add_route(\"m2\", class_name=\"ClassifierModel\", model_path=path2)\n", "\n", "# add the final step (after the router) that handles post processing and responds to the client\n", "graph.add_step(class_name=\"Echo\", name=\"final\", after=\"ensemble\").respond()\n", "\n", - "# add error handling state, run only when/if the \"pre-process\" state fails (keep after=\"\") \n", + "# add error handling state, run only when/if the \"pre-process\" state fails (keep after=\"\")\n", "graph.add_step(handler=\"error_catcher\", name=\"catcher\", full_event=True, after=\"\")\n", "\n", "# plot the graph (using Graphviz) and run a test\n", - "graph.plot(rankdir='LR')" + "graph.plot(rankdir=\"LR\")" ] }, { @@ -299,8 +304,9 @@ "outputs": [], "source": [ "import random\n", + "\n", "iris = load_iris()\n", - "x = random.sample(iris['data'].tolist(), 5)" + "x = random.sample(iris[\"data\"].tolist(), 5)" ] }, { diff --git a/docs/serving/model-serving-get-started.ipynb b/docs/serving/model-serving-get-started.ipynb index 02013318c1d5..4093fbf888ea 100644 --- a/docs/serving/model-serving-get-started.ipynb +++ b/docs/serving/model-serving-get-started.ipynb @@ -50,6 +50,7 @@ "\n", "import mlrun\n", "\n", + "\n", "class ClassifierModel(mlrun.serving.V2ModelServer):\n", " def load(self):\n", " \"\"\"load and initialize the model and/or other elements\"\"\"\n", @@ -85,9 +86,7 @@ "metadata": {}, "outputs": [], "source": [ - "fn = mlrun.code_to_function(\"serving_example\",\n", - " kind=\"serving\", \n", - " image=\"mlrun/mlrun\")" + "fn = mlrun.code_to_function(\"serving_example\", kind=\"serving\", image=\"mlrun/mlrun\")" ] }, { @@ -128,14 +127,18 @@ "graph = fn.set_topology(\"router\")\n", "\n", "# Add the model\n", - "fn.add_model(\"model1\", class_name=\"ClassifierModel\", model_path=\"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\")\n", + "fn.add_model(\n", + " \"model1\",\n", + " class_name=\"ClassifierModel\",\n", + " model_path=\"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\",\n", + ")\n", "\n", "# Add additional models\n", - "#fn.add_model(\"model2\", class_name=\"ClassifierModel\", model_path=\"\")\n", + "# fn.add_model(\"model2\", class_name=\"ClassifierModel\", model_path=\"\")\n", "\n", "# create and use the graph simulator\n", "server = fn.to_mock_server()\n", - "x = load_iris()['data'].tolist()\n", + "x = load_iris()[\"data\"].tolist()\n", "result = server.test(\"/v2/models/model1/infer\", {\"inputs\": x})\n", "\n", "print(result)" @@ -247,24 +250,28 @@ } ], "source": [ - "fn2 = mlrun.code_to_function(\"serving_example_flow\",\n", - " kind=\"serving\", \n", - " image=\"mlrun/mlrun\")\n", + "fn2 = mlrun.code_to_function(\n", + " \"serving_example_flow\", kind=\"serving\", image=\"mlrun/mlrun\"\n", + ")\n", "\n", - "graph2 = fn2.set_topology(\"flow\") \n", + "graph2 = fn2.set_topology(\"flow\")\n", "\n", "graph2_enrich = graph2.to(\"storey.Extend\", name=\"enrich\", _fn='({\"tag\": \"something\"})')\n", "\n", "# add an Ensemble router with two child models (routes)\n", "router = graph2.add_step(mlrun.serving.ModelRouter(), name=\"router\", after=\"enrich\")\n", - "router.add_route(\"m1\", class_name=\"ClassifierModel\", model_path='https://s3.wasabisys.com/iguazio/models/iris/model.pkl')\n", + "router.add_route(\n", + " \"m1\",\n", + " class_name=\"ClassifierModel\",\n", + " model_path=\"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\",\n", + ")\n", "router.respond()\n", "\n", "# Add additional models\n", - "#router.add_route(\"m2\", class_name=\"ClassifierModel\", model_path=path2)\n", + "# router.add_route(\"m2\", class_name=\"ClassifierModel\", model_path=path2)\n", "\n", "# plot the graph (using Graphviz)\n", - "graph2.plot(rankdir='LR')" + "graph2.plot(rankdir=\"LR\")" ] }, { @@ -336,12 +343,15 @@ "source": [ "remote_func_name = \"serving-example-flow\"\n", "project_name = \"graph-basic-concepts\"\n", - "fn_remote = mlrun.code_to_function(remote_func_name,\n", - " project=project_name,\n", - " kind=\"serving\", \n", - " image=\"mlrun/mlrun\")\n", + "fn_remote = mlrun.code_to_function(\n", + " remote_func_name, project=project_name, kind=\"serving\", image=\"mlrun/mlrun\"\n", + ")\n", "\n", - "fn_remote.add_model(\"model1\", class_name=\"ClassifierModel\", model_path=\"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\")\n", + "fn_remote.add_model(\n", + " \"model1\",\n", + " class_name=\"ClassifierModel\",\n", + " model_path=\"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\",\n", + ")\n", "\n", "remote_addr = fn_remote.deploy()" ] @@ -419,9 +429,10 @@ "graph_preprocessing = fn_preprocess.set_topology(\"flow\")\n", "\n", "graph_preprocessing.to(\"storey.Extend\", name=\"enrich\", _fn='({\"tag\": \"something\"})').to(\n", - " \"$remote\", \"remote_func\", url=f'{remote_addr}v2/models/model1/infer', method='put').respond()\n", + " \"$remote\", \"remote_func\", url=f\"{remote_addr}v2/models/model1/infer\", method=\"put\"\n", + ").respond()\n", "\n", - "graph_preprocessing.plot(rankdir='LR')" + "graph_preprocessing.plot(rankdir=\"LR\")" ] }, { @@ -440,7 +451,7 @@ ], "source": [ "fn3_server = fn_preprocess.to_mock_server()\n", - "my_data = '''{\"inputs\":[[5.1, 3.5, 1.4, 0.2],[7.7, 3.8, 6.7, 2.2]]}'''\n", + "my_data = \"\"\"{\"inputs\":[[5.1, 3.5, 1.4, 0.2],[7.7, 3.8, 6.7, 2.2]]}\"\"\"\n", "result = fn3_server.test(\"/v2/models/my_model/infer\", body=my_data)\n", "print(result)" ] @@ -498,7 +509,10 @@ "outputs": [], "source": [ "import os\n", - "streams_prefix = f\"v3io:///users/{os.getenv('V3IO_USERNAME')}/examples/graph-basic-concepts\"\n", + "\n", + "streams_prefix = (\n", + " f\"v3io:///users/{os.getenv('V3IO_USERNAME')}/examples/graph-basic-concepts\"\n", + ")\n", "\n", "input_stream = streams_prefix + \"/in-stream\"\n", "out_stream = streams_prefix + \"/out-stream\"\n", @@ -618,16 +632,17 @@ ], "source": [ "fn_preprocess2 = mlrun.new_function(\"preprocess\", kind=\"serving\")\n", - "fn_preprocess2.add_child_function('echo_func', './echo.py', 'mlrun/mlrun')\n", + "fn_preprocess2.add_child_function(\"echo_func\", \"./echo.py\", \"mlrun/mlrun\")\n", "\n", "graph_preprocess2 = fn_preprocess2.set_topology(\"flow\")\n", "\n", - "graph_preprocess2.to(\"storey.Extend\", name=\"enrich\", _fn='({\"tag\": \"something\"})')\\\n", - " .to(\">>\", \"input_stream\", path=input_stream, group=\"mygroup\")\\\n", - " .to(name=\"echo\", handler=\"echo_handler\", function=\"echo_func\")\\\n", - " .to(\">>\", \"output_stream\", path=out_stream, sharding_func=\"partition\")\n", + "graph_preprocess2.to(\"storey.Extend\", name=\"enrich\", _fn='({\"tag\": \"something\"})').to(\n", + " \">>\", \"input_stream\", path=input_stream, group=\"mygroup\"\n", + ").to(name=\"echo\", handler=\"echo_handler\", function=\"echo_func\").to(\n", + " \">>\", \"output_stream\", path=out_stream, sharding_func=\"partition\"\n", + ")\n", "\n", - "graph_preprocess2.plot(rankdir='LR')" + "graph_preprocess2.plot(rankdir=\"LR\")" ] }, { @@ -650,7 +665,7 @@ "\n", "fn4_server = fn_preprocess2.to_mock_server(current_function=\"*\")\n", "\n", - "my_data = '''{\"inputs\": [[5.1, 3.5, 1.4, 0.2], [7.7, 3.8, 6.7, 2.2]], \"partition\": 0}'''\n", + "my_data = \"\"\"{\"inputs\": [[5.1, 3.5, 1.4, 0.2], [7.7, 3.8, 6.7, 2.2]], \"partition\": 0}\"\"\"\n", "\n", "result = fn4_server.test(\"/v2/models/my_model/infer\", body=my_data)\n", "\n", @@ -724,16 +739,21 @@ "import mlrun\n", "\n", "fn_preprocess2 = mlrun.new_function(\"preprocess\", kind=\"serving\")\n", - "fn_preprocess2.add_child_function('echo_func', './echo.py', 'mlrun/mlrun')\n", + "fn_preprocess2.add_child_function(\"echo_func\", \"./echo.py\", \"mlrun/mlrun\")\n", "\n", "graph_preprocess2 = fn_preprocess2.set_topology(\"flow\")\n", "\n", - "graph_preprocess2.to(\"storey.Extend\", name=\"enrich\", _fn='({\"tag\": \"something\"})')\\\n", - " .to(\">>\", \"input_stream\", path=input_topic, group=\"mygroup\", kafka_bootstrap_servers=brokers)\\\n", - " .to(name=\"echo\", handler=\"echo_handler\", function=\"echo_func\")\\\n", - " .to(\">>\", \"output_stream\", path=out_topic, kafka_bootstrap_servers=brokers)\n", + "graph_preprocess2.to(\"storey.Extend\", name=\"enrich\", _fn='({\"tag\": \"something\"})').to(\n", + " \">>\",\n", + " \"input_stream\",\n", + " path=input_topic,\n", + " group=\"mygroup\",\n", + " kafka_bootstrap_servers=brokers,\n", + ").to(name=\"echo\", handler=\"echo_handler\", function=\"echo_func\").to(\n", + " \">>\", \"output_stream\", path=out_topic, kafka_bootstrap_servers=brokers\n", + ")\n", "\n", - "graph_preprocess2.plot(rankdir='LR')\n", + "graph_preprocess2.plot(rankdir=\"LR\")\n", "\n", "from echo import *\n", "\n", @@ -741,7 +761,7 @@ "\n", "fn4_server.set_error_stream(f\"kafka://{brokers}/{err_topic}\")\n", "\n", - "my_data = '''{\"inputs\":[[5.1, 3.5, 1.4, 0.2],[7.7, 3.8, 6.7, 2.2]]}'''\n", + "my_data = \"\"\"{\"inputs\":[[5.1, 3.5, 1.4, 0.2],[7.7, 3.8, 6.7, 2.2]]}\"\"\"\n", "\n", "result = fn4_server.test(\"/v2/models/my_model/infer\", body=my_data)\n", "\n", diff --git a/docs/serving/realtime-pipelines.ipynb b/docs/serving/realtime-pipelines.ipynb index c7e111e96411..bccbbac7e104 100644 --- a/docs/serving/realtime-pipelines.ipynb +++ b/docs/serving/realtime-pipelines.ipynb @@ -96,8 +96,8 @@ "outputs": [], "source": [ "if self.context.verbose:\n", - " self.context.logger.info('my message', some_arg='text')\n", - " x = self.context.get_param('x', 0)" + " self.context.logger.info(\"my message\", some_arg=\"text\")\n", + " x = self.context.get_param(\"x\", 0)" ] }, { @@ -138,14 +138,18 @@ "graph = fn.set_topology(\"router\")\n", "\n", "# Add the model\n", - "fn.add_model(\"model1\", class_name=\"ClassifierModel\", model_path=\"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\")\n", + "fn.add_model(\n", + " \"model1\",\n", + " class_name=\"ClassifierModel\",\n", + " model_path=\"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\",\n", + ")\n", "\n", "# Add additional models\n", - "#fn.add_model(\"model2\", class_name=\"ClassifierModel\", model_path=\"\")\n", + "# fn.add_model(\"model2\", class_name=\"ClassifierModel\", model_path=\"\")\n", "\n", "# create and use the graph simulator\n", "server = fn.to_mock_server()\n", - "x = load_iris()['data'].tolist()\n", + "x = load_iris()[\"data\"].tolist()\n", "result = server.test(\"/v2/models/model1/infer\", {\"inputs\": x})\n", "\n", "print(result)" @@ -196,24 +200,28 @@ } ], "source": [ - "fn2 = mlrun.code_to_function(\"serving_example_flow\",\n", - " kind=\"serving\", \n", - " image=\"mlrun/mlrun\")\n", + "fn2 = mlrun.code_to_function(\n", + " \"serving_example_flow\", kind=\"serving\", image=\"mlrun/mlrun\"\n", + ")\n", "\n", - "graph2 = fn2.set_topology(\"flow\") \n", + "graph2 = fn2.set_topology(\"flow\")\n", "\n", "graph2_enrich = graph2.to(\"storey.Extend\", name=\"enrich\", _fn='({\"tag\": \"something\"})')\n", "\n", "# add an Ensemble router with two child models (routes)\n", "router = graph2.add_step(mlrun.serving.ModelRouter(), name=\"router\", after=\"enrich\")\n", - "router.add_route(\"m1\", class_name=\"ClassifierModel\", model_path='https://s3.wasabisys.com/iguazio/models/iris/model.pkl')\n", + "router.add_route(\n", + " \"m1\",\n", + " class_name=\"ClassifierModel\",\n", + " model_path=\"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\",\n", + ")\n", "router.respond()\n", "\n", "# Add additional models\n", - "#router.add_route(\"m2\", class_name=\"ClassifierModel\", model_path=path2)\n", + "# router.add_route(\"m2\", class_name=\"ClassifierModel\", model_path=path2)\n", "\n", "# plot the graph (using Graphviz)\n", - "graph2.plot(rankdir='LR')" + "graph2.plot(rankdir=\"LR\")" ] }, { @@ -266,10 +274,12 @@ "metadata": {}, "outputs": [], "source": [ - "fn.add_child_function('enrich', \n", - " './entity_extraction.ipynb', \n", - " image='mlrun/mlrun',\n", - " requirements=[\"storey\", \"sklearn\"])" + "fn.add_child_function(\n", + " \"enrich\",\n", + " \"./entity_extraction.ipynb\",\n", + " image=\"mlrun/mlrun\",\n", + " requirements=[\"storey\", \"sklearn\"],\n", + ")" ] }, { @@ -347,7 +357,7 @@ "metadata": {}, "outputs": [], "source": [ - "graph2.plot(rankdir='LR')" + "graph2.plot(rankdir=\"LR\")" ] }, { diff --git a/docs/serving/writing-custom-steps.ipynb b/docs/serving/writing-custom-steps.ipynb index 1e736d86a929..05129323011d 100644 --- a/docs/serving/writing-custom-steps.ipynb +++ b/docs/serving/writing-custom-steps.ipynb @@ -53,7 +53,7 @@ " self.context = context\n", " self.name = name\n", " self.kw = kw\n", - " \n", + "\n", " def do(self, x):\n", " print(\"Echo:\", self.name, x)\n", " return x" @@ -129,9 +129,9 @@ "\n", "graph_echo = fn_echo.set_topology(\"flow\")\n", "\n", - "graph_echo.to(class_name=\"Echo\", name=\"pre-process\", some_arg='abc')\n", + "graph_echo.to(class_name=\"Echo\", name=\"pre-process\", some_arg=\"abc\")\n", "\n", - "graph_echo.plot(rankdir='LR')" + "graph_echo.plot(rankdir=\"LR\")" ] }, { diff --git a/docs/store/datastore.md b/docs/store/datastore.md index 3e091c198be3..7798901c0094 100644 --- a/docs/store/datastore.md +++ b/docs/store/datastore.md @@ -112,6 +112,7 @@ authentication methods that use the `fsspec` mechanism. ### Google cloud storage * `GOOGLE_APPLICATION_CREDENTIALS` — path to the application credentials to use (in the form of a JSON file). This can be used if this file is located in a location on shared storage, accessible to pods executing MLRun jobs. -* `GCP_CREDENTIALS` — when the credentials file cannot be mounted to the pod, this environment variable may contain -the contents of this file. If configured in the function pod, MLRun dumps its contents to a temporary file -and points `GOOGLE_APPLICATION_CREDENTIALS` at it. \ No newline at end of file +* `GCP_CREDENTIALS` — when the credentials file cannot be mounted to the pod, this secret or environment variable +may contain the contents of this file. If configured in the function pod, MLRun dumps its contents to a temporary file +and points `GOOGLE_APPLICATION_CREDENTIALS` at it. An exception is `BigQuerySource`, which passes `GCP_CREDENTIALS`'s +contents directly to the query engine. \ No newline at end of file diff --git a/docs/training/built-in-training-function.ipynb b/docs/training/built-in-training-function.ipynb index 1816ed6bb397..460fa32f5b5c 100644 --- a/docs/training/built-in-training-function.ipynb +++ b/docs/training/built-in-training-function.ipynb @@ -14,7 +14,7 @@ "id": "0e900797", "metadata": {}, "source": [ - "The MLRun [Function Hub](https://www.mlrun.org/marketplace/) includes, among other things, training functions. The most commonly used function for training is [`auto_trainer`](https://github.com/mlrun/functions/tree/development/auto_trainer), which includes the following handlers:\n", + "The MLRun [Function Hub](https://www.mlrun.org/hub/) includes, among other things, training functions. The most commonly used function for training is [`auto_trainer`](https://github.com/mlrun/functions/tree/development/auto_trainer), which includes the following handlers:\n", "\n", "- [Train](#train)\n", "- [Evaluate](#evaluate)" @@ -58,11 +58,14 @@ "outputs": [], "source": [ "import mlrun\n", + "\n", "# Set the base project name\n", - "project_name_base = 'training-test'\n", + "project_name_base = \"training-test\"\n", "\n", "# Initialize the MLRun project object\n", - "project = mlrun.get_or_create_project(project_name_base, context=\"./\", user_project=True)" + "project = mlrun.get_or_create_project(\n", + " project_name_base, context=\"./\", user_project=True\n", + ")" ] }, { @@ -106,14 +109,16 @@ " params={\n", " # Model parameters:\n", " \"model_class\": \"sklearn.ensemble.RandomForestClassifier\",\n", - " \"model_kwargs\": {\"max_depth\": 8}, # Could be also passed as \"MODEL_max_depth\": 8\n", + " \"model_kwargs\": {\n", + " \"max_depth\": 8\n", + " }, # Could be also passed as \"MODEL_max_depth\": 8\n", " \"model_name\": \"MyModel\",\n", " # Dataset parameters:\n", " \"drop_columns\": [\"feat_0\", \"feat_2\"],\n", " \"train_test_split_size\": 0.2,\n", " \"random_state\": 42,\n", " \"label_columns\": \"labels\",\n", - " }\n", + " },\n", ")" ] }, @@ -229,7 +234,7 @@ " \n", " ```{admonition} Note\n", " The custom objects are imported in the order they came in this dictionary (or json). If a custom \n", - " object is dependant on another, make sure to put it below the one it relies on.\n", + " object is dependent on another, make sure to put it below the one it relies on.\n", " ``` \n", " \n", " \n", @@ -283,9 +288,9 @@ "source": [ "evaluate_run = auto_trainer.run(\n", " handler=\"evaluate\",\n", - " inputs={\"dataset\": train_run.outputs['test_set']},\n", + " inputs={\"dataset\": train_run.outputs[\"test_set\"]},\n", " params={\n", - " \"model\": train_run.outputs['model'],\n", + " \"model\": train_run.outputs[\"model\"],\n", " \"label_columns\": \"labels\",\n", " },\n", ")" diff --git a/docs/training/create-a-basic-training-job.ipynb b/docs/training/create-a-basic-training-job.ipynb index 961717127eac..80149d08129e 100644 --- a/docs/training/create-a-basic-training-job.ipynb +++ b/docs/training/create-a-basic-training-job.ipynb @@ -126,7 +126,7 @@ " filename=\"trainer.py\",\n", " kind=\"job\",\n", " image=\"mlrun/mlrun\",\n", - " handler=\"train\"\n", + " handler=\"train\",\n", ")" ] }, @@ -392,8 +392,10 @@ ], "source": [ "run = training_job.run(\n", - " inputs={\"dataset\": \"https://igz-demo-datasets.s3.us-east-2.amazonaws.com/cancer-dataset.csv\"}, \n", - " params = {\"n_estimators\": 100, \"learning_rate\": 1e-1, \"max_depth\": 3}\n", + " inputs={\n", + " \"dataset\": \"https://igz-demo-datasets.s3.us-east-2.amazonaws.com/cancer-dataset.csv\"\n", + " },\n", + " params={\"n_estimators\": 100, \"learning_rate\": 1e-1, \"max_depth\": 3},\n", ")" ] }, diff --git a/docs/tutorial/01-mlrun-basics.ipynb b/docs/tutorial/01-mlrun-basics.ipynb index ea9d70280a20..e39cccc1cfa2 100644 --- a/docs/tutorial/01-mlrun-basics.ipynb +++ b/docs/tutorial/01-mlrun-basics.ipynb @@ -16,12 +16,12 @@ "cell_type": "markdown", "id": "d4cbf4a8-7e92-49f8-be36-c48a99fb4527", "metadata": { - "tags": [ - "docs-only" - ], "pycharm": { "name": "#%% md\n" - } + }, + "tags": [ + "docs-only" + ] }, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mlrun/mlrun/blob/development/docs/tutorial/colab/01-mlrun-basics-colab.ipynb)" @@ -48,7 +48,7 @@ "- [**Use the MLRun built-in Function Hub functions for training**](#use-hub)\n", "- [**Build, test, and deploy model serving functions**](#model-serving)\n", "\n", - "{octicon}`video` [**Watch the video tutorial**](https://youtu.be/xI8KVGLlj7Q)." + "" ] }, { @@ -67,7 +67,7 @@ "\n", "**Before you start, make sure the MLRun client package is installed and configured properly:**\n", "\n", - "This notebook uses sklearn. If it is not installed in your environment run `!pip install scikit-learn~=1.0`." + "This notebook uses sklearn. If it is not installed in your environment run `!pip install scikit-learn~=1.2`." ] }, { @@ -82,7 +82,7 @@ "outputs": [], "source": [ "# Install MLRun and sklearn, run this only once (restart the notebook after the install !!!)\n", - "%pip install mlrun scikit-learn~=1.0" + "%pip install mlrun scikit-learn~=1.2" ] }, { @@ -271,7 +271,13 @@ } ], "source": [ - "data_gen_fn = project.set_function(\"data-prep.py\", name=\"data-prep\", kind=\"job\", image=\"mlrun/mlrun\", handler=\"breast_cancer_generator\")\n", + "data_gen_fn = project.set_function(\n", + " \"data-prep.py\",\n", + " name=\"data-prep\",\n", + " kind=\"job\",\n", + " image=\"mlrun/mlrun\",\n", + " handler=\"breast_cancer_generator\",\n", + ")\n", "project.save() # save the project with the latest config" ] }, @@ -878,7 +884,7 @@ "\n", "## Train a model using an MLRun built-in Function Hub\n", "\n", - "MLRun provides a [**Function Hub**](https://www.mlrun.org/marketplace/) that hosts a set of pre-implemented and\n", + "MLRun provides a [**Function Hub**](https://www.mlrun.org/hub/) that hosts a set of pre-implemented and\n", "validated ML, DL, and data processing functions.\n", "\n", "You can import the `auto-trainer` hub function that can: train an ML model using a variety of ML frameworks; generate\n", @@ -897,7 +903,7 @@ "outputs": [], "source": [ "# Import the function\n", - "trainer = mlrun.import_function('hub://auto_trainer')" + "trainer = mlrun.import_function(\"hub://auto_trainer\")" ] }, { @@ -910,7 +916,7 @@ }, "source": [ "\n", - "See the `auto_trainer` function usage instructions in [the Function Hub](https://www.mlrun.org/marketplace/functions/master/auto_trainer/) or by typing `trainer.doc()`\n", + "See the `auto_trainer` function usage instructions in [the Function Hub](https://www.mlrun.org/hub/functions/master/auto_trainer/) or by typing `trainer.doc()`\n", "\n", "**Run the function on the cluster (if there is)**" ] @@ -1167,15 +1173,16 @@ } ], "source": [ - "trainer_run = project.run_function(trainer,\n", + "trainer_run = project.run_function(\n", + " trainer,\n", " inputs={\"dataset\": gen_data_run.outputs[\"dataset\"]},\n", - " params = {\n", + " params={\n", " \"model_class\": \"sklearn.ensemble.RandomForestClassifier\",\n", " \"train_test_split_size\": 0.2,\n", " \"label_columns\": \"label\",\n", - " \"model_name\": 'cancer',\n", - " }, \n", - " handler='train',\n", + " \"model_name\": \"cancer\",\n", + " },\n", + " handler=\"train\",\n", ")" ] }, @@ -1338,7 +1345,7 @@ ], "source": [ "# Display HTML output artifacts\n", - "trainer_run.artifact('confusion-matrix').show()" + "trainer_run.artifact(\"confusion-matrix\").show()" ] }, { @@ -1365,7 +1372,12 @@ "metadata": {}, "outputs": [], "source": [ - "serving_fn = mlrun.new_function(\"serving\", image=\"python:3.9\", kind=\"serving\", requirements=[\"mlrun[complete]\", \"scikit-learn~=1.2.0\"])" + "serving_fn = mlrun.new_function(\n", + " \"serving\",\n", + " image=\"python:3.9\",\n", + " kind=\"serving\",\n", + " requirements=[\"mlrun[complete]\", \"scikit-learn~=1.2.0\"],\n", + ")" ] }, { @@ -1405,7 +1417,11 @@ } ], "source": [ - "serving_fn.add_model('cancer-classifier',model_path=trainer_run.outputs[\"model\"], class_name='mlrun.frameworks.sklearn.SklearnModelServer')" + "serving_fn.add_model(\n", + " \"cancer-classifier\",\n", + " model_path=trainer_run.outputs[\"model\"],\n", + " class_name=\"mlrun.frameworks.sklearn.SklearnModelServer\",\n", + ")" ] }, { @@ -1599,14 +1615,41 @@ } ], "source": [ - "my_data = {\"inputs\"\n", - " :[[\n", - " 1.371e+01, 2.083e+01, 9.020e+01, 5.779e+02, 1.189e-01, 1.645e-01,\n", - " 9.366e-02, 5.985e-02, 2.196e-01, 7.451e-02, 5.835e-01, 1.377e+00,\n", - " 3.856e+00, 5.096e+01, 8.805e-03, 3.029e-02, 2.488e-02, 1.448e-02,\n", - " 1.486e-02, 5.412e-03, 1.706e+01, 2.814e+01, 1.106e+02, 8.970e+02,\n", - " 1.654e-01, 3.682e-01, 2.678e-01, 1.556e-01, 3.196e-01, 1.151e-01]\n", - " ]\n", + "my_data = {\n", + " \"inputs\": [\n", + " [\n", + " 1.371e01,\n", + " 2.083e01,\n", + " 9.020e01,\n", + " 5.779e02,\n", + " 1.189e-01,\n", + " 1.645e-01,\n", + " 9.366e-02,\n", + " 5.985e-02,\n", + " 2.196e-01,\n", + " 7.451e-02,\n", + " 5.835e-01,\n", + " 1.377e00,\n", + " 3.856e00,\n", + " 5.096e01,\n", + " 8.805e-03,\n", + " 3.029e-02,\n", + " 2.488e-02,\n", + " 1.448e-02,\n", + " 1.486e-02,\n", + " 5.412e-03,\n", + " 1.706e01,\n", + " 2.814e01,\n", + " 1.106e02,\n", + " 8.970e02,\n", + " 1.654e-01,\n", + " 3.682e-01,\n", + " 2.678e-01,\n", + " 1.556e-01,\n", + " 3.196e-01,\n", + " 1.151e-01,\n", + " ]\n", + " ]\n", "}\n", "server.test(\"/v2/models/cancer-classifier/infer\", body=my_data)" ] diff --git a/docs/tutorial/02-model-training.ipynb b/docs/tutorial/02-model-training.ipynb index d9facb1a94e9..e43608260ba9 100644 --- a/docs/tutorial/02-model-training.ipynb +++ b/docs/tutorial/02-model-training.ipynb @@ -17,11 +17,12 @@ "- [**Hyper-parameter tuning and model/experiment comparison**](#hyper-param)\n", "- [**Build and test the model serving functions**](#model-serving)\n", "\n", - "{octicon}`video` [**Watch the video tutorial**](https://youtu.be/bZgBsmLMdQo).\n", + "%%HTML\n", + "\n", "\n", "## MLRun installation and configuration\n", "\n", - "Before running this notebook make sure `mlrun` and `sklearn` packages are installed (`pip install mlrun scikit-learn~=1.0`) and that you have configured the access to the MLRun service. " + "Before running this notebook make sure `mlrun` and `sklearn` packages are installed (`pip install mlrun scikit-learn~=1.2`) and that you have configured the access to the MLRun service. " ] }, { @@ -66,6 +67,7 @@ ], "source": [ "import mlrun\n", + "\n", "project = mlrun.get_or_create_project(\"tutorial\", context=\"./\", user_project=True)" ] }, @@ -167,7 +169,9 @@ "metadata": {}, "outputs": [], "source": [ - "trainer = project.set_function(\"trainer.py\", name=\"trainer\", kind=\"job\", image=\"mlrun/mlrun\", handler=\"train\")" + "trainer = project.set_function(\n", + " \"trainer.py\", name=\"trainer\", kind=\"job\", image=\"mlrun/mlrun\", handler=\"train\"\n", + ")" ] }, { @@ -190,8 +194,11 @@ "source": [ "import pandas as pd\n", "from sklearn.datasets import load_breast_cancer\n", + "\n", "breast_cancer = load_breast_cancer()\n", - "breast_cancer_dataset = pd.DataFrame(data=breast_cancer.data, columns=breast_cancer.feature_names)\n", + "breast_cancer_dataset = pd.DataFrame(\n", + " data=breast_cancer.data, columns=breast_cancer.feature_names\n", + ")\n", "breast_cancer_labels = pd.DataFrame(data=breast_cancer.target, columns=[\"label\"])\n", "breast_cancer_dataset = pd.concat([breast_cancer_dataset, breast_cancer_labels], axis=1)\n", "\n", @@ -449,10 +456,10 @@ ], "source": [ "trainer_run = project.run_function(\n", - " \"trainer\", \n", - " inputs={\"dataset\": \"cancer-dataset.csv\"}, \n", - " params = {\"n_estimators\": 100, \"learning_rate\": 1e-1, \"max_depth\": 3},\n", - " local=True\n", + " \"trainer\",\n", + " inputs={\"dataset\": \"cancer-dataset.csv\"},\n", + " params={\"n_estimators\": 100, \"learning_rate\": 1e-1, \"max_depth\": 3},\n", + " local=True,\n", ")" ] }, @@ -586,7 +593,7 @@ } ], "source": [ - "trainer_run.artifact('feature-importance').show()" + "trainer_run.artifact(\"feature-importance\").show()" ] }, { @@ -606,7 +613,7 @@ "metadata": {}, "outputs": [], "source": [ - "trainer_run.artifact('model').meta.export(\"model.zip\")" + "trainer_run.artifact(\"model\").meta.export(\"model.zip\")" ] }, { @@ -638,7 +645,9 @@ "metadata": {}, "outputs": [], "source": [ - "dataset_artifact = project.log_dataset(\"cancer-dataset\", df=breast_cancer_dataset, index=False)" + "dataset_artifact = project.log_dataset(\n", + " \"cancer-dataset\", df=breast_cancer_dataset, index=False\n", + ")" ] }, { @@ -897,14 +906,14 @@ ], "source": [ "hp_tuning_run = project.run_function(\n", - " \"trainer\", \n", - " inputs={\"dataset\": dataset_artifact.uri}, \n", + " \"trainer\",\n", + " inputs={\"dataset\": dataset_artifact.uri},\n", " hyperparams={\n", - " \"n_estimators\": [10, 100, 1000], \n", - " \"learning_rate\": [1e-1, 1e-3], \n", - " \"max_depth\": [2, 8]\n", - " }, \n", - " selector=\"max.accuracy\", \n", + " \"n_estimators\": [10, 100, 1000],\n", + " \"learning_rate\": [1e-1, 1e-3],\n", + " \"max_depth\": [2, 8],\n", + " },\n", + " selector=\"max.accuracy\",\n", ")" ] }, @@ -1419,7 +1428,11 @@ ], "source": [ "serving_fn = mlrun.new_function(\"serving\", image=\"mlrun/mlrun\", kind=\"serving\")\n", - "serving_fn.add_model('cancer-classifier',model_path=hp_tuning_run.outputs[\"model\"], class_name='mlrun.frameworks.sklearn.SklearnModelServer')" + "serving_fn.add_model(\n", + " \"cancer-classifier\",\n", + " model_path=hp_tuning_run.outputs[\"model\"],\n", + " class_name=\"mlrun.frameworks.sklearn.SklearnModelServer\",\n", + ")" ] }, { @@ -1464,14 +1477,41 @@ "# Create a mock (simulator of the real-time function)\n", "server = serving_fn.to_mock_server()\n", "\n", - "my_data = {\"inputs\"\n", - " :[[\n", - " 1.371e+01, 2.083e+01, 9.020e+01, 5.779e+02, 1.189e-01, 1.645e-01,\n", - " 9.366e-02, 5.985e-02, 2.196e-01, 7.451e-02, 5.835e-01, 1.377e+00,\n", - " 3.856e+00, 5.096e+01, 8.805e-03, 3.029e-02, 2.488e-02, 1.448e-02,\n", - " 1.486e-02, 5.412e-03, 1.706e+01, 2.814e+01, 1.106e+02, 8.970e+02,\n", - " 1.654e-01, 3.682e-01, 2.678e-01, 1.556e-01, 3.196e-01, 1.151e-01]\n", - " ]\n", + "my_data = {\n", + " \"inputs\": [\n", + " [\n", + " 1.371e01,\n", + " 2.083e01,\n", + " 9.020e01,\n", + " 5.779e02,\n", + " 1.189e-01,\n", + " 1.645e-01,\n", + " 9.366e-02,\n", + " 5.985e-02,\n", + " 2.196e-01,\n", + " 7.451e-02,\n", + " 5.835e-01,\n", + " 1.377e00,\n", + " 3.856e00,\n", + " 5.096e01,\n", + " 8.805e-03,\n", + " 3.029e-02,\n", + " 2.488e-02,\n", + " 1.448e-02,\n", + " 1.486e-02,\n", + " 5.412e-03,\n", + " 1.706e01,\n", + " 2.814e01,\n", + " 1.106e02,\n", + " 8.970e02,\n", + " 1.654e-01,\n", + " 3.682e-01,\n", + " 2.678e-01,\n", + " 1.556e-01,\n", + " 3.196e-01,\n", + " 1.151e-01,\n", + " ]\n", + " ]\n", "}\n", "server.test(\"/v2/models/cancer-classifier/infer\", body=my_data)" ] diff --git a/docs/tutorial/03-model-serving.ipynb b/docs/tutorial/03-model-serving.ipynb index b0b56875b227..3432a497d63e 100644 --- a/docs/tutorial/03-model-serving.ipynb +++ b/docs/tutorial/03-model-serving.ipynb @@ -4,6 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ + "(serving-ml-dl-models)=\n", "# Serving pre-trained ML/DL models\n", "\n", "This notebook demonstrate how to serve standard ML/DL models using **MLRun Serving**.\n", @@ -26,7 +27,7 @@ "- [**Build a custom serving class**](#custom-class)\n", "- [**Building advanced model serving graph**](#serving=graph)\n", "\n", - "{octicon}`video` [**Watch the video tutorial**](https://youtu.be/OUjOus4dZfw).\n", + "\n", "\n", "## MLRun installation and configuration\n", "\n", @@ -59,6 +60,7 @@ "outputs": [], "source": [ "import mlrun\n", + "\n", "project = mlrun.get_or_create_project(\"tutorial\", context=\"./\", user_project=True)" ] }, @@ -107,23 +109,28 @@ "metadata": {}, "outputs": [], "source": [ - "models_dir = mlrun.get_sample_path('models/serving/')\n", + "models_dir = mlrun.get_sample_path(\"models/serving/\")\n", "\n", "# We choose the correct model to avoid pickle warnings\n", "import sys\n", - "suffix = mlrun.__version__.split(\"-\")[0].replace(\".\", \"_\") if sys.version_info[1] > 7 else \"3.7\"\n", "\n", - "framework = 'sklearn' # change to 'keras' to try the 2nd option \n", + "suffix = (\n", + " mlrun.__version__.split(\"-\")[0].replace(\".\", \"_\")\n", + " if sys.version_info[1] > 7\n", + " else \"3.7\"\n", + ")\n", + "\n", + "framework = \"sklearn\" # change to 'keras' to try the 2nd option\n", "kwargs = {}\n", "if framework == \"sklearn\":\n", - " serving_class = 'mlrun.frameworks.sklearn.SklearnModelServer'\n", - " model_path = models_dir + f'sklearn-{suffix}.pkl'\n", - " image = 'mlrun/mlrun'\n", + " serving_class = \"mlrun.frameworks.sklearn.SklearnModelServer\"\n", + " model_path = models_dir + f\"sklearn-{suffix}.pkl\"\n", + " image = \"mlrun/mlrun\"\n", "else:\n", - " serving_class = 'mlrun.frameworks.tf_keras.TFKerasModelServer'\n", - " model_path = models_dir + 'keras.h5'\n", - " image = 'mlrun/ml-models' # or mlrun/ml-models-gpu when using GPUs\n", - " kwargs['labels'] = {'model-format': 'h5'}" + " serving_class = \"mlrun.frameworks.tf_keras.TFKerasModelServer\"\n", + " model_path = models_dir + \"keras.h5\"\n", + " image = \"mlrun/ml-models\" # or mlrun/ml-models-gpu when using GPUs\n", + " kwargs[\"labels\"] = {\"model-format\": \"h5\"}" ] }, { @@ -141,7 +148,7 @@ "metadata": {}, "outputs": [], "source": [ - "model_object = project.log_model(f'{framework}-model', model_file=model_path, **kwargs)" + "model_object = project.log_model(f\"{framework}-model\", model_file=model_path, **kwargs)" ] }, { @@ -224,7 +231,9 @@ ], "source": [ "serving_fn = mlrun.new_function(\"serving\", image=image, kind=\"serving\", requirements={})\n", - "serving_fn.add_model(framework ,model_path=model_object.uri, class_name=serving_class, to_list=True)\n", + "serving_fn.add_model(\n", + " framework, model_path=model_object.uri, class_name=serving_class, to_list=True\n", + ")\n", "\n", "# Plot the serving topology input -> router -> model\n", "serving_fn.plot(rankdir=\"LR\")" @@ -311,8 +320,8 @@ } ], "source": [ - "sample = {\"inputs\":[[5.1, 3.5, 1.4, 0.2],[7.7, 3.8, 6.7, 2.2]]}\n", - "server.test(path=f'/v2/models/{framework}/infer',body=sample)" + "sample = {\"inputs\": [[5.1, 3.5, 1.4, 0.2], [7.7, 3.8, 6.7, 2.2]]}\n", + "server.test(path=f\"/v2/models/{framework}/infer\", body=sample)" ] }, { @@ -392,7 +401,7 @@ } ], "source": [ - "serving_fn.invoke(path=f'/v2/models/{framework}/infer',body=sample)" + "serving_fn.invoke(path=f\"/v2/models/{framework}/infer\", body=sample)" ] }, { @@ -471,9 +480,9 @@ ], "metadata": { "kernelspec": { - "display_name": "conda", + "display_name": "Python 3 (ipykernel)", "language": "python", - "name": "conda-root-py" + "name": "python3" }, "language_info": { "codemirror_mode": { @@ -485,7 +494,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.7.7" } }, "nbformat": 4, diff --git a/docs/tutorial/04-pipeline.ipynb b/docs/tutorial/04-pipeline.ipynb index 7e81d39239cd..e8a8776a0f65 100644 --- a/docs/tutorial/04-pipeline.ipynb +++ b/docs/tutorial/04-pipeline.ipynb @@ -76,6 +76,7 @@ ], "source": [ "import mlrun\n", + "\n", "project = mlrun.get_or_create_project(\"tutorial\", context=\"./\", user_project=True)" ] }, @@ -184,7 +185,13 @@ } ], "source": [ - "project.set_function(\"data-prep.py\", name=\"data-prep\", kind=\"job\", image=\"mlrun/mlrun\", handler=\"breast_cancer_generator\")" + "project.set_function(\n", + " \"data-prep.py\",\n", + " name=\"data-prep\",\n", + " kind=\"job\",\n", + " image=\"mlrun/mlrun\",\n", + " handler=\"breast_cancer_generator\",\n", + ")" ] }, { @@ -463,8 +470,9 @@ "# Run the workflow\n", "run_id = project.run(\n", " workflow_path=\"./workflow.py\",\n", - " arguments={\"model_name\": \"cancer-classifier\"}, \n", - " watch=True)" + " arguments={\"model_name\": \"cancer-classifier\"},\n", + " watch=True,\n", + ")" ] }, { @@ -535,14 +543,41 @@ ], "source": [ "# Create a mock (simulator of the real-time function)\n", - "my_data = {\"inputs\"\n", - " :[[\n", - " 1.371e+01, 2.083e+01, 9.020e+01, 5.779e+02, 1.189e-01, 1.645e-01,\n", - " 9.366e-02, 5.985e-02, 2.196e-01, 7.451e-02, 5.835e-01, 1.377e+00,\n", - " 3.856e+00, 5.096e+01, 8.805e-03, 3.029e-02, 2.488e-02, 1.448e-02,\n", - " 1.486e-02, 5.412e-03, 1.706e+01, 2.814e+01, 1.106e+02, 8.970e+02,\n", - " 1.654e-01, 3.682e-01, 2.678e-01, 1.556e-01, 3.196e-01, 1.151e-01]\n", - " ]\n", + "my_data = {\n", + " \"inputs\": [\n", + " [\n", + " 1.371e01,\n", + " 2.083e01,\n", + " 9.020e01,\n", + " 5.779e02,\n", + " 1.189e-01,\n", + " 1.645e-01,\n", + " 9.366e-02,\n", + " 5.985e-02,\n", + " 2.196e-01,\n", + " 7.451e-02,\n", + " 5.835e-01,\n", + " 1.377e00,\n", + " 3.856e00,\n", + " 5.096e01,\n", + " 8.805e-03,\n", + " 3.029e-02,\n", + " 2.488e-02,\n", + " 1.448e-02,\n", + " 1.486e-02,\n", + " 5.412e-03,\n", + " 1.706e01,\n", + " 2.814e01,\n", + " 1.106e02,\n", + " 8.970e02,\n", + " 1.654e-01,\n", + " 3.682e-01,\n", + " 2.678e-01,\n", + " 1.556e-01,\n", + " 3.196e-01,\n", + " 1.151e-01,\n", + " ]\n", + " ]\n", "}\n", "serving_fn.invoke(\"/v2/models/cancer-classifier/infer\", body=my_data)" ] diff --git a/docs/tutorial/05-model-monitoring.ipynb b/docs/tutorial/05-model-monitoring.ipynb index 22a293136fa9..667aff5aa99f 100644 --- a/docs/tutorial/05-model-monitoring.ipynb +++ b/docs/tutorial/05-model-monitoring.ipynb @@ -103,10 +103,15 @@ "source": [ "# We choose the correct model to avoid pickle warnings\n", "import sys\n", - "suffix = mlrun.__version__.split(\"-\")[0].replace(\".\", \"_\") if sys.version_info[1] > 7 else \"3.7\"\n", "\n", - "model_path = mlrun.get_sample_path(f'models/model-monitoring/model-{suffix}.pkl')\n", - "training_set_path = mlrun.get_sample_path('data/model-monitoring/iris_dataset.csv')" + "suffix = (\n", + " mlrun.__version__.split(\"-\")[0].replace(\".\", \"_\")\n", + " if sys.version_info[1] > 7\n", + " else \"3.7\"\n", + ")\n", + "\n", + "model_path = mlrun.get_sample_path(f\"models/model-monitoring/model-{suffix}.pkl\")\n", + "training_set_path = mlrun.get_sample_path(\"data/model-monitoring/iris_dataset.csv\")" ] }, { @@ -139,7 +144,7 @@ " model_file=model_path,\n", " framework=\"sklearn\",\n", " training_set=pd.read_csv(training_set_path),\n", - " label_column=\"label\"\n", + " label_column=\"label\",\n", ")" ] }, @@ -171,7 +176,7 @@ "\n", "## Import and deploy the serving function\n", "\n", - "Import the [model server](https://github.com/mlrun/functions/tree/master/v2_model_server) function from the [MLRun Function Hub](https://www.mlrun.org/marketplace/). Additionally, mount the filesytem, add the model that was logged via experiment tracking, and enable drift detection.\n", + "Import the [model server](https://github.com/mlrun/functions/tree/master/v2_model_server) function from the [MLRun Function Hub](https://www.mlrun.org/hub/). Additionally, mount the filesytem, add the model that was logged via experiment tracking, and enable drift detection.\n", "\n", "The core line here is `serving_fn.set_tracking()` that creates the required infrastructure behind the scenes to perform drift detection. See the [Model monitoring overview](https://docs.mlrun.org/en/latest/monitoring/model-monitoring-deployment.html) for more info on what is deployed." ] @@ -188,7 +193,7 @@ "outputs": [], "source": [ "# Import the serving function from the Function Hub and mount filesystem\n", - "serving_fn = mlrun.import_function('hub://v2_model_server', new_name=\"serving\")\n", + "serving_fn = mlrun.import_function(\"hub://v2_model_server\", new_name=\"serving\")\n", "\n", "# Add the model to the serving function's routing spec\n", "serving_fn.add_model(model_name, model_path=model_artifact.uri)\n", @@ -320,13 +325,17 @@ "logging.getLogger(name=\"mlrun\").setLevel(logging.WARNING)\n", "\n", "# Get training set as list\n", - "iris_data = pd.read_csv(training_set_path).drop(\"label\", axis=1).to_dict(orient=\"split\")[\"data\"]\n", + "iris_data = (\n", + " pd.read_csv(training_set_path).drop(\"label\", axis=1).to_dict(orient=\"split\")[\"data\"]\n", + ")\n", "\n", "# Simulate traffic using random elements from training set\n", "for i in tqdm(range(12_000)):\n", " data_point = choice(iris_data)\n", - " serving_fn.invoke(f'v2/models/{model_name}/infer', json.dumps({'inputs': [data_point]}))\n", - " \n", + " serving_fn.invoke(\n", + " f\"v2/models/{model_name}/infer\", json.dumps({\"inputs\": [data_point]})\n", + " )\n", + "\n", "# Resume normal logging\n", "logging.getLogger(name=\"mlrun\").setLevel(logging.INFO)" ] diff --git a/docs/tutorial/06-add-mlops-to-code.ipynb b/docs/tutorial/06-add-mlops-to-code.ipynb index bd6ba752c232..1460abbbbc94 100644 --- a/docs/tutorial/06-add-mlops-to-code.ipynb +++ b/docs/tutorial/06-add-mlops-to-code.ipynb @@ -2,22 +2,14 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "# Add MLOps to existing code" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "This tutorial showcases how easy it is to apply MLRun on your existing code. With only 7 lines of code, you get:\n", "* Experiment tracking — Track every single run of your experiment to learn what yielded the best results.\n", @@ -36,11 +28,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "\n", "## Get the data\n", @@ -50,11 +38,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "\n", "## Code review\n", @@ -94,11 +78,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "#### MLRun context\n", "\n", @@ -112,11 +92,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "#### Get Training Set\n", "\n", @@ -132,17 +108,13 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "### Apply MLRun\n", "\n", "Now use the `apply_mlrun` function from MLRun's LightGBM framework integration. MLRun automatically wraps the LightGBM module and enables automatic logging and evaluation.\n", "\n", - "Line 219:\n", + "Line 209:\n", "```python\n", "apply_mlrun(context=context)\n", "```" @@ -150,17 +122,13 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "### Logging the dataset\n", "\n", "Similar to the way you got the training set, you get the test dataset as an input from the MLRun content.\n", "\n", - "Line 235:\n", + "Line 226:\n", "```python\n", "test_df = context.get_input(\"test_set\", \"./test.csv\").as_df()\n", "# Instead of: `test_df = pd.read_csv('./test.csv')`\n", @@ -169,17 +137,13 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "#### Save the submission\n", "\n", "Finally, instead of saving the result locally, log the submission to MLRun.\n", "\n", - "Line 267:\n", + "Line 258:\n", "```python\n", "context.log_dataset(key=\"taxi_fare_submission\", df=submission, format=\"csv\") \n", "# Instead of: `submission.to_csv('taxi_fare_submission.csv',index=False)`\n", @@ -188,11 +152,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "\n", "## Run the script with MLRun\n", @@ -203,11 +163,7 @@ { "cell_type": "code", "execution_count": 2, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [], "source": [ "import mlrun" @@ -215,11 +171,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "### Create a project\n", "\n", @@ -229,11 +181,7 @@ { "cell_type": "code", "execution_count": 3, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -244,16 +192,14 @@ } ], "source": [ - "project = mlrun.get_or_create_project(name=\"apply-mlrun-tutorial\", context=\"./\", user_project=True)" + "project = mlrun.get_or_create_project(\n", + " name=\"apply-mlrun-tutorial\", context=\"./\", user_project=True\n", + ")" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "### Create a function\n", "\n", @@ -263,11 +209,7 @@ { "cell_type": "code", "execution_count": 4, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -285,17 +227,13 @@ " filename=\"./src/script.py\",\n", " name=\"apply-mlrun-tutorial-function\",\n", " kind=\"job\",\n", - " image=\"mlrun/ml-models\"\n", + " image=\"mlrun/ml-models\",\n", ")" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "### Run the function\n", "\n", @@ -305,11 +243,7 @@ { "cell_type": "code", "execution_count": 5, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -558,18 +492,14 @@ "script_run = script_function.run(\n", " inputs={\n", " \"train_set\": \"https://s3.us-east-1.wasabisys.com/iguazio/data/nyc-taxi/train.csv\",\n", - " \"test_set\": \"https://s3.us-east-1.wasabisys.com/iguazio/data/nyc-taxi/test.csv\"\n", + " \"test_set\": \"https://s3.us-east-1.wasabisys.com/iguazio/data/nyc-taxi/test.csv\",\n", " },\n", ")" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "\n", "## Review outputs\n", @@ -580,11 +510,7 @@ { "cell_type": "code", "execution_count": 6, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -608,11 +534,7 @@ }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "MLRun **automatically detects all the metrics calculated** and collects the data along with the training. Here there was one validation set named `valid_0` and the RMSE metric was calculated on it. You can see the RMSE values per iteration plot and the final score including the features importance plot.\n", "\n", @@ -622,11 +544,7 @@ { "cell_type": "code", "execution_count": 11, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -712,17 +630,13 @@ } ], "source": [ - "script_run.artifact('valid_0_rmse_plot').show()" + "script_run.artifact(\"valid_0_rmse_plot\").show()" ] }, { "cell_type": "code", "execution_count": 10, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -808,16 +722,12 @@ } ], "source": [ - "script_run.artifact('valid_0-feature-importance').show()" + "script_run.artifact(\"valid_0-feature-importance\").show()" ] }, { "cell_type": "markdown", - "metadata": { - "pycharm": { - "name": "#%% md\n" - } - }, + "metadata": {}, "source": [ "And of course, you can also see the submission that was logged:" ] @@ -825,11 +735,7 @@ { "cell_type": "code", "execution_count": 12, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "metadata": {}, "outputs": [ { "data": { @@ -939,7 +845,7 @@ } ], "source": [ - "script_run.artifact('taxi_fare_submission').show()" + "script_run.artifact(\"taxi_fare_submission\").show()" ] } ], diff --git a/docs/tutorial/07-batch-infer.ipynb b/docs/tutorial/07-batch-infer.ipynb index f121a1e42561..95f60259e60c 100644 --- a/docs/tutorial/07-batch-infer.ipynb +++ b/docs/tutorial/07-batch-infer.ipynb @@ -7,7 +7,7 @@ "source": [ "# Batch inference and drift detection\n", "\n", - "This tutorial leverages a function from the [MLRun Function Hub](https://www.mlrun.org/marketplace/) to perform [batch inference](https://www.mlrun.org/marketplace/functions/master/batch_inference/) using a logged model and a new prediction dataset. The function also calculates data drift by comparing the new prediction dataset with the original training set.\n", + "This tutorial leverages a function from the [MLRun Function Hub](https://www.mlrun.org/hub/) to perform [batch inference](https://www.mlrun.org/hub/functions/master/batch_inference/) using a logged model and a new prediction dataset. The function also calculates data drift by comparing the new prediction dataset with the original training set.\n", "\n", "Make sure you have reviewed the basics in MLRun [**Quick Start Tutorial**](../01-mlrun-basics.html)." ] @@ -108,11 +108,16 @@ "source": [ "# We choose the correct model to avoid pickle warnings\n", "import sys\n", - "suffix = mlrun.__version__.split(\"-\")[0].replace(\".\", \"_\") if sys.version_info[1] > 7 else \"3.7\"\n", "\n", - "model_path = mlrun.get_sample_path(f'models/batch-predict/model-{suffix}.pkl')\n", - "training_set_path = mlrun.get_sample_path('data/batch-predict/training_set.parquet')\n", - "prediction_set_path = mlrun.get_sample_path('data/batch-predict/prediction_set.parquet')" + "suffix = (\n", + " mlrun.__version__.split(\"-\")[0].replace(\".\", \"_\")\n", + " if sys.version_info[1] > 7\n", + " else \"3.7\"\n", + ")\n", + "\n", + "model_path = mlrun.get_sample_path(f\"models/batch-predict/model-{suffix}.pkl\")\n", + "training_set_path = mlrun.get_sample_path(\"data/batch-predict/training_set.parquet\")\n", + "prediction_set_path = mlrun.get_sample_path(\"data/batch-predict/prediction_set.parquet\")" ] }, { @@ -584,7 +589,7 @@ " model_file=model_path,\n", " framework=\"sklearn\",\n", " training_set=pd.read_parquet(training_set_path),\n", - " label_column=\"label\"\n", + " label_column=\"label\",\n", ")" ] }, @@ -607,7 +612,7 @@ "\n", "## Import and run the batch inference function\n", "\n", - "Next, import the [batch inference](https://www.mlrun.org/marketplace/functions/master/batch_inference/) function from the [MLRun Function Hub](https://www.mlrun.org/marketplace/):" + "Next, import the [batch inference](https://www.mlrun.org/hub/functions/master/batch_inference/) function from the [MLRun Function Hub](https://www.mlrun.org/hub/):" ] }, { @@ -879,11 +884,11 @@ " inputs={\n", " \"dataset\": prediction_set_path,\n", " # If you do not log a dataset with your model, you can pass it in here:\n", - "# \"sample_set\" : training_set_path\n", + " # \"sample_set\" : training_set_path\n", " },\n", " params={\n", " \"model\": model_artifact.uri,\n", - " \"perform_drift_analysis\" : True,\n", + " \"perform_drift_analysis\": True,\n", " },\n", ")" ] @@ -1219,6 +1224,7 @@ "source": [ "# Data/concept drift per feature\n", "import json\n", + "\n", "json.loads(run.artifact(\"features_drift_results\").get())" ] }, diff --git a/docs/tutorial/colab/01-mlrun-basics-colab.ipynb b/docs/tutorial/colab/01-mlrun-basics-colab.ipynb index 925d12d0fbda..c5a9e950d9f7 100644 --- a/docs/tutorial/colab/01-mlrun-basics-colab.ipynb +++ b/docs/tutorial/colab/01-mlrun-basics-colab.ipynb @@ -922,7 +922,7 @@ "\n", "## Train a model using an MLRun built-in function \n", "\n", - "MLRun provides a [**public Function Hub**](https://www.mlrun.org/marketplace/) which hosts a set of pre-implemented and\n", + "MLRun provides a [**public Function Hub**](https://www.mlrun.org/hub/) which hosts a set of pre-implemented and\n", "validated ML, DL, and data processing functions.\n", "\n", "You can import the `auto-trainer` hub function which can train an ML model using variety of ML frameworks, generate\n", @@ -954,7 +954,7 @@ }, "source": [ "\n", - "> See the `auto_trainer` function usage instructions in [the Function Hub](https://www.mlrun.org/marketplace/functions/master/auto_trainer/) or by typing `trainer.doc()`\n", + "> See the `auto_trainer` function usage instructions in [the Function Hub](https://www.mlrun.org/hub/functions/master/auto_trainer/) or by typing `trainer.doc()`\n", "\n", "**Run the function on the cluster (if exist):**" ] diff --git a/docs/tutorial/index.md b/docs/tutorial/index.md index a2bd5cedf72d..b5232ca2edf3 100644 --- a/docs/tutorial/index.md +++ b/docs/tutorial/index.md @@ -3,10 +3,12 @@ The following tutorials provide a hands-on introduction to using MLRun to implement a data science workflow and automate machine-learning operations (MLOps). -- [**Quick-start Tutorial**](./01-mlrun-basics.html) ({octicon}`video` [**watch video**](https://youtu.be/xI8KVGLlj7Q)) +- [**Quick-start Tutorial**](./01-mlrun-basics.html) - [**Targeted Tutorials**](#other-tutorial) - [**End to End Demos**](#e2e-demos) +

+ (quick-start-tutorial)= ````{card} Make sure you start with the Quick start tutorial to understand the basics diff --git a/docs/tutorial/src/script.py b/docs/tutorial/src/script.py index 28dba9c828b4..e1bc7b96d504 100644 --- a/docs/tutorial/src/script.py +++ b/docs/tutorial/src/script.py @@ -199,32 +199,23 @@ def add_datetime_info(dataset): "scale_pos_weight": 1, "zero_as_missing": True, "seed": 0, - "num_rounds": 50000, + # "categorical_feature": "name:year,month,day,weekday", } -train_set = lgbm.Dataset( - x_train, - y_train, - silent=False, - categorical_feature=["year", "month", "day", "weekday"], -) -valid_set = lgbm.Dataset( - x_test, - y_test, - silent=False, - categorical_feature=["year", "month", "day", "weekday"], -) +train_set = lgbm.Dataset(x_train, y_train) +valid_set = lgbm.Dataset(x_test, y_test) # [MLRun] Apply MLRun on the LightGBM module: apply_mlrun(context=context) model = lgbm.train( params, - train_set=train_set, num_boost_round=10000, - early_stopping_rounds=500, + train_set=train_set, valid_sets=[valid_set], + callbacks=[lgbm.early_stopping(stopping_rounds=500)], ) + del x_train del y_train del x_test diff --git a/docs/tutorial/src/workflow.py b/docs/tutorial/src/workflow.py index 3fadf9b20068..033051140119 100644 --- a/docs/tutorial/src/workflow.py +++ b/docs/tutorial/src/workflow.py @@ -16,7 +16,7 @@ def pipeline(model_name="cancer-classifier"): # Train a model using the auto_trainer hub function train = mlrun.run_function( - "hub://auto_trainer", + "hub://auto-trainer", inputs={"dataset": ingest.outputs["dataset"]}, params={ "model_class": "sklearn.ensemble.RandomForestClassifier", diff --git a/examples/load-project.ipynb b/examples/load-project.ipynb index 213330eaea51..e568e7f37cb6 100644 --- a/examples/load-project.ipynb +++ b/examples/load-project.ipynb @@ -64,13 +64,13 @@ "\n", "# source Git Repo\n", "# YOU SHOULD fork this to your account and use the fork if you plan on modifying the code\n", - "url = 'git://github.com/mlrun/demo-xgb-project.git' # refs/tags/v0.4.7'\n", + "url = \"git://github.com/mlrun/demo-xgb-project.git\" # refs/tags/v0.4.7'\n", "\n", "# alternatively can use tar files, e.g.\n", - "#url = 'v3io:///users/admin/tars/src-project.tar.gz'\n", + "# url = 'v3io:///users/admin/tars/src-project.tar.gz'\n", "\n", "# change if you want to clone into a different dir, can use clone=True to override the dir content\n", - "project_dir = path.join(str(Path.home()), 'my_proj')\n", + "project_dir = path.join(str(Path.home()), \"my_proj\")\n", "proj = load_project(project_dir, url, clone=True)" ] }, @@ -181,11 +181,16 @@ "source": [ "# You can update the function .py and .yaml from a notebook (code + spec)\n", "# the \"code_output\" option will generate a .py file from our notebook which can be used for src control and local runs\n", - "xgbfn = code_to_function('xgb', filename='notebooks/train-xgboost.ipynb' ,kind='job', code_output='src/iris.py')\n", + "xgbfn = code_to_function(\n", + " \"xgb\",\n", + " filename=\"notebooks/train-xgboost.ipynb\",\n", + " kind=\"job\",\n", + " code_output=\"src/iris.py\",\n", + ")\n", "\n", - "# tell the builder to clone this repo into the function container \n", - "xgbfn.spec.build.source = './'\n", - "xgbfn.export('src/iris.yaml')" + "# tell the builder to clone this repo into the function container\n", + "xgbfn.spec.build.source = \"./\"\n", + "xgbfn.export(\"src/iris.yaml\")" ] }, { @@ -268,7 +273,7 @@ ], "source": [ "# read specific function spec\n", - "print(proj.func('xgb').to_yaml())" + "print(proj.func(\"xgb\").to_yaml())" ] }, { @@ -519,7 +524,8 @@ ], "source": [ "from mlrun import run_local, new_task\n", - "run_local(new_task(handler='iris_generator'), proj.func('xgb'), workdir='./')" + "\n", + "run_local(new_task(handler=\"iris_generator\"), proj.func(\"xgb\"), workdir=\"./\")" ] }, { @@ -738,7 +744,13 @@ } ], "source": [ - "proj.run('main', arguments={}, artifact_path='v3io:///users/admin/mlrun/kfp/{{workflow.uid}}/', dirty=True, watch=True)" + "proj.run(\n", + " \"main\",\n", + " arguments={},\n", + " artifact_path=\"v3io:///users/admin/mlrun/kfp/{{workflow.uid}}/\",\n", + " dirty=True,\n", + " watch=True,\n", + ")" ] }, { @@ -758,7 +770,7 @@ "metadata": {}, "outputs": [], "source": [ - "proj.source = 'v3io:///users/admin/my-proj'" + "proj.source = \"v3io:///users/admin/my-proj\"" ] }, { diff --git a/examples/mlrun_basics.ipynb b/examples/mlrun_basics.ipynb index ae524fe4d4b9..b14d3bd66c9d 100644 --- a/examples/mlrun_basics.ipynb +++ b/examples/mlrun_basics.ipynb @@ -99,7 +99,8 @@ "source": [ "from mlrun import run_local, new_task, mlconf\n", "from os import path\n", - "mlconf.dbpath = mlconf.dbpath or './'" + "\n", + "mlconf.dbpath = mlconf.dbpath or \"./\"" ] }, { @@ -158,9 +159,9 @@ "metadata": {}, "outputs": [], "source": [ - "out = mlconf.artifact_path or path.abspath('./data')\n", + "out = mlconf.artifact_path or path.abspath(\"./data\")\n", "# {{run.uid}} will be substituted with the run id, so output will be written to different directoried per run\n", - "artifact_path = path.join(out, '{{run.uid}}')" + "artifact_path = path.join(out, \"{{run.uid}}\")" ] }, { @@ -177,7 +178,11 @@ "metadata": {}, "outputs": [], "source": [ - "task = new_task(name='demo', params={'p1': 5}, artifact_path=artifact_path).with_secrets('file', 'secrets.txt').set_label('type', 'demo')" + "task = (\n", + " new_task(name=\"demo\", params={\"p1\": 5}, artifact_path=artifact_path)\n", + " .with_secrets(\"file\", \"secrets.txt\")\n", + " .set_label(\"type\", \"demo\")\n", + ")" ] }, { @@ -442,7 +447,7 @@ ], "source": [ "# run our task using our new function\n", - "run_object = run_local(task, command='training.py')" + "run_object = run_local(task, command=\"training.py\")" ] }, { @@ -934,7 +939,7 @@ } ], "source": [ - "run_object.artifact('dataset')" + "run_object.artifact(\"dataset\")" ] }, { @@ -1230,7 +1235,9 @@ } ], "source": [ - "run = run_local(task.with_hyper_params({'p2': [5, 2, 3]}, 'min.loss'), command='training.py')" + "run = run_local(\n", + " task.with_hyper_params({\"p2\": [5, 2, 3]}, \"min.loss\"), command=\"training.py\"\n", + ")" ] }, { @@ -1424,46 +1431,53 @@ "\n", "# define a function with spec as parameter\n", "import time\n", - "def handler(context, p1=1, p2='xx'):\n", + "\n", + "\n", + "def handler(context, p1=1, p2=\"xx\"):\n", " \"\"\"this is a simple function\n", - " \n", + "\n", " :param p1: first param\n", " :param p2: another param\n", " \"\"\"\n", " # access input metadata, values, and inputs\n", - " print(f'Run: {context.name} (uid={context.uid})')\n", - " print(f'Params: p1={p1}, p2={p2}')\n", - " \n", + " print(f\"Run: {context.name} (uid={context.uid})\")\n", + " print(f\"Params: p1={p1}, p2={p2}\")\n", + "\n", " time.sleep(1)\n", - " \n", + "\n", " # log the run results (scalar values)\n", - " context.log_result('accuracy', p1 * 2)\n", - " context.log_result('loss', p1 * 3)\n", - " \n", - " # add a lable/tag to this run \n", - " context.set_label('category', 'tests')\n", - " \n", - " # create a matplot figure and store as artifact \n", + " context.log_result(\"accuracy\", p1 * 2)\n", + " context.log_result(\"loss\", p1 * 3)\n", + "\n", + " # add a lable/tag to this run\n", + " context.set_label(\"category\", \"tests\")\n", + "\n", + " # create a matplot figure and store as artifact\n", " fig, ax = plt.subplots()\n", " np.random.seed(0)\n", " x, y = np.random.normal(size=(2, 200))\n", " color, size = np.random.random((2, 200))\n", " ax.scatter(x, y, c=color, s=500 * size, alpha=0.3)\n", - " ax.grid(color='lightgray', alpha=0.7)\n", - " \n", - " context.log_artifact(PlotArtifact('myfig', body=fig, title='my plot'))\n", - " \n", - " # create a dataframe artifact \n", - " df = pd.DataFrame([{'A':10, 'B':100}, {'A':11,'B':110}, {'A':12,'B':120}])\n", - " context.log_dataset('mydf', df=df)\n", - " \n", - " # Log an ML Model artifact \n", - " context.log_model('mymodel', body=b'abc is 123', \n", - " model_file='model.txt', model_dir='data', \n", - " metrics={'accuracy':0.85}, parameters={'xx':'abc'},\n", - " labels={'framework': 'xgboost'})\n", + " ax.grid(color=\"lightgray\", alpha=0.7)\n", + "\n", + " context.log_artifact(PlotArtifact(\"myfig\", body=fig, title=\"my plot\"))\n", + "\n", + " # create a dataframe artifact\n", + " df = pd.DataFrame([{\"A\": 10, \"B\": 100}, {\"A\": 11, \"B\": 110}, {\"A\": 12, \"B\": 120}])\n", + " context.log_dataset(\"mydf\", df=df)\n", + "\n", + " # Log an ML Model artifact\n", + " context.log_model(\n", + " \"mymodel\",\n", + " body=b\"abc is 123\",\n", + " model_file=\"model.txt\",\n", + " model_dir=\"data\",\n", + " metrics={\"accuracy\": 0.85},\n", + " parameters={\"xx\": \"abc\"},\n", + " labels={\"framework\": \"xgboost\"},\n", + " )\n", "\n", - " return 'my resp'" + " return \"my resp\"" ] }, { @@ -1723,7 +1737,9 @@ } ], "source": [ - "task = new_task(name='demo2', handler=handler, artifact_path=artifact_path).with_params(p1=7)\n", + "task = new_task(name=\"demo2\", handler=handler, artifact_path=artifact_path).with_params(\n", + " p1=7\n", + ")\n", "run = run_local(task)" ] }, @@ -2025,7 +2041,9 @@ } ], "source": [ - "task = new_task(name='demo2', handler=handler, artifact_path=artifact_path).with_param_file('params.csv', 'max.accuracy')\n", + "task = new_task(\n", + " name=\"demo2\", handler=handler, artifact_path=artifact_path\n", + ").with_param_file(\"params.csv\", \"max.accuracy\")\n", "run = run_local(task)" ] }, diff --git a/examples/mlrun_dask.ipynb b/examples/mlrun_dask.ipynb index 316e8c8a8506..efe615968896 100644 --- a/examples/mlrun_dask.ipynb +++ b/examples/mlrun_dask.ipynb @@ -14,7 +14,7 @@ "outputs": [], "source": [ "# recommended, installing the exact package versions as we use in the worker\n", - "#!pip install dask==2.12.0 distributed==2.14.0 " + "#!pip install dask==2.12.0 distributed==2.14.0" ] }, { @@ -30,9 +30,9 @@ "metadata": {}, "outputs": [], "source": [ - "# function that will be distributed \n", + "# function that will be distributed\n", "def inc(x):\n", - " return x+2" + " return x + 2" ] }, { @@ -50,13 +50,13 @@ "outputs": [], "source": [ "# wrapper function, uses the dask client object\n", - "def hndlr(context, x=1,y=2):\n", - " context.logger.info('params: x={},y={}'.format(x,y))\n", - " print('params: x={},y={}'.format(x,y))\n", + "def hndlr(context, x=1, y=2):\n", + " context.logger.info(\"params: x={},y={}\".format(x, y))\n", + " print(\"params: x={},y={}\".format(x, y))\n", " x = context.dask_client.submit(inc, x)\n", " print(x)\n", " print(x.result())\n", - " context.log_result('y', x.result())" + " context.log_result(\"y\", x.result())" ] }, { @@ -76,7 +76,8 @@ "outputs": [], "source": [ "from mlrun import new_function, mlconf, code_to_function, mount_v3io, new_task\n", - "#mlconf.dbpath = 'http://mlrun-api:8080'" + "\n", + "# mlconf.dbpath = 'http://mlrun-api:8080'" ] }, { @@ -107,7 +108,7 @@ "outputs": [], "source": [ "# create the function from the notebook code + annotations, add volumes\n", - "dsf = code_to_function('mydask', kind='dask').apply(mount_v3io())" + "dsf = code_to_function(\"mydask\", kind=\"dask\").apply(mount_v3io())" ] }, { @@ -116,10 +117,10 @@ "metadata": {}, "outputs": [], "source": [ - "dsf.spec.image = 'mlrun/ml-models'\n", + "dsf.spec.image = \"mlrun/ml-models\"\n", "dsf.spec.remote = True\n", "dsf.spec.replicas = 1\n", - "dsf.spec.service_type = 'NodePort'" + "dsf.spec.service_type = \"NodePort\"" ] }, { @@ -399,7 +400,7 @@ } ], "source": [ - "myrun = dsf.run(handler=hndlr, params={'x': 12})" + "myrun = dsf.run(handler=hndlr, params={\"x\": 12})" ] }, { @@ -520,8 +521,9 @@ ], "source": [ "from mlrun import import_function\n", + "\n", "# Functions url: db:///[:tag]\n", - "dsf_obj = import_function('db://default/mydask')\n", + "dsf_obj = import_function(\"db://default/mydask\")\n", "c = dsf_obj.client" ] }, @@ -550,12 +552,15 @@ "outputs": [], "source": [ "@dsl.pipeline(name=\"dask_pipeline\")\n", - "def dask_pipe(x=1,y=10):\n", + "def dask_pipe(x=1, y=10):\n", " # use_db option will use a function (DB) pointer instead of adding the function spec to the YAML\n", - " myrun = dsf.as_step(new_task(handler=hndlr, name=\"dask_pipeline\", params={'x': x, 'y': y}), use_db=True)\n", - " \n", + " myrun = dsf.as_step(\n", + " new_task(handler=hndlr, name=\"dask_pipeline\", params={\"x\": x, \"y\": y}),\n", + " use_db=True,\n", + " )\n", + "\n", " # if the step (dask client) need v3io access u should add: .apply(mount_v3io())\n", - " \n", + "\n", " # if its a new image we may want to tell Kubeflow to reload the image\n", " # myrun.container.set_image_pull_policy('Always')" ] @@ -578,7 +583,7 @@ ], "source": [ "# for pipeline debug\n", - "kfp.compiler.Compiler().compile(dask_pipe, 'daskpipe.yaml', type_check=False)" + "kfp.compiler.Compiler().compile(dask_pipe, \"daskpipe.yaml\", type_check=False)" ] }, { @@ -631,13 +636,15 @@ } ], "source": [ - "arguments={'x':4,'y':-5}\n", - "artifact_path = '/User/test'\n", - "run_id = run_pipeline(dask_pipe, \n", - " arguments, \n", - " artifact_path=artifact_path,\n", - " run=\"DaskExamplePipeline\", \n", - " experiment=\"dask pipe\")" + "arguments = {\"x\": 4, \"y\": -5}\n", + "artifact_path = \"/User/test\"\n", + "run_id = run_pipeline(\n", + " dask_pipe,\n", + " arguments,\n", + " artifact_path=artifact_path,\n", + " run=\"DaskExamplePipeline\",\n", + " experiment=\"dask pipe\",\n", + ")" ] }, { @@ -647,9 +654,10 @@ "outputs": [], "source": [ "from mlrun import wait_for_pipeline_completion, get_run_db\n", + "\n", "wait_for_pipeline_completion(run_id)\n", "db = get_run_db().connect()\n", - "db.list_runs(project='default', labels=f'workflow={run_id}').show()\n" + "db.list_runs(project=\"default\", labels=f\"workflow={run_id}\").show()" ] } ], diff --git a/examples/mlrun_db.ipynb b/examples/mlrun_db.ipynb index b6bb0a365ba0..a11f215473b1 100644 --- a/examples/mlrun_db.ipynb +++ b/examples/mlrun_db.ipynb @@ -16,7 +16,7 @@ "outputs": [], "source": [ "# specify the DB path (use 'http://mlrun-api:8080' for api service)\n", - "mlconf.dbpath = mlconf.dbpath or 'http://mlrun-api:8080'\n", + "mlconf.dbpath = mlconf.dbpath or \"http://mlrun-api:8080\"\n", "db = get_run_db().connect()" ] }, @@ -251,7 +251,7 @@ ], "source": [ "# list all runs\n", - "db.list_runs('download').show()" + "db.list_runs(\"download\").show()" ] }, { @@ -619,7 +619,7 @@ ], "source": [ "# list all artifact for version \"latest\"\n", - "db.list_artifacts('', tag='latest', project='iris').show()" + "db.list_artifacts(\"\", tag=\"latest\", project=\"iris\").show()" ] }, { @@ -1182,8 +1182,8 @@ } ], "source": [ - "# check different artifact versions \n", - "db.list_artifacts('ch', tag='*').show()" + "# check different artifact versions\n", + "db.list_artifacts(\"ch\", tag=\"*\").show()" ] }, { @@ -1192,7 +1192,7 @@ "metadata": {}, "outputs": [], "source": [ - "db.del_runs(state='completed')" + "db.del_runs(state=\"completed\")" ] }, { @@ -1201,7 +1201,7 @@ "metadata": {}, "outputs": [], "source": [ - "db.del_artifacts(tag='*')" + "db.del_artifacts(tag=\"*\")" ] }, { diff --git a/examples/mlrun_export_import.ipynb b/examples/mlrun_export_import.ipynb index 8f987c369952..f32636df7d14 100644 --- a/examples/mlrun_export_import.ipynb +++ b/examples/mlrun_export_import.ipynb @@ -24,31 +24,30 @@ "import zipfile\n", "from mlrun import DataItem\n", "\n", - "def open_archive(context, \n", - " target_dir: str,\n", - " archive_url: DataItem = None):\n", + "\n", + "def open_archive(context, target_dir: str, archive_url: DataItem = None):\n", " \"\"\"Open a file/object archive into a target directory\n", - " \n", + "\n", " :param target_dir: target directory\n", " :param archive_url: source archive path/url (MLRun DataItem object)\n", - " \n", + "\n", " :returns: content dir\n", " \"\"\"\n", - " \n", + "\n", " # Define locations\n", " archive_file = archive_url.local()\n", " os.makedirs(target_dir, exist_ok=True)\n", - " context.logger.info('Verified directories')\n", - " \n", + " context.logger.info(\"Verified directories\")\n", + "\n", " # Extract dataset from zip\n", - " context.logger.info('Extracting zip')\n", - " zip_ref = zipfile.ZipFile(archive_file, 'r')\n", + " context.logger.info(\"Extracting zip\")\n", + " zip_ref = zipfile.ZipFile(archive_file, \"r\")\n", " zip_ref.extractall(target_dir)\n", " zip_ref.close()\n", - " \n", - " context.logger.info(f'extracted archive to {target_dir}')\n", + "\n", + " context.logger.info(f\"extracted archive to {target_dir}\")\n", " # use target_path= to specify and absolute target path (vs artifact_path)\n", - " context.log_artifact('content', target_path=target_dir)\n" + " context.log_artifact(\"content\", target_path=target_dir)" ] }, { @@ -75,11 +74,16 @@ "source": [ "# create job function object from notebook code and add doc/metadata\n", "import mlrun\n", - "fn = mlrun.code_to_function('file_utils', kind='job',\n", - " handler='open_archive', image='mlrun/mlrun',\n", - " description = \"this function opens a zip archive into a local/mounted folder\",\n", - " categories = ['fileutils'],\n", - " labels = {'author': 'me'})\n" + "\n", + "fn = mlrun.code_to_function(\n", + " \"file_utils\",\n", + " kind=\"job\",\n", + " handler=\"open_archive\",\n", + " image=\"mlrun/mlrun\",\n", + " description=\"this function opens a zip archive into a local/mounted folder\",\n", + " categories=[\"fileutils\"],\n", + " labels={\"author\": \"me\"},\n", + ")" ] }, { @@ -160,7 +164,7 @@ ], "source": [ "# save to a file (and can be pushed to a git)\n", - "fn.export('function.yaml')" + "fn.export(\"function.yaml\")" ] }, { @@ -176,7 +180,7 @@ "metadata": {}, "outputs": [], "source": [ - "mlrun.mlconf.dbpath = mlrun.mlconf.dbpath or 'http://mlrun-api:8080'" + "mlrun.mlconf.dbpath = mlrun.mlconf.dbpath or \"http://mlrun-api:8080\"" ] }, { @@ -201,9 +205,9 @@ ], "source": [ "# load from local file\n", - "xfn = mlrun.import_function('./function.yaml')\n", + "xfn = mlrun.import_function(\"./function.yaml\")\n", "\n", - "# load function from MLRun functions hub \n", + "# load function from MLRun functions hub\n", "# xfn = mlrun.import_function('hub://open_archive')\n", "\n", "# get function doc\n", @@ -218,15 +222,18 @@ "source": [ "from os import path\n", "from mlrun.platforms import auto_mount\n", + "\n", "# for auto choice between Iguazio platform and k8s PVC\n", - "# should set the env var for PVC: MLRUN_PVC_MOUNT=:, or use mount_pvc() \n", + "# should set the env var for PVC: MLRUN_PVC_MOUNT=:, or use mount_pvc()\n", "xfn.apply(auto_mount())\n", "\n", "# create and run the task\n", - "images_path = path.abspath('images')\n", - "open_archive_task = mlrun.new_task('download',\n", - " params={'target_dir': images_path},\n", - " inputs={'archive_url': 'http://iguazio-sample-data.s3.amazonaws.com/catsndogs.zip'})" + "images_path = path.abspath(\"images\")\n", + "open_archive_task = mlrun.new_task(\n", + " \"download\",\n", + " params={\"target_dir\": images_path},\n", + " inputs={\"archive_url\": \"http://iguazio-sample-data.s3.amazonaws.com/catsndogs.zip\"},\n", + ")" ] }, { @@ -485,8 +492,9 @@ "outputs": [], "source": [ "from mlrun import mlconf\n", - "mlconf.dbpath = mlconf.dbpath or './'\n", - "artifact_path = mlconf.artifact_path or path.abspath('data')" + "\n", + "mlconf.dbpath = mlconf.dbpath or \"./\"\n", + "artifact_path = mlconf.artifact_path or path.abspath(\"data\")" ] }, { @@ -740,6 +748,7 @@ "outputs": [], "source": [ "from mlrun import function_to_module, get_or_create_ctx\n", + "\n", "mod = function_to_module(xfn)" ] }, @@ -750,9 +759,11 @@ "outputs": [], "source": [ "# create a context object and DataItem objects\n", - "# you can also use existing context and data objects (e.g. from parant function) \n", - "context = get_or_create_ctx('myfunc')\n", - "data = mlrun.run.get_dataitem('http://iguazio-sample-data.s3.amazonaws.com/catsndogs.zip')" + "# you can also use existing context and data objects (e.g. from parant function)\n", + "context = get_or_create_ctx(\"myfunc\")\n", + "data = mlrun.run.get_dataitem(\n", + " \"http://iguazio-sample-data.s3.amazonaws.com/catsndogs.zip\"\n", + ")" ] }, { diff --git a/examples/mlrun_jobs.ipynb b/examples/mlrun_jobs.ipynb index 4f473289c61a..32c3ee9653cb 100644 --- a/examples/mlrun_jobs.ipynb +++ b/examples/mlrun_jobs.ipynb @@ -82,8 +82,8 @@ "source": [ "# mlrun: ignore\n", "# do not remove the comment above (it is a directive to nuclio, ignore that cell during build)\n", - "# if the nuclio-jupyter package is not installed run !pip install nuclio-jupyter and restart the kernel \n", - "import nuclio " + "# if the nuclio-jupyter package is not installed run !pip install nuclio-jupyter and restart the kernel\n", + "import nuclio" ] }, { @@ -138,11 +138,8 @@ "import time\n", "import pandas as pd\n", "\n", - "def training(\n", - " context: MLClientCtx,\n", - " p1: int = 1,\n", - " p2: int = 2\n", - ") -> None:\n", + "\n", + "def training(context: MLClientCtx, p1: int = 1, p2: int = 2) -> None:\n", " \"\"\"Train a model.\n", "\n", " :param context: The runtime context object.\n", @@ -150,36 +147,38 @@ " :param p2: Another model parameter.\n", " \"\"\"\n", " # access input metadata, values, and inputs\n", - " print(f'Run: {context.name} (uid={context.uid})')\n", - " print(f'Params: p1={p1}, p2={p2}')\n", - " context.logger.info('started training')\n", - " \n", + " print(f\"Run: {context.name} (uid={context.uid})\")\n", + " print(f\"Params: p1={p1}, p2={p2}\")\n", + " context.logger.info(\"started training\")\n", + "\n", " # \n", - " \n", + "\n", " # log the run results (scalar values)\n", - " context.log_result('accuracy', p1 * 2)\n", - " context.log_result('loss', p1 * 3)\n", - " \n", - " # add a lable/tag to this run \n", - " context.set_label('category', 'tests')\n", - " \n", - " # log a simple artifact + label the artifact \n", + " context.log_result(\"accuracy\", p1 * 2)\n", + " context.log_result(\"loss\", p1 * 3)\n", + "\n", + " # add a lable/tag to this run\n", + " context.set_label(\"category\", \"tests\")\n", + "\n", + " # log a simple artifact + label the artifact\n", " # If you want to upload a local file to the artifact repo add src_path=\n", - " context.log_artifact('somefile', \n", - " body=b'abc is 123', \n", - " local_path='myfile.txt')\n", - " \n", - " # create a dataframe artifact \n", - " df = pd.DataFrame([{'A':10, 'B':100}, {'A':11,'B':110}, {'A':12,'B':120}])\n", - " context.log_dataset('mydf', df=df)\n", - " \n", + " context.log_artifact(\"somefile\", body=b\"abc is 123\", local_path=\"myfile.txt\")\n", + "\n", + " # create a dataframe artifact\n", + " df = pd.DataFrame([{\"A\": 10, \"B\": 100}, {\"A\": 11, \"B\": 110}, {\"A\": 12, \"B\": 120}])\n", + " context.log_dataset(\"mydf\", df=df)\n", + "\n", " # Log an ML Model artifact, add metrics, params, and labels to it\n", - " # and place it in a subdir ('models') under artifacts path \n", - " context.log_model('mymodel', body=b'abc is 123', \n", - " model_file='model.txt', \n", - " metrics={'accuracy':0.85}, parameters={'xx':'abc'},\n", - " labels={'framework': 'xgboost'},\n", - " artifact_path=context.artifact_subpath('models'))\n" + " # and place it in a subdir ('models') under artifacts path\n", + " context.log_model(\n", + " \"mymodel\",\n", + " body=b\"abc is 123\",\n", + " model_file=\"model.txt\",\n", + " metrics={\"accuracy\": 0.85},\n", + " parameters={\"xx\": \"abc\"},\n", + " labels={\"framework\": \"xgboost\"},\n", + " artifact_path=context.artifact_subpath(\"models\"),\n", + " )" ] }, { @@ -188,33 +187,28 @@ "metadata": {}, "outputs": [], "source": [ - "def validation(\n", - " context: MLClientCtx,\n", - " model: DataItem\n", - ") -> None:\n", + "def validation(context: MLClientCtx, model: DataItem) -> None:\n", " \"\"\"Model validation.\n", - " \n", + "\n", " Dummy validation function.\n", - " \n", + "\n", " :param context: The runtime context object.\n", " :param model: The extimated model object.\n", " \"\"\"\n", " # access input metadata, values, files, and secrets (passwords)\n", - " print(f'Run: {context.name} (uid={context.uid})')\n", - " context.logger.info('started validation')\n", - " \n", + " print(f\"Run: {context.name} (uid={context.uid})\")\n", + " context.logger.info(\"started validation\")\n", + "\n", " # get the model file, class (metadata), and extra_data (dict of key: DataItem)\n", " model_file, model_obj, _ = get_model(model)\n", "\n", " # update model object elements and data\n", - " update_model(model_obj, parameters={'one_more': 5})\n", + " update_model(model_obj, parameters={\"one_more\": 5})\n", "\n", - " print(f'path to local copy of model file - {model_file}')\n", - " print('parameters:', model_obj.parameters)\n", - " print('metrics:', model_obj.metrics)\n", - " context.log_artifact('validation', \n", - " body=b' validated ', \n", - " format='html')" + " print(f\"path to local copy of model file - {model_file}\")\n", + " print(\"parameters:\", model_obj.parameters)\n", + " print(\"metrics:\", model_obj.metrics)\n", + " context.log_artifact(\"validation\", body=b\" validated \", format=\"html\")" ] }, { @@ -263,7 +257,8 @@ "source": [ "from mlrun import run_local, code_to_function, mlconf, new_task\n", "from mlrun.platforms.other import auto_mount\n", - "mlconf.dbpath = mlconf.dbpath or 'http://mlrun-api:8080'" + "\n", + "mlconf.dbpath = mlconf.dbpath or \"http://mlrun-api:8080\"" ] }, { @@ -280,9 +275,10 @@ "outputs": [], "source": [ "from os import path\n", - "out = mlconf.artifact_path or path.abspath('./data')\n", + "\n", + "out = mlconf.artifact_path or path.abspath(\"./data\")\n", "# {{run.uid}} will be substituted with the run id, so output will be written to different directoried per run\n", - "artifact_path = path.join(out, '{{run.uid}}')" + "artifact_path = path.join(out, \"{{run.uid}}\")" ] }, { @@ -539,7 +535,7 @@ } ], "source": [ - "train_run = run_local(new_task(handler=training, params={'p1': 5}, artifact_path=out))" + "train_run = run_local(new_task(handler=training, params={\"p1\": 5}, artifact_path=out))" ] }, { @@ -811,9 +807,11 @@ } ], "source": [ - "model = train_run.outputs['mymodel']\n", + "model = train_run.outputs[\"mymodel\"]\n", "\n", - "validation_run = run_local(new_task(handler=validation, inputs={'model': model}, artifact_path=out))" + "validation_run = run_local(\n", + " new_task(handler=validation, inputs={\"model\": model}, artifact_path=out)\n", + ")" ] }, { @@ -842,7 +840,7 @@ "outputs": [], "source": [ "# create an ML function from the notebook, attache it to iguazio data fabric (v3io)\n", - "trainer = code_to_function(name='my-trainer', kind='job')" + "trainer = code_to_function(name=\"my-trainer\", kind=\"job\")" ] }, { @@ -913,7 +911,7 @@ ], "source": [ "# for auto choice between Iguazio platform and k8s PVC\n", - "# should set the env var for PVC: MLRUN_PVC_MOUNT=:, or use mount_pvc() \n", + "# should set the env var for PVC: MLRUN_PVC_MOUNT=:, or use mount_pvc()\n", "trainer.apply(auto_mount())" ] }, @@ -1057,7 +1055,7 @@ "outputs": [], "source": [ "# create the base task (common to both steps), and set the output path and experiment label\n", - "base_task = new_task(artifact_path=out).set_label('stage', 'dev')" + "base_task = new_task(artifact_path=out).set_label(\"stage\", \"dev\")" ] }, { @@ -1296,7 +1294,9 @@ ], "source": [ "# run our training task, with hyper params, and select the one with max accuracy\n", - "train_task = new_task(name='my-training', handler='training', params={'p1': 9}, base=base_task)\n", + "train_task = new_task(\n", + " name=\"my-training\", handler=\"training\", params={\"p1\": 9}, base=base_task\n", + ")\n", "train_run = trainer.run(train_task)" ] }, @@ -1545,9 +1545,9 @@ } ], "source": [ - "# running validation, use the model result from the previous step \n", - "model = train_run.outputs['mymodel']\n", - "trainer.run(base_task, handler='validation', inputs={'model': model}, watch=True)" + "# running validation, use the model result from the previous step\n", + "model = train_run.outputs[\"mymodel\"]\n", + "trainer.run(base_task, handler=\"validation\", inputs={\"model\": model}, watch=True)" ] }, { @@ -1586,26 +1586,20 @@ "metadata": {}, "outputs": [], "source": [ - "@dsl.pipeline(\n", - " name = 'job test',\n", - " description = 'demonstrating mlrun usage'\n", - ")\n", - "def job_pipeline(\n", - " p1: int = 9\n", - ") -> None:\n", + "@dsl.pipeline(name=\"job test\", description=\"demonstrating mlrun usage\")\n", + "def job_pipeline(p1: int = 9) -> None:\n", " \"\"\"Define our pipeline.\n", - " \n", + "\n", " :param p1: A model parameter.\n", " \"\"\"\n", "\n", - " train = trainer.as_step(handler='training',\n", - " params={'p1': p1},\n", - " outputs=['mymodel'])\n", - " \n", - " validate = trainer.as_step(handler='validation',\n", - " inputs={'model': train.outputs['mymodel']},\n", - " outputs=['validation'])\n", - " " + " train = trainer.as_step(handler=\"training\", params={\"p1\": p1}, outputs=[\"mymodel\"])\n", + "\n", + " validate = trainer.as_step(\n", + " handler=\"validation\",\n", + " inputs={\"model\": train.outputs[\"mymodel\"]},\n", + " outputs=[\"validation\"],\n", + " )" ] }, { @@ -1621,7 +1615,7 @@ "metadata": {}, "outputs": [], "source": [ - "kfp.compiler.Compiler().compile(job_pipeline, 'jobpipe.yaml')" + "kfp.compiler.Compiler().compile(job_pipeline, \"jobpipe.yaml\")" ] }, { @@ -1651,7 +1645,7 @@ "metadata": {}, "outputs": [], "source": [ - "artifact_path = 'v3io:///users/admin/kfp/{{workflow.uid}}/'" + "artifact_path = \"v3io:///users/admin/kfp/{{workflow.uid}}/\"" ] }, { @@ -1692,8 +1686,10 @@ } ], "source": [ - "arguments = {'p1': 8}\n", - "run_id = run_pipeline(job_pipeline, arguments, experiment='my-job', artifact_path=artifact_path)" + "arguments = {\"p1\": 8}\n", + "run_id = run_pipeline(\n", + " job_pipeline, arguments, experiment=\"my-job\", artifact_path=artifact_path\n", + ")" ] }, { @@ -1925,9 +1921,10 @@ ], "source": [ "from mlrun import wait_for_pipeline_completion, get_run_db\n", + "\n", "wait_for_pipeline_completion(run_id)\n", "db = get_run_db().connect()\n", - "db.list_runs(project='default', labels=f'workflow={run_id}').show()" + "db.list_runs(project=\"default\", labels=f\"workflow={run_id}\").show()" ] }, { diff --git a/examples/mlrun_sparkk8s.ipynb b/examples/mlrun_sparkk8s.ipynb index 095700541325..12de5596256f 100644 --- a/examples/mlrun_sparkk8s.ipynb +++ b/examples/mlrun_sparkk8s.ipynb @@ -31,12 +31,12 @@ "from os.path import isfile, join\n", "from mlrun import new_function, new_task, mlconf\n", "\n", - "#Set the mlrun database/api\n", - "mlconf.dbpath = 'http://mlrun-api:8080'\n", + "# Set the mlrun database/api\n", + "mlconf.dbpath = \"http://mlrun-api:8080\"\n", "\n", - "#Set the pyspark script path\n", - "V3IO_WORKING_DIR = os.getcwd().replace('/User','/v3io/'+os.getenv('V3IO_HOME'))\n", - "V3IO_SCRIPT_PATH = V3IO_WORKING_DIR+'/spark-function.py'" + "# Set the pyspark script path\n", + "V3IO_WORKING_DIR = os.getcwd().replace(\"/User\", \"/v3io/\" + os.getenv(\"V3IO_HOME\"))\n", + "V3IO_SCRIPT_PATH = V3IO_WORKING_DIR + \"/spark-function.py\"" ] }, { @@ -52,36 +52,39 @@ "metadata": {}, "outputs": [], "source": [ - "#Define a dict of input data sources\n", - "DATA_SOURCES = {'family' :\n", - " {'format': 'jdbc',\n", - " 'url': 'jdbc:mysql://mysql-rfam-public.ebi.ac.uk:4497/Rfam',\n", - " 'dbtable': 'Rfam.family',\n", - " 'user': 'rfamro',\n", - " 'password': '',\n", - " 'driver': 'com.mysql.jdbc.Driver'},\n", - " 'full_region':\n", - " {'format': 'jdbc',\n", - " 'url': 'jdbc:mysql://mysql-rfam-public.ebi.ac.uk:4497/Rfam',\n", - " 'dbtable': 'Rfam.full_region',\n", - " 'user': 'rfamro',\n", - " 'password': '',\n", - " 'driver': 'com.mysql.jdbc.Driver'}\n", - " }\n", - "\n", - "#Define a query to execute on the input data sources\n", - "QUERY = 'SELECT family.*, full_region.evalue_score from family INNER JOIN full_region ON family.rfam_acc = full_region.rfam_acc LIMIT 10'\n", - "\n", - "#Define the output destination\n", - "WRITE_OPTIONS = {'format': 'io.iguaz.v3io.spark.sql.kv',\n", - " 'mode': 'overwrite',\n", - " 'key': 'rfam_id',\n", - " 'path': 'v3io://users/admin/frommysql'}\n", - "\n", - "#Create a task execution with parameters\n", - "PARAMS = {'data_sources': DATA_SOURCES,\n", - " 'query': QUERY,\n", - " 'write_options': WRITE_OPTIONS}\n", + "# Define a dict of input data sources\n", + "DATA_SOURCES = {\n", + " \"family\": {\n", + " \"format\": \"jdbc\",\n", + " \"url\": \"jdbc:mysql://mysql-rfam-public.ebi.ac.uk:4497/Rfam\",\n", + " \"dbtable\": \"Rfam.family\",\n", + " \"user\": \"rfamro\",\n", + " \"password\": \"\",\n", + " \"driver\": \"com.mysql.jdbc.Driver\",\n", + " },\n", + " \"full_region\": {\n", + " \"format\": \"jdbc\",\n", + " \"url\": \"jdbc:mysql://mysql-rfam-public.ebi.ac.uk:4497/Rfam\",\n", + " \"dbtable\": \"Rfam.full_region\",\n", + " \"user\": \"rfamro\",\n", + " \"password\": \"\",\n", + " \"driver\": \"com.mysql.jdbc.Driver\",\n", + " },\n", + "}\n", + "\n", + "# Define a query to execute on the input data sources\n", + "QUERY = \"SELECT family.*, full_region.evalue_score from family INNER JOIN full_region ON family.rfam_acc = full_region.rfam_acc LIMIT 10\"\n", + "\n", + "# Define the output destination\n", + "WRITE_OPTIONS = {\n", + " \"format\": \"io.iguaz.v3io.spark.sql.kv\",\n", + " \"mode\": \"overwrite\",\n", + " \"key\": \"rfam_id\",\n", + " \"path\": \"v3io://users/admin/frommysql\",\n", + "}\n", + "\n", + "# Create a task execution with parameters\n", + "PARAMS = {\"data_sources\": DATA_SOURCES, \"query\": QUERY, \"write_options\": WRITE_OPTIONS}\n", "\n", "SPARK_TASK = new_task(params=PARAMS)" ] @@ -115,12 +118,15 @@ "metadata": {}, "outputs": [], "source": [ - "#Get the list of the dpendency jars\n", - "V3IO_JARS_PATH = '/igz/java/libs/'\n", - "DEPS_JARS_LIST = [join(V3IO_JARS_PATH, f) for f in os.listdir(V3IO_JARS_PATH) \n", - " if isfile(join(V3IO_JARS_PATH, f)) and f.startswith('v3io-') and f.endswith('.jar')]\n", + "# Get the list of the dpendency jars\n", + "V3IO_JARS_PATH = \"/igz/java/libs/\"\n", + "DEPS_JARS_LIST = [\n", + " join(V3IO_JARS_PATH, f)\n", + " for f in os.listdir(V3IO_JARS_PATH)\n", + " if isfile(join(V3IO_JARS_PATH, f)) and f.startswith(\"v3io-\") and f.endswith(\".jar\")\n", + "]\n", "\n", - "DEPS_JARS_LIST.append(V3IO_WORKING_DIR + '/mysql-connector-java-8.0.19.jar')" + "DEPS_JARS_LIST.append(V3IO_WORKING_DIR + \"/mysql-connector-java-8.0.19.jar\")" ] }, { @@ -129,11 +135,14 @@ "metadata": {}, "outputs": [], "source": [ - "#Create MLRun function which runs locally in a passthrough mode (since we use spark-submit)\n", - "local_spark_fn = new_function(kind='local', mode = 'pass',\n", - " command= f\"spark-submit --jars {','.join(DEPS_JARS_LIST)} {V3IO_SCRIPT_PATH}\")\n", - "\n", - "#Run the function with a task\n", + "# Create MLRun function which runs locally in a passthrough mode (since we use spark-submit)\n", + "local_spark_fn = new_function(\n", + " kind=\"local\",\n", + " mode=\"pass\",\n", + " command=f\"spark-submit --jars {','.join(DEPS_JARS_LIST)} {V3IO_SCRIPT_PATH}\",\n", + ")\n", + "\n", + "# Run the function with a task\n", "local_spark_fn.run(SPARK_TASK)" ] }, @@ -150,13 +159,19 @@ "metadata": {}, "outputs": [], "source": [ - "#Create MLRun function to run the spark-job on the kubernetes cluster\n", - "serverless_spark_fn = new_function(kind='spark', command=V3IO_SCRIPT_PATH, name='my-spark-func')\n", + "# Create MLRun function to run the spark-job on the kubernetes cluster\n", + "serverless_spark_fn = new_function(\n", + " kind=\"spark\", command=V3IO_SCRIPT_PATH, name=\"my-spark-func\"\n", + ")\n", "\n", "serverless_spark_fn.with_driver_limits(cpu=\"1300m\")\n", - "serverless_spark_fn.with_driver_requests(cpu=1, mem=\"4G\") # gpu_type & gpus= are supported too\n", + "serverless_spark_fn.with_driver_requests(\n", + " cpu=1, mem=\"4G\"\n", + ") # gpu_type & gpus= are supported too\n", "serverless_spark_fn.with_executor_limits(cpu=\"1400m\")\n", - "serverless_spark_fn.with_executor_requests(cpu=1, mem=\"4G\") # gpu_type & gpus= are supported too\n", + "serverless_spark_fn.with_executor_requests(\n", + " cpu=1, mem=\"4G\"\n", + ") # gpu_type & gpus= are supported too\n", "\n", "serverless_spark_fn.with_igz_spark()\n", "\n", @@ -166,10 +181,10 @@ " \"-O /spark/jars/mysql-connector-java-8.0.19.jar\"\n", "]\n", "\n", - "#Set number of executors\n", + "# Set number of executors\n", "serverless_spark_fn.spec.replicas = 2\n", "\n", - "#Deploy function and install MLRun in the spark image\n", + "# Deploy function and install MLRun in the spark image\n", "serverless_spark_fn.deploy()\n", "\n", "run = serverless_spark_fn.run(SPARK_TASK, watch=False)" diff --git a/examples/mlrun_vault.ipynb b/examples/mlrun_vault.ipynb index 9bdad88e8b8d..4142731568e9 100644 --- a/examples/mlrun_vault.ipynb +++ b/examples/mlrun_vault.ipynb @@ -77,9 +77,7 @@ }, "outputs": [], "source": [ - "func = mlrun.code_to_function(name='vault-func', \n", - " kind='job',\n", - " image='mlrun/mlrun')" + "func = mlrun.code_to_function(name=\"vault-func\", kind=\"job\", image=\"mlrun/mlrun\")" ] }, { @@ -113,11 +111,11 @@ "metadata": {}, "outputs": [], "source": [ - "proj_name = 'vault-mlrun'\n", + "proj_name = \"vault-mlrun\"\n", "\n", "proj = mlrun.new_project(proj_name)\n", "\n", - "project_secrets = {'aws_key': '1234567890', 'github_key': 'proj1Key!!!'}\n", + "project_secrets = {\"aws_key\": \"1234567890\", \"github_key\": \"proj1Key!!!\"}\n", "proj.create_vault_secrets(project_secrets)\n", "\n", "proj.get_vault_secrets()" @@ -143,13 +141,15 @@ "metadata": {}, "outputs": [], "source": [ - "task = mlrun.new_task(project=proj_name,\n", - " name='vault_test_run',\n", - " handler='vault_func',\n", - " params={'secrets':['github_key', 'aws_key']})\n", + "task = mlrun.new_task(\n", + " project=proj_name,\n", + " name=\"vault_test_run\",\n", + " handler=\"vault_func\",\n", + " params={\"secrets\": [\"github_key\", \"aws_key\"]},\n", + ")\n", "\n", "# Add access to project-level secrets\n", - "task.with_secrets('vault', [\"aws_key\"])" + "task.with_secrets(\"vault\", [\"aws_key\"])" ] }, { @@ -183,7 +183,7 @@ "outputs": [], "source": [ "# Access to all project-level secrets can be obtained by passing an empty list of secret names\n", - "task.with_secrets('vault', [])\n", + "task.with_secrets(\"vault\", [])\n", "\n", "result = func.run(task)" ] @@ -207,9 +207,11 @@ "metadata": {}, "outputs": [], "source": [ - "proj_name_2 = 'vault-mlrun-2'\n", + "proj_name_2 = \"vault-mlrun-2\"\n", "proj2 = mlrun.new_project(proj_name_2)\n", - "proj2.create_vault_secrets({'aws_key': '0987654321', 'github_key': 'proj2Key???', 'password': 'myPassword'})" + "proj2.create_vault_secrets(\n", + " {\"aws_key\": \"0987654321\", \"github_key\": \"proj2Key???\", \"password\": \"myPassword\"}\n", + ")" ] }, { @@ -218,11 +220,13 @@ "metadata": {}, "outputs": [], "source": [ - "task2 = mlrun.new_task(project=proj_name_2,\n", - " name='vault_test_run_2',\n", - " handler='vault_func',\n", - " params={'secrets':['password', 'github_key', 'aws_key']})\n", - "task2.with_secrets('vault', [\"aws_key\", \"github_key\", \"password\"])\n", + "task2 = mlrun.new_task(\n", + " project=proj_name_2,\n", + " name=\"vault_test_run_2\",\n", + " handler=\"vault_func\",\n", + " params={\"secrets\": [\"password\", \"github_key\", \"aws_key\"]},\n", + ")\n", + "task2.with_secrets(\"vault\", [\"aws_key\", \"github_key\", \"password\"])\n", "\n", "result = func.run(task2)" ] @@ -244,8 +248,8 @@ }, "outputs": [], "source": [ - "proj.with_secrets('vault',['github_key'])\n", - "proj.get_secret('github_key')" + "proj.with_secrets(\"vault\", [\"github_key\"])\n", + "proj.get_secret(\"github_key\")" ] }, { diff --git a/examples/new-project.ipynb b/examples/new-project.ipynb index 0b9bf26be087..61a8127ed780 100644 --- a/examples/new-project.ipynb +++ b/examples/new-project.ipynb @@ -43,11 +43,11 @@ "metadata": {}, "outputs": [], "source": [ - "# update the dir and repo to reflect real locations \n", + "# update the dir and repo to reflect real locations\n", "# the remote git repo must be initialized in GitHub\n", - "project_dir = '/User/new-proj'\n", - "remote_git = 'https://github.com//.git'\n", - "newproj = new_project('new-project', project_dir, init_git=True)" + "project_dir = \"/User/new-proj\"\n", + "remote_git = \"https://github.com//.git\"\n", + "newproj = new_project(\"new-project\", project_dir, init_git=True)" ] }, { @@ -129,7 +129,7 @@ } ], "source": [ - "newproj.set_function('hub://load_dataset', 'ingest').doc()" + "newproj.set_function(\"hub://load_dataset\", \"ingest\").doc()" ] }, { @@ -192,10 +192,10 @@ ], "source": [ "# add function with build config (base image, run command)\n", - "fn = code_to_function('tstfunc', filename='handler.py', kind='job')\n", - "fn.build_config(base_image = 'mlrun/mlrun', commands=['pip install pandas'])\n", + "fn = code_to_function(\"tstfunc\", filename=\"handler.py\", kind=\"job\")\n", + "fn.build_config(base_image=\"mlrun/mlrun\", commands=[\"pip install pandas\"])\n", "newproj.set_function(fn)\n", - "print(newproj.func('tstfunc').to_yaml())" + "print(newproj.func(\"tstfunc\").to_yaml())" ] }, { @@ -250,7 +250,7 @@ "metadata": {}, "outputs": [], "source": [ - "newproj.set_workflow('main', 'workflow.py')" + "newproj.set_workflow(\"main\", \"workflow.py\")" ] }, { @@ -311,7 +311,7 @@ "metadata": {}, "outputs": [], "source": [ - "newproj.push('master', 'first push', add=['handler.py', 'workflow.py'])" + "newproj.push(\"master\", \"first push\", add=[\"handler.py\", \"workflow.py\"])" ] }, { @@ -394,7 +394,11 @@ } ], "source": [ - "newproj.run('main', arguments={}, artifact_path='v3io:///users/admin/mlrun/kfp/{{workflow.uid}}/')" + "newproj.run(\n", + " \"main\",\n", + " arguments={},\n", + " artifact_path=\"v3io:///users/admin/mlrun/kfp/{{workflow.uid}}/\",\n", + ")" ] }, { diff --git a/examples/remote-spark.ipynb b/examples/remote-spark.ipynb index c017f265cc45..b669c9ab6789 100644 --- a/examples/remote-spark.ipynb +++ b/examples/remote-spark.ipynb @@ -28,32 +28,34 @@ "\n", "from pyspark.sql import SparkSession\n", "\n", - "def describe_spark(context: MLClientCtx, \n", - " dataset: DataItem, \n", - " artifact_path):\n", + "\n", + "def describe_spark(context: MLClientCtx, dataset: DataItem, artifact_path):\n", "\n", " # get file location\n", " location = dataset.local()\n", - " \n", + "\n", " # build spark session\n", " spark = SparkSession.builder.appName(\"Spark job\").getOrCreate()\n", - " \n", + "\n", " # read csv\n", - " df = spark.read.csv(location, header=True, inferSchema= True)\n", - " \n", + " df = spark.read.csv(location, header=True, inferSchema=True)\n", + "\n", " # show\n", " df.show(5)\n", - " \n", + "\n", " # sample for logging\n", " df_to_log = df.sample(False, 0.1).toPandas()\n", - " \n", + "\n", " # log final report\n", - " context.log_dataset(\"df_sample\", \n", - " df=df_to_log,\n", - " format=\"csv\", index=False,\n", - " artifact_path=context.artifact_subpath('data'))\n", - " \n", - " spark.stop()\n" + " context.log_dataset(\n", + " \"df_sample\",\n", + " df=df_to_log,\n", + " format=\"csv\",\n", + " index=False,\n", + " artifact_path=context.artifact_subpath(\"data\"),\n", + " )\n", + "\n", + " spark.stop()" ] }, { diff --git a/examples/v2_model_server.ipynb b/examples/v2_model_server.ipynb index 9fd040eda3be..1ccf8f589fba 100644 --- a/examples/v2_model_server.ipynb +++ b/examples/v2_model_server.ipynb @@ -87,12 +87,12 @@ "class ClassifierModel(mlrun.serving.V2ModelServer):\n", " def load(self):\n", " \"\"\"load and initialize the model and/or other elements\"\"\"\n", - " model_file, extra_data = self.get_model('.pkl')\n", - " self.model = load(open(model_file, 'rb'))\n", + " model_file, extra_data = self.get_model(\".pkl\")\n", + " self.model = load(open(model_file, \"rb\"))\n", "\n", " def predict(self, body: dict) -> List:\n", " \"\"\"Generate model predictions from sample.\"\"\"\n", - " feats = np.asarray(body['inputs'])\n", + " feats = np.asarray(body[\"inputs\"])\n", " result: np.ndarray = self.model.predict(feats)\n", " return result.tolist()" ] @@ -126,7 +126,7 @@ "metadata": {}, "outputs": [], "source": [ - "models_path = 'https://s3.wasabisys.com/iguazio/models/iris/model.pkl'" + "models_path = \"https://s3.wasabisys.com/iguazio/models/iris/model.pkl\"" ] }, { @@ -160,12 +160,15 @@ } ], "source": [ - "fn = mlrun.code_to_function('v2-model-server', description=\"generic sklearn model server\",\n", - " categories=['serving', 'ml'],\n", - " labels={'author': 'yaronh', 'framework': 'sklearn'},\n", - " code_output='.')\n", - "fn.spec.default_class = 'ClassifierModel'\n", - "#print(fn.to_yaml())\n", + "fn = mlrun.code_to_function(\n", + " \"v2-model-server\",\n", + " description=\"generic sklearn model server\",\n", + " categories=[\"serving\", \"ml\"],\n", + " labels={\"author\": \"yaronh\", \"framework\": \"sklearn\"},\n", + " code_output=\".\",\n", + ")\n", + "fn.spec.default_class = \"ClassifierModel\"\n", + "# print(fn.to_yaml())\n", "fn.export()" ] }, @@ -182,8 +185,8 @@ "metadata": {}, "outputs": [], "source": [ - "fn.add_model('mymodel', model_path=models_path)\n", - "#fn.verbose = True" + "fn.add_model(\"mymodel\", model_path=models_path)\n", + "# fn.verbose = True" ] }, { @@ -217,8 +220,9 @@ "outputs": [], "source": [ "from sklearn.datasets import load_iris\n", + "\n", "iris = load_iris()\n", - "x = iris['data'].tolist()" + "x = iris[\"data\"].tolist()" ] }, { @@ -285,7 +289,7 @@ ], "source": [ "fn.apply(mlrun.mount_v3io())\n", - "fn.deploy(project='v2-srv')" + "fn.deploy(project=\"v2-srv\")" ] }, { @@ -314,8 +318,8 @@ } ], "source": [ - "my_data = '''{\"inputs\":[[5.1, 3.5, 1.4, 0.2],[7.7, 3.8, 6.7, 2.2]]}'''\n", - "fn.invoke('/v2/models/mymodel/infer', my_data)" + "my_data = \"\"\"{\"inputs\":[[5.1, 3.5, 1.4, 0.2],[7.7, 3.8, 6.7, 2.2]]}\"\"\"\n", + "fn.invoke(\"/v2/models/mymodel/infer\", my_data)" ] }, { diff --git a/examples/xgb_serving.ipynb b/examples/xgb_serving.ipynb index 8316d6ff4d98..422269c795b1 100644 --- a/examples/xgb_serving.ipynb +++ b/examples/xgb_serving.ipynb @@ -35,7 +35,7 @@ "source": [ "# mlrun: ignore\n", "# if the nuclio-jupyter package is not installed run !pip install nuclio-jupyter\n", - "import nuclio " + "import nuclio" ] }, { @@ -128,18 +128,18 @@ " # this is called once to load the model\n", " # get_model returns file path (copied to local) and extra data dict (of key: DataItem)\n", " # model object can be accessed at self.model_spec (after running .get_model)\n", - " model_file, _ = self.get_model('.bst')\n", + " model_file, _ = self.get_model(\".bst\")\n", " self._booster = xgb.Booster(model_file=model_file)\n", "\n", " def predict(self, body):\n", " try:\n", " # Use of list as input is deprecated see https://github.com/dmlc/xgboost/pull/3970\n", - " events = np.array(body['instances'])\n", + " events = np.array(body[\"instances\"])\n", " dmatrix = xgb.DMatrix(events)\n", " result: xgb.DMatrix = self._booster.predict(dmatrix)\n", " return result.tolist()\n", " except Exception as exc:\n", - " raise Exception(f\"Failed to predict {exc}\")\n" + " raise Exception(f\"Failed to predict {exc}\")" ] }, { @@ -183,8 +183,8 @@ "outputs": [], "source": [ "# a valist model.bst file MUST EXIST in the model dir\n", - "#model_dir = os.path.abspath('./')\n", - "model_dir = '/User/mlrun/kfp/032e6d59-6bfe-4ee7-bcf6-1fb26e5db550/1' #/model.bst'" + "# model_dir = os.path.abspath('./')\n", + "model_dir = \"/User/mlrun/kfp/032e6d59-6bfe-4ee7-bcf6-1fb26e5db550/1\" # /model.bst'" ] }, { @@ -193,7 +193,7 @@ "metadata": {}, "outputs": [], "source": [ - "my_server = XGBoostModel('my-model', model_dir=model_dir)\n", + "my_server = XGBoostModel(\"my-model\", model_dir=model_dir)\n", "my_server.load()" ] }, @@ -291,12 +291,12 @@ } ], "source": [ - "fn = new_model_server('iris-srv', \n", - " models={'iris_v1': model_dir}, \n", - " model_class='XGBoostModel')\n", + "fn = new_model_server(\n", + " \"iris-srv\", models={\"iris_v1\": model_dir}, model_class=\"XGBoostModel\"\n", + ")\n", "\n", "# use mount_v3io() for iguazio volumes or mount_pvc() for k8s PVC volumes\n", - "fn.apply(mount_v3io()) " + "fn.apply(mount_v3io())" ] }, { @@ -339,7 +339,7 @@ "outputs": [], "source": [ "# KFServing protocol event\n", - "event_data = {\"instances\":[[5], [10]]}" + "event_data = {\"instances\": [[5], [10]]}" ] }, { @@ -349,7 +349,8 @@ "outputs": [], "source": [ "import json\n", - "resp = requests.put(addr + '/iris_v1/predict', json=json.dumps(event_data))\n", + "\n", + "resp = requests.put(addr + \"/iris_v1/predict\", json=json.dumps(event_data))\n", "print(resp.text)" ] }, diff --git a/extras-requirements.txt b/extras-requirements.txt index a9f6d5e13b0f..08731198c752 100644 --- a/extras-requirements.txt +++ b/extras-requirements.txt @@ -8,22 +8,21 @@ # in setup.py so that we'll be able to copy and install this in the layer with all other requirements making the last # layer (which is most commonly being re-built) as thin as possible # we have a test test_extras_requirement_file_aligned to verify this file is aligned to setup.py -boto3~=1.9, <1.17.107 -botocore>=1.20.106,<1.20.107 -aiobotocore~=1.4.0 -s3fs~=2021.8.1 +boto3~=1.24.59 +aiobotocore~=2.4.2 +s3fs~=2023.1.0 # https://github.com/Azure/azure-sdk-for-python/issues/24765#issuecomment-1150310498 msrest~=0.6.21 azure-core~=1.24 azure-storage-blob~=12.13 -adlfs~=2021.8.1 +adlfs~=2022.2.0 azure-identity~=1.5 azure-keyvault-secrets~=4.2 # cryptography>=39, which is required by azure, needs this, or else we get # AttributeError: module 'lib' has no attribute 'OpenSSL_add_all_algorithms' (ML-3471) pyopenssl>=23 bokeh~=2.4, >=2.4.2 -gcsfs~=2021.8.1 +gcsfs~=2023.1.0 # plotly artifact body in 5.12.0 may contain chars that are not encodable in 'latin-1' encoding # so, it cannot be logged as artifact (raised UnicodeEncode error - ML-3255) plotly~=5.4, <5.12.0 @@ -31,5 +30,6 @@ plotly~=5.4, <5.12.0 # required by frames (because it upgrades protobuf from 3.x to 4.x, breaking binary compatibility) google-cloud-bigquery[pandas, bqstorage]~=3.2 kafka-python~=2.0 +avro~=1.11 redis~=4.3 graphviz~=0.20.0 diff --git a/go/Makefile b/go/Makefile index 2568e58ca73e..f805f7783744 100644 --- a/go/Makefile +++ b/go/Makefile @@ -50,6 +50,10 @@ push-log-collector: @echo Pushing log-collector image docker push $(MLRUN_DOCKER_IMAGE_PREFIX)/log-collector:$(MLRUN_DOCKER_TAG) +.PHONY: pull-log-collector +pull-log-collector: + docker pull $(MLRUN_DOCKER_IMAGE_PREFIX)/log-collector:$(MLRUN_DOCKER_TAG) + .PHONY: schemas-compiler schemas-compiler: schemas-compiler @echo Building schemas-compiler image @@ -73,7 +77,7 @@ compile-schemas-local: cleanup compile-schemas-go compile-schemas-python compile-schemas-dockerized: schemas-compiler @echo Compiling schemas in docker container docker run \ - -v $(shell dirname $(PWD)):/app \ + -v $(shell dirname $(CURDIR)):/app \ $(MLRUN_DOCKER_IMAGE_PREFIX)/schemas-compiler:latest \ make compile-schemas-local diff --git a/go/cmd/logcollector/docker/Dockerfile b/go/cmd/logcollector/docker/Dockerfile index a09d205430a8..e94d67dca940 100644 --- a/go/cmd/logcollector/docker/Dockerfile +++ b/go/cmd/logcollector/docker/Dockerfile @@ -35,7 +35,7 @@ RUN GOOS=linux \ FROM alpine:latest as install-health-probe -ARG GRPC_HEALTH_PROBE_VERSION=v0.4.14 +ARG GRPC_HEALTH_PROBE_VERSION=v0.4.19 RUN mkdir /app WORKDIR /app @@ -44,7 +44,7 @@ RUN wget -qO/app/grpc_health_probe \ https://github.com/grpc-ecosystem/grpc-health-probe/releases/download/${GRPC_HEALTH_PROBE_VERSION}/grpc_health_probe-linux-amd64 && \ chmod +x /app/grpc_health_probe -FROM gcr.io/iguazio/alpine:3.17 +FROM gcr.io/iguazio/alpine:3.18 COPY --from=build-binary /app/main /main COPY --from=install-health-probe /app/grpc_health_probe /grpc_health_probe diff --git a/go/cmd/schemas_compiler/docker/Dockerfile b/go/cmd/schemas_compiler/docker/Dockerfile index 88b3a4c528ad..fc18b78845df 100644 --- a/go/cmd/schemas_compiler/docker/Dockerfile +++ b/go/cmd/schemas_compiler/docker/Dockerfile @@ -12,21 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +ARG PYTHON_VERSION=3.9 ARG GO_VERSION=1.19 -FROM golang:${GO_VERSION} +FROM golang:${GO_VERSION}-alpine AS golang + +FROM python:${PYTHON_VERSION}-alpine ARG PROTOC_GEN_GO_VERSION=v1.28 ARG PROTOC_GEN_GO_GRPC_VERSION=v1.2 -ARG GRPCIO_TOOLS_VERSION="~=1.41.0" +ARG GRPCIO_TOOLS_VERSION="~=1.54.2" WORKDIR /app/go -RUN apt-get update && apt install -y \ - protobuf-compiler \ - python3 \ - python3-setuptools \ - python3-pip +RUN apk add --no-cache protoc build-base linux-headers + +COPY --from=golang /usr/local/go/ /usr/local/go/ + +# add copied golang binary to path, add go bin to path (where we install go binaries) +ENV PATH="/usr/local/go/bin:/root/go/bin:${PATH}" RUN go install google.golang.org/protobuf/cmd/protoc-gen-go@${PROTOC_GEN_GO_VERSION} && \ go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@${PROTOC_GEN_GO_GRPC_VERSION} diff --git a/go/go.mod b/go/go.mod index ebca7a271b04..7658af2c90b6 100644 --- a/go/go.mod +++ b/go/go.mod @@ -3,7 +3,6 @@ module github.com/mlrun/mlrun go 1.19 require ( - github.com/golang/protobuf v1.5.2 github.com/google/uuid v1.3.0 github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 github.com/nuclio/errors v0.0.4 @@ -11,8 +10,9 @@ require ( github.com/nuclio/loggerus v0.0.6 github.com/sirupsen/logrus v1.8.0 github.com/stretchr/testify v1.8.1 - golang.org/x/sync v0.1.0 - google.golang.org/grpc v1.51.0 + golang.org/x/sync v0.2.0 + google.golang.org/grpc v1.55.0 + google.golang.org/protobuf v1.30.0 k8s.io/api v0.23.15 k8s.io/apimachinery v0.23.15 k8s.io/client-go v0.23.15 @@ -23,7 +23,8 @@ require ( github.com/evanphx/json-patch v4.12.0+incompatible // indirect github.com/go-logr/logr v1.2.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect - github.com/google/go-cmp v0.5.6 // indirect + github.com/golang/protobuf v1.5.3 // indirect + github.com/google/go-cmp v0.5.9 // indirect github.com/google/gofuzz v1.1.0 // indirect github.com/googleapis/gnostic v0.5.5 // indirect github.com/imdario/mergo v0.3.5 // indirect @@ -35,15 +36,14 @@ require ( github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - golang.org/x/net v0.7.0 // indirect - golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f // indirect - golang.org/x/sys v0.5.0 // indirect - golang.org/x/term v0.5.0 // indirect - golang.org/x/text v0.7.0 // indirect - golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac // indirect + golang.org/x/net v0.8.0 // indirect + golang.org/x/oauth2 v0.6.0 // indirect + golang.org/x/sys v0.6.0 // indirect + golang.org/x/term v0.6.0 // indirect + golang.org/x/text v0.8.0 // indirect + golang.org/x/time v0.3.0 // indirect google.golang.org/appengine v1.6.7 // indirect - google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1 // indirect - google.golang.org/protobuf v1.28.1 // indirect + google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go/go.sum b/go/go.sum index d6b28c4cdba6..a2601027c8c1 100644 --- a/go/go.sum +++ b/go/go.sum @@ -124,8 +124,9 @@ github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.1/go.mod h1:DopwsBzvsk0Fs44TXzsVbJyPhcCPeIwnvohx4u74HPM= -github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= +github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= @@ -140,8 +141,8 @@ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= -github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.1.0 h1:Hsa8mG0dQ46ij8Sl2AYJDUv1oA9/d6Vk+3LG99Oe02g= github.com/google/gofuzz v1.1.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -354,8 +355,8 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210316092652-d523dce5a7f4/go.mod h1:RBQZq4jEuRlivfhVLdyRGr576XBO4/greRjx4P4O3yc= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20211209124913-491a49abca63/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= -golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= +golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -367,8 +368,9 @@ golang.org/x/oauth2 v0.0.0-20201208152858-08078c50e5b5/go.mod h1:KelEdhl1UZF7XfJ golang.org/x/oauth2 v0.0.0-20210218202405-ba52d332ba99/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= -golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f h1:Qmd2pbz05z7z6lm0DrgQVVPuBm92jqujBKMHMOlOQEw= golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A= +golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= +golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -380,8 +382,8 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= -golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.2.0 h1:PUR+T4wwASmuSTYdKjYHI5TD22Wy5ogLU5qZCOLxBrI= +golang.org/x/sync v0.2.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -429,12 +431,12 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210831042530-f4d43177bf5e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU= -golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY= -golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.6.0 h1:clScbb1cHjoCkyRbWwBEUZ5H/tIFu5TAXIqaZD0Gcjw= +golang.org/x/term v0.6.0/go.mod h1:m6U89DPEgQRMq3DNkDClhWw02AUbt2daBVO4cn4Hv9U= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -444,13 +446,14 @@ golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= -golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.8.0 h1:57P1ETyNKtuIjB4SRd15iJxuhj8Gc416Y78H3qgMh68= +golang.org/x/text v0.8.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac h1:7zkz7BUtwNFFqcowJ+RIgu2MaV/MapERkDIy+mwPyjs= golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= +golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= @@ -575,8 +578,9 @@ google.golang.org/genproto v0.0.0-20210222152913-aa3ee6e6a81c/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20210303154014-9728d6b83eeb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210310155132-4ce2db91004e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20210319143718-93e7006c17a6/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= -google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1 h1:E7wSQBXkH3T3diucK+9Z1kjn4+/9tNG7lZLr75oOhh8= google.golang.org/genproto v0.0.0-20210402141018-6c239bbf2bb1/go.mod h1:9lPAdzaEmUacj36I+k7YKbEc5CXzPIeORRgDAUOu28A= +google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4 h1:DdoeryqhaXp1LtT/emMP1BRJPHHKFi5akj/nbx/zNTA= +google.golang.org/genproto v0.0.0-20230306155012-7f2fa6fef1f4/go.mod h1:NWraEVixdDnqcqQ30jipen1STv2r/n24Wb7twVTGR4s= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -595,8 +599,8 @@ google.golang.org/grpc v1.34.0/go.mod h1:WotjhfgOW/POjDeRt8vscBtXq+2VjORFy659qA5 google.golang.org/grpc v1.35.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= google.golang.org/grpc v1.36.1/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.51.0 h1:E1eGv1FTqoLIdnBCZufiSHgKjlqG6fKFf6pPWtMTh8U= -google.golang.org/grpc v1.51.0/go.mod h1:wgNDFcnuBGmxLKI/qn4T+m5BtEBYXJPvibbUPsAIPww= +google.golang.org/grpc v1.55.0 h1:3Oj82/tFSCeUrRTg/5E/7d/W5A1tj6Ky1ABAuZuv5ag= +google.golang.org/grpc v1.55.0/go.mod h1:iYEXKGkEBhg1PjZQvoYEVPTDkHo1/bjTnfwTeGONTY8= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -610,8 +614,8 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/go/pkg/services/logcollector/logcollector_test.go b/go/pkg/services/logcollector/logcollector_test.go index 5e8f03aa5557..11fc375a970b 100644 --- a/go/pkg/services/logcollector/logcollector_test.go +++ b/go/pkg/services/logcollector/logcollector_test.go @@ -211,7 +211,7 @@ func (suite *LogCollectorTestSuite) TestStreamPodLogs() { suite.Require().True(started, "Log streaming didn't start") // resolve log file path - logFilePath := suite.logCollectorServer.resolvePodLogFilePath(suite.projectName, runId, pod.Name) + logFilePath := suite.logCollectorServer.resolveRunLogFilePath(suite.projectName, runId) // read log file until it has content, or timeout timeout := time.After(30 * time.Second) @@ -259,10 +259,9 @@ func (suite *LogCollectorTestSuite) TestStartLogBestEffort() { func (suite *LogCollectorTestSuite) TestGetLogsSuccessful() { runUID := uuid.New().String() - podName := "my-pod" // creat log file for runUID and pod - logFilePath := suite.logCollectorServer.resolvePodLogFilePath(suite.projectName, runUID, podName) + logFilePath := suite.logCollectorServer.resolveRunLogFilePath(suite.projectName, runUID) // write log file logText := "Some fake pod logs\n" @@ -397,8 +396,6 @@ func (suite *LogCollectorTestSuite) TestReadLogsFromFileWhileWriting() { func (suite *LogCollectorTestSuite) TestHasLogs() { runUID := uuid.New().String() - podName := "my-pod" - request := &log_collector.HasLogsRequest{ RunUID: runUID, ProjectName: suite.projectName, @@ -411,7 +408,7 @@ func (suite *LogCollectorTestSuite) TestHasLogs() { suite.Require().False(hasLogsResponse.HasLogs, "Expected run to not have logs") // create log file for runUID and pod - logFilePath := suite.logCollectorServer.resolvePodLogFilePath(suite.projectName, runUID, podName) + logFilePath := suite.logCollectorServer.resolveRunLogFilePath(suite.projectName, runUID) // write log file logText := "Some fake pod logs\n" @@ -521,7 +518,7 @@ func (suite *LogCollectorTestSuite) TestDeleteLogs() { for i := 0; i < testCase.logsNumToCreate; i++ { runUID := uuid.New().String() runUIDs = append(runUIDs, runUID) - logFilePath := suite.logCollectorServer.resolvePodLogFilePath(projectName, runUID, "pod") + logFilePath := suite.logCollectorServer.resolveRunLogFilePath(projectName, runUID) err := common.WriteToFile(logFilePath, []byte("some log"), false) suite.Require().NoError(err, "Failed to write to file") } @@ -558,7 +555,7 @@ func (suite *LogCollectorTestSuite) TestDeleteProjectLogs() { for i := 0; i < logsNum; i++ { runUID := uuid.New().String() runUIDs = append(runUIDs, runUID) - logFilePath := suite.logCollectorServer.resolvePodLogFilePath(projectName, runUID, "pod") + logFilePath := suite.logCollectorServer.resolveRunLogFilePath(projectName, runUID) err := common.WriteToFile(logFilePath, []byte("some log"), false) suite.Require().NoError(err, "Failed to write to file") } @@ -596,38 +593,7 @@ func (suite *LogCollectorTestSuite) TestGetLogFilePath() { suite.Require().NoError(err) // make the run file - runFilePath := suite.logCollectorServer.resolvePodLogFilePath(projectName, runUID, "pod") - err = common.WriteToFile(runFilePath, []byte("some log"), false) - suite.Require().NoError(err, "Failed to write to file") - - // get the log file path - logFilePath, err := suite.logCollectorServer.getLogFilePath(suite.ctx, runUID, projectName) - suite.Require().NoError(err, "Failed to get log file path") - suite.Require().Equal(runFilePath, logFilePath, "Expected log file path to be the same as the run file path") -} - -func (suite *LogCollectorTestSuite) TestGetLogFilePathConcurrently() { - runUID := "1234" - projectName := "someProjectB" - var err error - - projectMutex := &sync.Mutex{} - suite.logCollectorServer.readDirentProjectNameSyncMap = &sync.Map{} - suite.logCollectorServer.readDirentProjectNameSyncMap.Store(projectName, projectMutex) - projectMutex.Lock() - startTime := time.Now() - - // unlock the mutex after 1 second - time.AfterFunc(1500*time.Millisecond, func() { - projectMutex.Unlock() - }) - - // make the project dir - err = os.MkdirAll(path.Join(suite.baseDir, projectName), 0755) - suite.Require().NoError(err) - - // make the run file - runFilePath := suite.logCollectorServer.resolvePodLogFilePath(projectName, runUID, "pod") + runFilePath := suite.logCollectorServer.resolveRunLogFilePath(projectName, runUID) err = common.WriteToFile(runFilePath, []byte("some log"), false) suite.Require().NoError(err, "Failed to write to file") @@ -635,12 +601,6 @@ func (suite *LogCollectorTestSuite) TestGetLogFilePathConcurrently() { logFilePath, err := suite.logCollectorServer.getLogFilePath(suite.ctx, runUID, projectName) suite.Require().NoError(err, "Failed to get log file path") suite.Require().Equal(runFilePath, logFilePath, "Expected log file path to be the same as the run file path") - - endTime := time.Since(startTime) - suite.Require().Truef(endTime >= 1*time.Second, "Expected getLogFilePath to take more than a second (took %v)", endTime) - - // make sure the mutex is unlocked - suite.Require().True(projectMutex.TryLock(), "Expected project mutex to be unlocked") } func TestLogCollectorTestSuite(t *testing.T) { diff --git a/go/pkg/services/logcollector/server.go b/go/pkg/services/logcollector/server.go index 20f0623e37b4..c4cd7d0064ea 100644 --- a/go/pkg/services/logcollector/server.go +++ b/go/pkg/services/logcollector/server.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "io" - "io/fs" "math" "os" "path" @@ -41,7 +40,6 @@ import ( "golang.org/x/sync/errgroup" "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" - "k8s.io/apimachinery/pkg/util/cache" "k8s.io/client-go/kubernetes" ) @@ -70,14 +68,6 @@ type Server struct { // interval durations readLogWaitTime time.Duration monitoringInterval time.Duration - - // log file cache to reduce sys calls finding the log file paths. - logFilesCache *cache.Expiring - logFilesCacheTTL time.Duration - - // map of project name to its mutex lock - // using project mutex to prevent listing project dir concurrently - readDirentProjectNameSyncMap *sync.Map } // NewLogCollectorServer creates a new log collector server @@ -149,8 +139,6 @@ func NewLogCollectorServer(logger logger.Logger, logCollectionBufferPool := bufferpool.NewSizedBytePool(logCollectionBufferPoolSize, logCollectionBufferSizeBytes) getLogsBufferPool := bufferpool.NewSizedBytePool(getLogsBufferPoolSize, getLogsBufferSizeBytes) - logFilesCache := cache.NewExpiring() - return &Server{ AbstractMlrunGRPCServer: abstractServer, namespace: namespace, @@ -165,18 +153,8 @@ func NewLogCollectorServer(logger logger.Logger, logCollectionBufferSizeBytes: logCollectionBufferSizeBytes, getLogsBufferSizeBytes: getLogsBufferSizeBytes, isChief: isChief, - logFilesCache: logFilesCache, startLogsFindingPodsInterval: 3 * time.Second, startLogsFindingPodsTimeout: 15 * time.Second, - readDirentProjectNameSyncMap: &sync.Map{}, - - // we delete log files only when deleting the project - // that means, if project is gone, log files are gone too - // hasLogFiles is called during get_logs on project runs - // so if no project, no runs, no get_logs, and this one is pretty much safe to cache - // that being said, limit to few minutes (hard coded for now) - // this cache is done to reduce IOs - logFilesCacheTTL: 5 * time.Minute, }, nil } @@ -433,6 +411,10 @@ func (s *Server) GetLogs(request *protologcollector.GetLogsRequest, responseStre // HasLogs returns true if the log file exists for a given run id func (s *Server) HasLogs(ctx context.Context, request *protologcollector.HasLogsRequest) (*protologcollector.HasLogsResponse, error) { + s.Logger.DebugWithCtx(ctx, + "Received has log request", + "runUID", request.RunUID, + "project", request.ProjectName) // get log file path if _, err := s.getLogFilePath(ctx, request.RunUID, request.ProjectName); err != nil { @@ -464,7 +446,6 @@ func (s *Server) HasLogs(ctx context.Context, request *protologcollector.HasLogs ErrorMessage: common.GetErrorStack(err, common.DefaultErrorStackDepth), }, nil } - return &protologcollector.HasLogsResponse{ Success: true, HasLogs: true, @@ -656,7 +637,7 @@ func (s *Server) startLogStreaming(ctx context.Context, startedStreamingGoroutine <- true // create a log file to the pod - logFilePath := s.resolvePodLogFilePath(projectName, runUID, podName) + logFilePath := s.resolveRunLogFilePath(projectName, runUID) if err := common.EnsureFileExists(logFilePath); err != nil { s.Logger.ErrorWithCtx(ctx, "Failed to ensure log file", @@ -665,14 +646,14 @@ func (s *Server) startLogStreaming(ctx context.Context, return } - // add log file path to cache - s.logFilesCache.Set(s.getLogFileCacheKey(runUID, projectName), logFilePath, s.logFilesCacheTTL) - // open log file in read/write and append, to allow reading the logs while we write more logs to it openFlags := os.O_RDWR | os.O_APPEND file, err := os.OpenFile(logFilePath, openFlags, 0644) if err != nil { - s.Logger.ErrorWithCtx(ctx, "Failed to open file", "err", err.Error(), "logFilePath", logFilePath) + s.Logger.ErrorWithCtx(ctx, + "Failed to open file", + "err", err.Error(), + "logFilePath", logFilePath) return } defer file.Close() // nolint: errcheck @@ -716,13 +697,23 @@ func (s *Server) startLogStreaming(ctx context.Context, defer stream.Close() // nolint: errcheck for keepLogging { - keepLogging, err = s.streamPodLogs(ctx, runUID, file, stream) if err != nil { s.Logger.WarnWithCtx(ctx, "An error occurred while streaming pod logs", "err", common.GetErrorStack(err, common.DefaultErrorStackDepth)) + + // fatal error, bail out + // note that when function is returned, a defer function will remove the + // log collection from (in memory) state file. + // it ensures us that when log collection monitoring kicks in (it runs periodically) + // it will ignite the run log collection again. + return } + + // breath + // stream pod logs might return fast when there is nothing to read and no error occurred + time.Sleep(100 * time.Millisecond) } s.Logger.DebugWithCtx(ctx, @@ -735,7 +726,10 @@ func (s *Server) startLogStreaming(ctx context.Context, s.Logger.WarnWithCtx(ctx, "Failed to remove log item from state file") } - s.Logger.DebugWithCtx(ctx, "Finished log streaming", "runUID", runUID, "podName", podName) + s.Logger.DebugWithCtx(ctx, + "Finished log streaming", + "runUID", runUID, + "podName", podName) } // streamPodLogs streams logs from a pod to a file @@ -756,7 +750,8 @@ func (s *Server) streamPodLogs(ctx context.Context, // write to file if _, err := logFile.Write(buf[:numBytesRead]); err != nil { - s.Logger.WarnWithCtx(ctx, "Failed to write pod log to file", + s.Logger.WarnWithCtx(ctx, + "Failed to write pod log to file", "err", err.Error(), "runUID", runUID) return true, errors.Wrap(err, "Failed to write pod log to file") @@ -769,47 +764,23 @@ func (s *Server) streamPodLogs(ctx context.Context, return false, nil } - // log error if occurred + // other error occurred if err != nil { - s.Logger.WarnWithCtx(ctx, "Failed to read pod log", - "err", err.Error(), - "runUID", runUID) - - // if error is not nil, and we didn't read anything - a real error occurred, so we stop logging - if numBytesRead != 0 { - return false, errors.Wrap(err, "Failed to read pod logs") - } + return false, errors.Wrap(err, "Failed to read pod logs") } // nothing happened, continue return true, nil } -// resolvePodLogFilePath returns the path to the pod log file -func (s *Server) resolvePodLogFilePath(projectName, runUID, podName string) string { - return path.Join(s.baseDir, projectName, fmt.Sprintf("%s_%s", runUID, podName)) +// resolveRunLogFilePath returns the path to the pod log file +func (s *Server) resolveRunLogFilePath(projectName, runUID string) string { + return path.Join(s.baseDir, projectName, runUID) } // getLogFilePath returns the path to the run's latest log file func (s *Server) getLogFilePath(ctx context.Context, runUID, projectName string) (string, error) { - - // first try load from cache - if filePath, found := s.logFilesCache.Get(s.getLogFileCacheKey(runUID, projectName)); found { - return filePath.(string), nil - } - - // get project mutex or create one - projectMutex, _ := s.readDirentProjectNameSyncMap.LoadOrStore(projectName, &sync.Mutex{}) - - // lock project mutex, we want only one project dir to be read at a time - projectMutex.(*sync.Mutex).Lock() - - // unlock project mutex when done - defer projectMutex.(*sync.Mutex).Unlock() - var logFilePath string - var latestModTime time.Time - var retryCount int if err := common.RetryUntilSuccessful(5*time.Second, 1*time.Second, func() (bool, error) { defer func() { @@ -836,58 +807,33 @@ func (s *Server) getLogFilePath(ctx context.Context, runUID, projectName string) return false, errors.Wrap(err, "Failed to get project directory") } - // list all files in project directory - if err := filepath.WalkDir(filepath.Join(s.baseDir, projectName), - func(path string, dirEntry fs.DirEntry, err error) error { - if err != nil { - s.Logger.WarnWithCtx(ctx, - "Failed to walk path", - "retryCount", retryCount, - "path", path, - "err", errors.GetErrorStackString(err, 10)) - return errors.Wrapf(err, "Failed to walk path %s", path) - } - - // skip directories - if dirEntry.IsDir() { - return nil - } - - // if file name starts with run id, it's a log file - if strings.HasPrefix(dirEntry.Name(), runUID) { - info, err := dirEntry.Info() - if err != nil { - return errors.Wrapf(err, "Failed to get file info for %s", path) - } - - // if it's the first file, set it as the log file path - // otherwise, check if it's the latest modified file - if logFilePath == "" || info.ModTime().After(latestModTime) { - logFilePath = path - latestModTime = info.ModTime() - } - } - - return nil - }); err != nil { + // get run log file path + runLogFilePath := s.resolveRunLogFilePath(projectName, runUID) - // retry - return true, errors.Wrap(err, "Failed to list files in base directory") - } - - if logFilePath == "" { - return true, errors.Errorf("Log file not found for run %s", runUID) + if exists, err := common.FileExists(runLogFilePath); err != nil { + s.Logger.WarnWithCtx(ctx, + "Failed to get run log file path", + "retryCount", retryCount, + "runUID", runUID, + "projectName", projectName, + "err", err.Error()) + return false, errors.Wrap(err, "Failed to get project directory") + } else if !exists { + s.Logger.WarnWithCtx(ctx, + "Run log file not found", + "retryCount", retryCount, + "runUID", runUID, + "projectName", projectName) + return true, errors.New("Run log file not found") } - // found log file + // found it + logFilePath = runLogFilePath return false, nil - }); err != nil { return "", errors.Wrap(err, "Exhausted getting log file path") } - // store in cache - s.logFilesCache.Set(s.getLogFileCacheKey(runUID, projectName), logFilePath, s.logFilesCacheTTL) return logFilePath, nil } @@ -1108,7 +1054,7 @@ func (s *Server) successfulBaseResponse() *protologcollector.BaseResponse { func (s *Server) deleteRunLogFiles(ctx context.Context, runUID, project string) error { // get all files that have the runUID as a prefix - pattern := path.Join(s.baseDir, project, fmt.Sprintf("%s_*", runUID)) + pattern := path.Join(s.baseDir, project, runUID) files, err := filepath.Glob(pattern) if err != nil { return errors.Wrap(err, "Failed to get log files") @@ -1147,7 +1093,3 @@ func (s *Server) deleteProjectLogs(project string) error { return nil } - -func (s *Server) getLogFileCacheKey(runUID, project string) string { - return fmt.Sprintf("%s/%s", runUID, project) -} diff --git a/mlrun/__init__.py b/mlrun/__init__.py index f17741f6f3e7..ca84c2caf317 100644 --- a/mlrun/__init__.py +++ b/mlrun/__init__.py @@ -35,6 +35,7 @@ from .errors import MLRunInvalidArgumentError, MLRunNotFoundError from .execution import MLClientCtx from .model import RunObject, RunTemplate, new_task +from .package import ArtifactType, DefaultPackager, Packager, handler from .platforms import ( VolumeMount, auto_mount, @@ -62,14 +63,13 @@ get_object, get_or_create_ctx, get_pipeline, - handler, import_function, new_function, run_local, run_pipeline, wait_for_pipeline_completion, ) -from .runtimes import ArtifactType, new_model_server +from .runtimes import new_model_server from .secrets import get_secret_or_env from .utils.version import Version diff --git a/mlrun/__main__.py b/mlrun/__main__.py index fded06f63325..7fc6ae7406fb 100644 --- a/mlrun/__main__.py +++ b/mlrun/__main__.py @@ -34,11 +34,9 @@ import mlrun -from .builder import upload_tarball from .config import config as mlconf from .db import get_run_db from .errors import err_to_str -from .k8s_utils import K8sHelper from .model import RunTemplate from .platforms import auto_mount as auto_mount_modifier from .projects import load_project @@ -545,7 +543,7 @@ def build( logger.info(f"uploading data from {src} to {archive}") target = archive if archive.endswith("/") else archive + "/" target += f"src-{meta.project}-{meta.name}-{meta.tag or 'latest'}.tar.gz" - upload_tarball(src, target) + mlrun.datastore.utils.upload_tarball(src, target) # todo: replace function.yaml inside the tar b.source = target @@ -700,20 +698,6 @@ def deploy( fp.write(function.status.nuclio_name) -@main.command(context_settings=dict(ignore_unknown_options=True)) -@click.argument("pod", type=str, callback=validate_base_argument) -@click.option("--namespace", "-n", help="kubernetes namespace") -@click.option( - "--timeout", "-t", default=600, show_default=True, help="timeout in seconds" -) -def watch(pod, namespace, timeout): - """Read current or previous task (pod) logs.""" - print("This command will be deprecated in future version !!!\n") - k8s = K8sHelper(namespace) - status = k8s.watch(pod, namespace, timeout) - print(f"Pod {pod} last status is: {status}") - - @main.command(context_settings=dict(ignore_unknown_options=True)) @click.argument("kind", type=str, callback=validate_base_argument) @click.argument( @@ -1052,6 +1036,15 @@ def logs(uid, project, offset, db, watch): is_flag=True, help="Store the project secrets as k8s secrets", ) +@click.option( + "--notifications", + "--notification", + "-nt", + multiple=True, + help="To have a notification for the run set notification file " + "destination define: file=notification.json or a " + 'dictionary configuration e.g \'{"slack":{"webhook":""}}\'', +) def project( context, name, @@ -1077,6 +1070,7 @@ def project( timeout, ensure_project, schedule, + notifications, overwrite_schedule, save_secrets, save, @@ -1152,6 +1146,8 @@ def project( "token": proj.get_param("GIT_TOKEN"), }, ) + if notifications: + load_notification(notifications, proj) try: proj.run( name=run, @@ -1169,11 +1165,9 @@ def project( timeout=timeout, overwrite=overwrite_schedule, ) - - except Exception as exc: + except Exception as err: print(traceback.format_exc()) - message = f"failed to run pipeline, {err_to_str(exc)}" - proj.notifiers.push(message, "error") + send_workflow_error_notification(run, proj, err) exit(1) elif sync: @@ -1450,5 +1444,48 @@ def func_url_to_runtime(func_url, ensure_project: bool = False): return runtime +def load_notification(notifications: str, project: mlrun.projects.MlrunProject): + """ + A dictionary or json file containing notification dictionaries can be used by the user to set notifications. + Each notification is stored in a tuple called notifications. + The code then goes through each value in the notifications tuple and check + if the notification starts with "file=", such as "file=notification.json," in those cases it loads the + notification.json file and uses add_notification_to_project to add the notifications from the file to + the project. If not, it adds the notification dictionary to the project. + :param notifications: Notifications file or a dictionary to be added to the project + :param project: The object to which the notifications will be added + :return: + """ + for notification in notifications: + if notification.startswith("file="): + file_path = notification.split("=")[-1] + notification = open(file_path, "r") + notification = json.load(notification) + else: + notification = json.loads(notification) + add_notification_to_project(notification, project) + + +def add_notification_to_project( + notification: str, project: mlrun.projects.MlrunProject +): + for notification_type, notification_params in notification.items(): + project.notifiers.add_notification( + notification_type=notification_type, params=notification_params + ) + + +def send_workflow_error_notification( + run_id: str, project: mlrun.projects.MlrunProject, error: KeyError +): + message = ( + f":x: Failed to run scheduled workflow {run_id} in Project {project.name} !\n" + f"error: ```{err_to_str(error)}```" + ) + project.notifiers.push( + message=message, severity=mlrun.common.schemas.NotificationSeverity.ERROR + ) + + if __name__ == "__main__": main() diff --git a/mlrun/api/api/api.py b/mlrun/api/api/api.py index 8d1966499f38..3d9ae2eebd85 100644 --- a/mlrun/api/api/api.py +++ b/mlrun/api/api/api.py @@ -28,9 +28,9 @@ functions, grafana_proxy, healthz, + hub, internal, logs, - marketplace, model_endpoints, operations, pipelines, @@ -125,8 +125,8 @@ api_router.include_router(grafana_proxy.router, tags=["grafana", "model-endpoints"]) api_router.include_router(model_endpoints.router, tags=["model-endpoints"]) api_router.include_router( - marketplace.router, - tags=["marketplace"], + hub.router, + tags=["hub"], dependencies=[Depends(mlrun.api.api.deps.authenticate_request)], ) api_router.include_router( diff --git a/mlrun/api/api/deps.py b/mlrun/api/api/deps.py index 5d31f297ef5f..50eda225ba26 100644 --- a/mlrun/api/api/deps.py +++ b/mlrun/api/api/deps.py @@ -20,9 +20,9 @@ import mlrun import mlrun.api.db.session -import mlrun.api.schemas import mlrun.api.utils.auth.verifier import mlrun.api.utils.clients.iguazio +import mlrun.common.schemas def get_db_session() -> typing.Generator[Session, None, None]: @@ -35,7 +35,7 @@ def get_db_session() -> typing.Generator[Session, None, None]: mlrun.api.db.session.close_session(db_session) -async def authenticate_request(request: Request) -> mlrun.api.schemas.AuthInfo: +async def authenticate_request(request: Request) -> mlrun.common.schemas.AuthInfo: return await mlrun.api.utils.auth.verifier.AuthVerifier().authenticate_request( request ) @@ -46,7 +46,7 @@ def verify_api_state(request: Request): request.scope ) path = path_with_query_string.split("?")[0] - if mlrun.mlconf.httpdb.state == mlrun.api.schemas.APIStates.offline: + if mlrun.mlconf.httpdb.state == mlrun.common.schemas.APIStates.offline: enabled_endpoints = [ # we want to stay healthy "healthz", @@ -56,10 +56,10 @@ def verify_api_state(request: Request): if not any(enabled_endpoint in path for enabled_endpoint in enabled_endpoints): raise mlrun.errors.MLRunPreconditionFailedError("API is in offline state") if mlrun.mlconf.httpdb.state in [ - mlrun.api.schemas.APIStates.waiting_for_migrations, - mlrun.api.schemas.APIStates.migrations_in_progress, - mlrun.api.schemas.APIStates.migrations_failed, - mlrun.api.schemas.APIStates.waiting_for_chief, + mlrun.common.schemas.APIStates.waiting_for_migrations, + mlrun.common.schemas.APIStates.migrations_in_progress, + mlrun.common.schemas.APIStates.migrations_failed, + mlrun.common.schemas.APIStates.waiting_for_chief, ]: enabled_endpoints = [ "healthz", @@ -70,20 +70,9 @@ def verify_api_state(request: Request): "memory-reports", ] if not any(enabled_endpoint in path for enabled_endpoint in enabled_endpoints): - message = ( - "API is waiting for migrations to be triggered. Send POST request to /api/operations/migrations to" - " trigger it" - ) - if ( - mlrun.mlconf.httpdb.state - == mlrun.api.schemas.APIStates.migrations_in_progress - ): - message = "Migrations are in progress" - elif ( + message = mlrun.common.schemas.APIStates.description( mlrun.mlconf.httpdb.state - == mlrun.api.schemas.APIStates.migrations_failed - ): - message = "Migrations failed, API can't be started" + ) raise mlrun.errors.MLRunPreconditionFailedError(message) diff --git a/mlrun/api/api/endpoints/artifacts.py b/mlrun/api/api/endpoints/artifacts.py index f8a29ef53c65..028ac924446a 100644 --- a/mlrun/api/api/endpoints/artifacts.py +++ b/mlrun/api/api/endpoints/artifacts.py @@ -22,18 +22,23 @@ import mlrun.api.crud import mlrun.api.utils.auth.verifier import mlrun.api.utils.singletons.project_member -from mlrun.api import schemas +import mlrun.common.schemas from mlrun.api.api import deps from mlrun.api.api.utils import log_and_raise -from mlrun.api.schemas.artifact import ArtifactsFormat +from mlrun.common.schemas.artifact import ArtifactsFormat from mlrun.config import config from mlrun.utils import is_legacy_artifact, logger router = APIRouter() -# TODO /artifact/{project}/{uid}/{key:path} should be deprecated in 1.4 -@router.post("/artifact/{project}/{uid}/{key:path}") +# TODO: remove /artifact/{project}/{uid}/{key:path} in 1.6.0 +@router.post( + "/artifact/{project}/{uid}/{key:path}", + deprecated=True, + description="/artifact/{project}/{uid}/{key:path} is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use /projects/{project}/artifacts/{uid}/{key:path} instead", +) @router.post("/projects/{project}/artifacts/{uid}/{key:path}") async def store_artifact( request: Request, @@ -42,7 +47,7 @@ async def store_artifact( key: str, tag: str = "", iter: int = 0, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await run_in_threadpool( @@ -52,10 +57,10 @@ async def store_artifact( auth_info=auth_info, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.artifact, + mlrun.common.schemas.AuthorizationResourceTypes.artifact, project, key, - mlrun.api.schemas.AuthorizationAction.store, + mlrun.common.schemas.AuthorizationAction.store, auth_info, ) @@ -65,7 +70,9 @@ async def store_artifact( except ValueError: log_and_raise(HTTPStatus.BAD_REQUEST.value, reason="bad JSON body") - logger.debug("Storing artifact", data=data) + logger.debug( + "Storing artifact", project=project, uid=uid, key=key, tag=tag, iter=iter + ) await run_in_threadpool( mlrun.api.crud.Artifacts().store_artifact, db_session, @@ -82,13 +89,13 @@ async def store_artifact( @router.get("/projects/{project}/artifact-tags") async def list_artifact_tags( project: str, - category: schemas.ArtifactCategories = None, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + category: mlrun.common.schemas.ArtifactCategories = None, + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) tag_tuples = await run_in_threadpool( @@ -96,7 +103,7 @@ async def list_artifact_tags( ) artifact_key_to_tag = {tag_tuple[1]: tag_tuple[2] for tag_tuple in tag_tuples} allowed_artifact_keys = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.artifact, + mlrun.common.schemas.AuthorizationResourceTypes.artifact, list(artifact_key_to_tag.keys()), lambda artifact_key: ( project, @@ -116,8 +123,13 @@ async def list_artifact_tags( } -# TODO /projects/{project}/artifact/{key:path} should be deprecated in 1.4 -@router.get("/projects/{project}/artifact/{key:path}") +# TODO: remove /projects/{project}/artifact/{key:path} in 1.6.0 +@router.get( + "/projects/{project}/artifact/{key:path}", + deprecated=True, + description="/projects/{project}/artifact/{key:path} is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use /projects/{project}/artifacts/{key:path} instead", +) @router.get("/projects/{project}/artifacts/{key:path}") async def get_artifact( project: str, @@ -125,7 +137,7 @@ async def get_artifact( tag: str = "latest", iter: int = 0, format_: ArtifactsFormat = Query(ArtifactsFormat.full, alias="format"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): data = await run_in_threadpool( @@ -138,10 +150,10 @@ async def get_artifact( format_, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.artifact, + mlrun.common.schemas.AuthorizationResourceTypes.artifact, project, key, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) return { @@ -149,22 +161,27 @@ async def get_artifact( } -# TODO /artifact/{project}/{uid} should be deprecated in 1.4 -@router.delete("/artifact/{project}/{uid}") +# TODO: remove /artifact/{project}/{uid} in 1.6.0 +@router.delete( + "/artifact/{project}/{uid}", + deprecated=True, + description="/artifact/{project}/{uid} is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use /projects/{project}/artifacts/{uid} instead", +) @router.delete("/projects/{project}/artifacts/{uid}") async def delete_artifact( project: str, uid: str, key: str, tag: str = "", - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.artifact, + mlrun.common.schemas.AuthorizationResourceTypes.artifact, project, key, - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) await run_in_threadpool( @@ -173,27 +190,32 @@ async def delete_artifact( return {} -# TODO /artifacts should be deprecated in 1.4 -@router.get("/artifacts") +# TODO: remove /artifacts in 1.6.0 +@router.get( + "/artifacts", + deprecated=True, + description="/artifacts is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use /projects/{project}/artifacts instead", +) @router.get("/projects/{project}/artifacts") async def list_artifacts( project: str = None, name: str = None, tag: str = None, kind: str = None, - category: schemas.ArtifactCategories = None, + category: mlrun.common.schemas.ArtifactCategories = None, labels: List[str] = Query([], alias="label"), iter: int = Query(None, ge=0), best_iteration: bool = Query(False, alias="best-iteration"), format_: ArtifactsFormat = Query(ArtifactsFormat.full, alias="format"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): if project is None: project = config.default_project await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) @@ -212,7 +234,7 @@ async def list_artifacts( ) artifacts = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.artifact, + mlrun.common.schemas.AuthorizationResourceTypes.artifact, artifacts, _artifact_project_and_resource_name_extractor, auth_info, @@ -222,14 +244,19 @@ async def list_artifacts( } -# TODO /artifacts should be deprecated in 1.4 -@router.delete("/artifacts") +# TODO: remove /artifacts in 1.6.0 +@router.delete( + "/artifacts", + deprecated=True, + description="/artifacts is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use /projects/{project}/artifacts instead", +) async def delete_artifacts_legacy( project: str = mlrun.mlconf.default_project, name: str = "", tag: str = "", labels: List[str] = Query([], alias="label"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): return await _delete_artifacts( @@ -248,7 +275,7 @@ async def delete_artifacts( name: str = "", tag: str = "", labels: List[str] = Query([], alias="label"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): return await _delete_artifacts( @@ -266,7 +293,7 @@ async def _delete_artifacts( name: str = None, tag: str = None, labels: List[str] = None, - auth_info: mlrun.api.schemas.AuthInfo = None, + auth_info: mlrun.common.schemas.AuthInfo = None, db_session: Session = None, ): artifacts = await run_in_threadpool( @@ -278,10 +305,10 @@ async def _delete_artifacts( labels, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resources_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.artifact, + mlrun.common.schemas.AuthorizationResourceTypes.artifact, artifacts, _artifact_project_and_resource_name_extractor, - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) await run_in_threadpool( diff --git a/mlrun/api/api/endpoints/auth.py b/mlrun/api/api/endpoints/auth.py index 4d3fd2b2938a..58b8b6ff3574 100644 --- a/mlrun/api/api/endpoints/auth.py +++ b/mlrun/api/api/endpoints/auth.py @@ -15,16 +15,16 @@ import fastapi import mlrun.api.api.deps -import mlrun.api.schemas import mlrun.api.utils.auth.verifier +import mlrun.common.schemas router = fastapi.APIRouter() @router.post("/authorization/verifications") async def verify_authorization( - authorization_verification_input: mlrun.api.schemas.AuthorizationVerificationInput, - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + authorization_verification_input: mlrun.common.schemas.AuthorizationVerificationInput, + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), ): diff --git a/mlrun/api/api/endpoints/background_tasks.py b/mlrun/api/api/endpoints/background_tasks.py index 4586e09878d5..26c80e183716 100644 --- a/mlrun/api/api/endpoints/background_tasks.py +++ b/mlrun/api/api/endpoints/background_tasks.py @@ -18,10 +18,10 @@ from fastapi.concurrency import run_in_threadpool import mlrun.api.api.deps -import mlrun.api.schemas import mlrun.api.utils.auth.verifier import mlrun.api.utils.background_tasks import mlrun.api.utils.clients.chief +import mlrun.common.schemas from mlrun.utils import logger router = fastapi.APIRouter() @@ -29,12 +29,12 @@ @router.get( "/projects/{project}/background-tasks/{name}", - response_model=mlrun.api.schemas.BackgroundTask, + response_model=mlrun.common.schemas.BackgroundTask, ) async def get_project_background_task( project: str, name: str, - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -44,10 +44,10 @@ async def get_project_background_task( # Since there's no not-found option on get_project_background_task - we authorize before getting (unlike other # get endpoint) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.project_background_task, + mlrun.common.schemas.AuthorizationResourceTypes.project_background_task, project, name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) return await run_in_threadpool( @@ -60,12 +60,12 @@ async def get_project_background_task( @router.get( "/background-tasks/{name}", - response_model=mlrun.api.schemas.BackgroundTask, + response_model=mlrun.common.schemas.BackgroundTask, ) async def get_internal_background_task( name: str, request: fastapi.Request, - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), ): @@ -76,14 +76,14 @@ async def get_internal_background_task( igz_version = mlrun.mlconf.get_parsed_igz_version() if igz_version and igz_version >= semver.VersionInfo.parse("3.7.0-b1"): await mlrun.api.utils.auth.verifier.AuthVerifier().query_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.background_task, + mlrun.common.schemas.AuthorizationResourceTypes.background_task, name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): logger.info( "Requesting internal background task, re-routing to chief", diff --git a/mlrun/api/api/endpoints/client_spec.py b/mlrun/api/api/endpoints/client_spec.py index 788276eb90c7..3fd477fbf4d2 100644 --- a/mlrun/api/api/endpoints/client_spec.py +++ b/mlrun/api/api/endpoints/client_spec.py @@ -17,21 +17,21 @@ from fastapi import APIRouter, Header import mlrun.api.crud -import mlrun.api.schemas +import mlrun.common.schemas router = APIRouter() @router.get( "/client-spec", - response_model=mlrun.api.schemas.ClientSpec, + response_model=mlrun.common.schemas.ClientSpec, ) def get_client_spec( client_version: typing.Optional[str] = Header( - None, alias=mlrun.api.schemas.HeaderNames.client_version + None, alias=mlrun.common.schemas.HeaderNames.client_version ), client_python_version: typing.Optional[str] = Header( - None, alias=mlrun.api.schemas.HeaderNames.python_version + None, alias=mlrun.common.schemas.HeaderNames.python_version ), ): return mlrun.api.crud.ClientSpec().get_client_spec( diff --git a/mlrun/api/api/endpoints/clusterization_spec.py b/mlrun/api/api/endpoints/clusterization_spec.py index fc14e755a501..005e490ac461 100644 --- a/mlrun/api/api/endpoints/clusterization_spec.py +++ b/mlrun/api/api/endpoints/clusterization_spec.py @@ -15,17 +15,19 @@ import fastapi import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.clients.chief +import mlrun.common.schemas router = fastapi.APIRouter() -@router.get("/clusterization-spec", response_model=mlrun.api.schemas.ClusterizationSpec) +@router.get( + "/clusterization-spec", response_model=mlrun.common.schemas.ClusterizationSpec +) async def clusterization_spec(): if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): chief_client = mlrun.api.utils.clients.chief.Client() return await chief_client.get_clusterization_spec() diff --git a/mlrun/api/api/endpoints/feature_store.py b/mlrun/api/api/endpoints/feature_store.py index 127af270ba2e..1c5c3f3849b9 100644 --- a/mlrun/api/api/endpoints/feature_store.py +++ b/mlrun/api/api/endpoints/feature_store.py @@ -23,10 +23,10 @@ import mlrun.api.crud import mlrun.api.utils.auth.verifier import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.errors import mlrun.feature_store from mlrun import v3io_cred -from mlrun.api import schemas from mlrun.api.api import deps from mlrun.api.api.utils import log_and_raise, parse_reference from mlrun.data_types import InferOptions @@ -34,15 +34,15 @@ from mlrun.feature_store.api import RunConfig, ingest from mlrun.model import DataSource, DataTargetBase -router = APIRouter() +router = APIRouter(prefix="/projects/{project}") -@router.post("/projects/{project}/feature-sets", response_model=schemas.FeatureSet) +@router.post("/feature-sets", response_model=mlrun.common.schemas.FeatureSet) async def create_feature_set( project: str, - feature_set: schemas.FeatureSet, + feature_set: mlrun.common.schemas.FeatureSet, versioned: bool = True, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await run_in_threadpool( @@ -52,10 +52,10 @@ async def create_feature_set( auth_info=auth_info, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set, + mlrun.common.schemas.AuthorizationResourceTypes.feature_set, project, feature_set.metadata.name, - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) feature_set_uid = await run_in_threadpool( @@ -77,16 +77,16 @@ async def create_feature_set( @router.put( - "/projects/{project}/feature-sets/{name}/references/{reference}", - response_model=schemas.FeatureSet, + "/feature-sets/{name}/references/{reference}", + response_model=mlrun.common.schemas.FeatureSet, ) async def store_feature_set( project: str, name: str, reference: str, - feature_set: schemas.FeatureSet, + feature_set: mlrun.common.schemas.FeatureSet, versioned: bool = True, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await run_in_threadpool( @@ -96,10 +96,10 @@ async def store_feature_set( auth_info=auth_info, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set, + mlrun.common.schemas.AuthorizationResourceTypes.feature_set, project, name, - mlrun.api.schemas.AuthorizationAction.store, + mlrun.common.schemas.AuthorizationAction.store, auth_info, ) tag, uid = parse_reference(reference) @@ -123,23 +123,24 @@ async def store_feature_set( ) -@router.patch("/projects/{project}/feature-sets/{name}/references/{reference}") +@router.patch("/feature-sets/{name}/references/{reference}") async def patch_feature_set( project: str, name: str, feature_set_update: dict, reference: str, - patch_mode: schemas.PatchMode = Header( - schemas.PatchMode.replace, alias=schemas.HeaderNames.patch_mode + patch_mode: mlrun.common.schemas.PatchMode = Header( + mlrun.common.schemas.PatchMode.replace, + alias=mlrun.common.schemas.HeaderNames.patch_mode, ), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set, + mlrun.common.schemas.AuthorizationResourceTypes.feature_set, project, name, - mlrun.api.schemas.AuthorizationAction.update, + mlrun.common.schemas.AuthorizationAction.update, auth_info, ) tag, uid = parse_reference(reference) @@ -157,14 +158,14 @@ async def patch_feature_set( @router.get( - "/projects/{project}/feature-sets/{name}/references/{reference}", - response_model=schemas.FeatureSet, + "/feature-sets/{name}/references/{reference}", + response_model=mlrun.common.schemas.FeatureSet, ) async def get_feature_set( project: str, name: str, reference: str, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): tag, uid = parse_reference(reference) @@ -177,29 +178,29 @@ async def get_feature_set( uid, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set, + mlrun.common.schemas.AuthorizationResourceTypes.feature_set, project, name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) return feature_set -@router.delete("/projects/{project}/feature-sets/{name}") -@router.delete("/projects/{project}/feature-sets/{name}/references/{reference}") +@router.delete("/feature-sets/{name}") +@router.delete("/feature-sets/{name}/references/{reference}") async def delete_feature_set( project: str, name: str, reference: str = None, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set, + mlrun.common.schemas.AuthorizationResourceTypes.feature_set, project, name, - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) tag = uid = None @@ -217,7 +218,8 @@ async def delete_feature_set( @router.get( - "/projects/{project}/feature-sets", response_model=schemas.FeatureSetsOutput + "/feature-sets", + response_model=mlrun.common.schemas.FeatureSetsOutput, ) async def list_feature_sets( project: str, @@ -227,20 +229,22 @@ async def list_feature_sets( entities: List[str] = Query(None, alias="entity"), features: List[str] = Query(None, alias="feature"), labels: List[str] = Query(None, alias="label"), - partition_by: schemas.FeatureStorePartitionByField = Query( + partition_by: mlrun.common.schemas.FeatureStorePartitionByField = Query( None, alias="partition-by" ), rows_per_partition: int = Query(1, alias="rows-per-partition", gt=0), - partition_sort_by: schemas.SortField = Query(None, alias="partition-sort-by"), - partition_order: schemas.OrderType = Query( - schemas.OrderType.desc, alias="partition-order" + partition_sort_by: mlrun.common.schemas.SortField = Query( + None, alias="partition-sort-by" ), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + partition_order: mlrun.common.schemas.OrderType = Query( + mlrun.common.schemas.OrderType.desc, alias="partition-order" + ), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) feature_sets = await run_in_threadpool( @@ -259,7 +263,7 @@ async def list_feature_sets( partition_order, ) feature_sets = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set, + mlrun.common.schemas.AuthorizationResourceTypes.feature_set, feature_sets.feature_sets, lambda feature_set: ( feature_set.metadata.project, @@ -267,17 +271,17 @@ async def list_feature_sets( ), auth_info, ) - return mlrun.api.schemas.FeatureSetsOutput(feature_sets=feature_sets) + return mlrun.common.schemas.FeatureSetsOutput(feature_sets=feature_sets) @router.get( - "/projects/{project}/feature-sets/{name}/tags", - response_model=schemas.FeatureSetsTagsOutput, + "/feature-sets/{name}/tags", + response_model=mlrun.common.schemas.FeatureSetsTagsOutput, ) async def list_feature_sets_tags( project: str, name: str, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): if name != "*": @@ -286,7 +290,7 @@ async def list_feature_sets_tags( ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) tag_tuples = await run_in_threadpool( @@ -298,7 +302,7 @@ async def list_feature_sets_tags( auth_verifier = mlrun.api.utils.auth.verifier.AuthVerifier() allowed_feature_set_names = ( await auth_verifier.filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set, + mlrun.common.schemas.AuthorizationResourceTypes.feature_set, list(feature_set_name_to_tag.keys()), lambda feature_set_name: ( project, @@ -312,7 +316,7 @@ async def list_feature_sets_tags( for tag_tuple in tag_tuples if tag_tuple[1] in allowed_feature_set_names } - return mlrun.api.schemas.FeatureSetsTagsOutput(tags=list(tags)) + return mlrun.common.schemas.FeatureSetsTagsOutput(tags=list(tags)) def _has_v3io_path(data_source, data_targets, feature_set): @@ -341,8 +345,8 @@ def _has_v3io_path(data_source, data_targets, feature_set): @router.post( - "/projects/{project}/feature-sets/{name}/references/{reference}/ingest", - response_model=schemas.FeatureSetIngestOutput, + "/feature-sets/{name}/references/{reference}/ingest", + response_model=mlrun.common.schemas.FeatureSetIngestOutput, status_code=HTTPStatus.ACCEPTED.value, ) async def ingest_feature_set( @@ -350,10 +354,10 @@ async def ingest_feature_set( name: str, reference: str, ingest_parameters: Optional[ - schemas.FeatureSetIngestInput - ] = schemas.FeatureSetIngestInput(), + mlrun.common.schemas.FeatureSetIngestInput + ] = mlrun.common.schemas.FeatureSetIngestInput(), username: str = Header(None, alias="x-remote-user"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): """ @@ -361,17 +365,17 @@ async def ingest_feature_set( that already being happen on client side """ await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set, + mlrun.common.schemas.AuthorizationResourceTypes.feature_set, project, name, - mlrun.api.schemas.AuthorizationAction.update, + mlrun.common.schemas.AuthorizationAction.update, auth_info, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.run, + mlrun.common.schemas.AuthorizationResourceTypes.run, project, "", - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) data_source = data_targets = None @@ -379,10 +383,10 @@ async def ingest_feature_set( data_source = DataSource.from_dict(ingest_parameters.source.dict()) if data_source.schedule: await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.schedule, + mlrun.common.schemas.AuthorizationResourceTypes.schedule, project, "", - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) tag, uid = parse_reference(reference) @@ -398,10 +402,10 @@ async def ingest_feature_set( if feature_set.spec.function and feature_set.spec.function.function_object: function = feature_set.spec.function.function_object await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.function, + mlrun.common.schemas.AuthorizationResourceTypes.function, function.metadata.project, function.metadata.name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) # Need to override the default rundb since we're in the server. @@ -447,25 +451,25 @@ async def ingest_feature_set( run_config=run_config, ) # ingest may modify the feature-set contents, so returning the updated feature-set. - result_feature_set = schemas.FeatureSet(**feature_set.to_dict()) - return schemas.FeatureSetIngestOutput( + result_feature_set = mlrun.common.schemas.FeatureSet(**feature_set.to_dict()) + return mlrun.common.schemas.FeatureSetIngestOutput( feature_set=result_feature_set, run_object=run_params.to_dict() ) -@router.get("/projects/{project}/features", response_model=schemas.FeaturesOutput) +@router.get("/features", response_model=mlrun.common.schemas.FeaturesOutput) async def list_features( project: str, name: str = None, tag: str = None, entities: List[str] = Query(None, alias="entity"), labels: List[str] = Query(None, alias="label"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) features = await run_in_threadpool( @@ -478,7 +482,7 @@ async def list_features( labels, ) features = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature, + mlrun.common.schemas.AuthorizationResourceTypes.feature, features.features, lambda feature_list_output: ( feature_list_output.feature_set_digest.metadata.project, @@ -486,21 +490,21 @@ async def list_features( ), auth_info, ) - return mlrun.api.schemas.FeaturesOutput(features=features) + return mlrun.common.schemas.FeaturesOutput(features=features) -@router.get("/projects/{project}/entities", response_model=schemas.EntitiesOutput) +@router.get("/entities", response_model=mlrun.common.schemas.EntitiesOutput) async def list_entities( project: str, name: str = None, tag: str = None, labels: List[str] = Query(None, alias="label"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) entities = await run_in_threadpool( @@ -512,7 +516,7 @@ async def list_entities( labels, ) entities = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.entity, + mlrun.common.schemas.AuthorizationResourceTypes.entity, entities.entities, lambda entity_list_output: ( entity_list_output.feature_set_digest.metadata.project, @@ -520,17 +524,18 @@ async def list_entities( ), auth_info, ) - return mlrun.api.schemas.EntitiesOutput(entities=entities) + return mlrun.common.schemas.EntitiesOutput(entities=entities) @router.post( - "/projects/{project}/feature-vectors", response_model=schemas.FeatureVector + "/feature-vectors", + response_model=mlrun.common.schemas.FeatureVector, ) async def create_feature_vector( project: str, - feature_vector: schemas.FeatureVector, + feature_vector: mlrun.common.schemas.FeatureVector, versioned: bool = True, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await run_in_threadpool( @@ -540,10 +545,10 @@ async def create_feature_vector( auth_info=auth_info, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_vector, + mlrun.common.schemas.AuthorizationResourceTypes.feature_vector, project, feature_vector.metadata.name, - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) await _verify_feature_vector_features_permissions( @@ -568,14 +573,14 @@ async def create_feature_vector( @router.get( - "/projects/{project}/feature-vectors/{name}/references/{reference}", - response_model=schemas.FeatureVector, + "/feature-vectors/{name}/references/{reference}", + response_model=mlrun.common.schemas.FeatureVector, ) async def get_feature_vector( project: str, name: str, reference: str, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): tag, uid = parse_reference(reference) @@ -588,10 +593,10 @@ async def get_feature_vector( uid, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_vector, + mlrun.common.schemas.AuthorizationResourceTypes.feature_vector, project, name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) await _verify_feature_vector_features_permissions( @@ -601,7 +606,8 @@ async def get_feature_vector( @router.get( - "/projects/{project}/feature-vectors", response_model=schemas.FeatureVectorsOutput + "/feature-vectors", + response_model=mlrun.common.schemas.FeatureVectorsOutput, ) async def list_feature_vectors( project: str, @@ -609,20 +615,22 @@ async def list_feature_vectors( state: str = None, tag: str = None, labels: List[str] = Query(None, alias="label"), - partition_by: schemas.FeatureStorePartitionByField = Query( + partition_by: mlrun.common.schemas.FeatureStorePartitionByField = Query( None, alias="partition-by" ), rows_per_partition: int = Query(1, alias="rows-per-partition", gt=0), - partition_sort_by: schemas.SortField = Query(None, alias="partition-sort-by"), - partition_order: schemas.OrderType = Query( - schemas.OrderType.desc, alias="partition-order" + partition_sort_by: mlrun.common.schemas.SortField = Query( + None, alias="partition-sort-by" + ), + partition_order: mlrun.common.schemas.OrderType = Query( + mlrun.common.schemas.OrderType.desc, alias="partition-order" ), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) feature_vectors = await run_in_threadpool( @@ -639,7 +647,7 @@ async def list_feature_vectors( partition_order, ) feature_vectors = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_vector, + mlrun.common.schemas.AuthorizationResourceTypes.feature_vector, feature_vectors.feature_vectors, lambda feature_vector: ( feature_vector.metadata.project, @@ -653,17 +661,17 @@ async def list_feature_vectors( for fv in feature_vectors ] ) - return mlrun.api.schemas.FeatureVectorsOutput(feature_vectors=feature_vectors) + return mlrun.common.schemas.FeatureVectorsOutput(feature_vectors=feature_vectors) @router.get( - "/projects/{project}/feature-vectors/{name}/tags", - response_model=schemas.FeatureVectorsTagsOutput, + "/feature-vectors/{name}/tags", + response_model=mlrun.common.schemas.FeatureVectorsTagsOutput, ) async def list_feature_vectors_tags( project: str, name: str, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): if name != "*": @@ -672,7 +680,7 @@ async def list_feature_vectors_tags( ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) tag_tuples = await run_in_threadpool( @@ -686,7 +694,7 @@ async def list_feature_vectors_tags( auth_verifier = mlrun.api.utils.auth.verifier.AuthVerifier() allowed_feature_vector_names = ( await auth_verifier.filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_vector, + mlrun.common.schemas.AuthorizationResourceTypes.feature_vector, list(feature_vector_name_to_tag.keys()), lambda feature_vector_name: ( project, @@ -700,20 +708,20 @@ async def list_feature_vectors_tags( for tag_tuple in tag_tuples if tag_tuple[1] in allowed_feature_vector_names } - return mlrun.api.schemas.FeatureVectorsTagsOutput(tags=list(tags)) + return mlrun.common.schemas.FeatureVectorsTagsOutput(tags=list(tags)) @router.put( - "/projects/{project}/feature-vectors/{name}/references/{reference}", - response_model=schemas.FeatureVector, + "/feature-vectors/{name}/references/{reference}", + response_model=mlrun.common.schemas.FeatureVector, ) async def store_feature_vector( project: str, name: str, reference: str, - feature_vector: schemas.FeatureVector, + feature_vector: mlrun.common.schemas.FeatureVector, versioned: bool = True, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await run_in_threadpool( @@ -723,10 +731,10 @@ async def store_feature_vector( auth_info=auth_info, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_vector, + mlrun.common.schemas.AuthorizationResourceTypes.feature_vector, project, name, - mlrun.api.schemas.AuthorizationAction.update, + mlrun.common.schemas.AuthorizationAction.update, auth_info, ) await _verify_feature_vector_features_permissions( @@ -754,23 +762,24 @@ async def store_feature_vector( ) -@router.patch("/projects/{project}/feature-vectors/{name}/references/{reference}") +@router.patch("/feature-vectors/{name}/references/{reference}") async def patch_feature_vector( project: str, name: str, feature_vector_patch: dict, reference: str, - patch_mode: schemas.PatchMode = Header( - schemas.PatchMode.replace, alias=schemas.HeaderNames.patch_mode + patch_mode: mlrun.common.schemas.PatchMode = Header( + mlrun.common.schemas.PatchMode.replace, + alias=mlrun.common.schemas.HeaderNames.patch_mode, ), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_vector, + mlrun.common.schemas.AuthorizationResourceTypes.feature_vector, project, name, - mlrun.api.schemas.AuthorizationAction.update, + mlrun.common.schemas.AuthorizationAction.update, auth_info, ) await _verify_feature_vector_features_permissions( @@ -790,20 +799,20 @@ async def patch_feature_vector( return Response(status_code=HTTPStatus.OK.value) -@router.delete("/projects/{project}/feature-vectors/{name}") -@router.delete("/projects/{project}/feature-vectors/{name}/references/{reference}") +@router.delete("/feature-vectors/{name}") +@router.delete("/feature-vectors/{name}/references/{reference}") async def delete_feature_vector( project: str, name: str, reference: str = None, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_vector, + mlrun.common.schemas.AuthorizationResourceTypes.feature_vector, project, name, - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) tag = uid = None @@ -821,7 +830,7 @@ async def delete_feature_vector( async def _verify_feature_vector_features_permissions( - auth_info: mlrun.api.schemas.AuthInfo, project: str, feature_vector: dict + auth_info: mlrun.common.schemas.AuthInfo, project: str, feature_vector: dict ): features = [] if feature_vector.get("spec", {}).get("features"): @@ -840,12 +849,12 @@ async def _verify_feature_vector_features_permissions( for name in names: feature_set_project_name_tuples.append((_project, name)) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resources_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set, + mlrun.common.schemas.AuthorizationResourceTypes.feature_set, feature_set_project_name_tuples, lambda feature_set_project_name_tuple: ( feature_set_project_name_tuple[0], feature_set_project_name_tuple[1], ), - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) diff --git a/mlrun/api/api/endpoints/files.py b/mlrun/api/api/endpoints/files.py index e2777d4d7d9c..4a52d76a6964 100644 --- a/mlrun/api/api/endpoints/files.py +++ b/mlrun/api/api/endpoints/files.py @@ -20,8 +20,8 @@ import mlrun.api.api.deps import mlrun.api.crud.secrets -import mlrun.api.schemas import mlrun.api.utils.auth.verifier +import mlrun.common.schemas from mlrun.api.api.utils import get_obj_path, get_secrets, log_and_raise from mlrun.datastore import store_manager from mlrun.errors import err_to_str @@ -37,7 +37,7 @@ def get_files( user: str = "", size: int = 0, offset: int = 0, - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), ): @@ -53,13 +53,13 @@ async def get_files_with_project_secrets( size: int = 0, offset: int = 0, use_secrets: bool = fastapi.Query(True, alias="use-secrets"), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) @@ -76,7 +76,7 @@ async def get_files_with_project_secrets( def get_filestat( schema: str = "", path: str = "", - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), user: str = "", @@ -89,7 +89,7 @@ async def get_filestat_with_project_secrets( project: str, schema: str = "", path: str = "", - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), user: str = "", @@ -97,7 +97,7 @@ async def get_filestat_with_project_secrets( ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) @@ -116,7 +116,7 @@ def _get_files( user: str, size: int, offset: int, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, secrets: dict = None, ): _, filename = objpath.split(objpath) @@ -162,7 +162,7 @@ def _get_filestat( schema: str, path: str, user: str, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, secrets: dict = None, ): _, filename = path.split(path) @@ -197,16 +197,16 @@ def _get_filestat( async def _verify_and_get_project_secrets(project, auth_info): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.secret, + mlrun.common.schemas.AuthorizationResourceTypes.secret, project, - mlrun.api.schemas.SecretProviderName.kubernetes, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) secrets_data = await run_in_threadpool( mlrun.api.crud.Secrets().list_project_secrets, project, - mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.SecretProviderName.kubernetes, allow_secrets_from_k8s=True, ) return secrets_data.secrets or {} diff --git a/mlrun/api/api/endpoints/frontend_spec.py b/mlrun/api/api/endpoints/frontend_spec.py index d180d302920f..164e5038a14a 100644 --- a/mlrun/api/api/endpoints/frontend_spec.py +++ b/mlrun/api/api/endpoints/frontend_spec.py @@ -18,9 +18,9 @@ import semver import mlrun.api.api.deps -import mlrun.api.schemas +import mlrun.api.utils.builder import mlrun.api.utils.clients.iguazio -import mlrun.builder +import mlrun.common.schemas import mlrun.runtimes import mlrun.runtimes.utils import mlrun.utils.helpers @@ -33,19 +33,15 @@ @router.get( "/frontend-spec", - response_model=mlrun.api.schemas.FrontendSpec, + response_model=mlrun.common.schemas.FrontendSpec, ) def get_frontend_spec( - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), - # In Iguazio 3.0 auth is turned off, but for this endpoint specifically the session is a must, so getting it from - # the cookie like it was before - # TODO: remove when Iguazio 3.0 is no longer relevant - session: typing.Optional[str] = fastapi.Cookie(None), ): jobs_dashboard_url = None - session = auth_info.session or session + session = auth_info.session if session and is_iguazio_session_cookie(session): jobs_dashboard_url = _resolve_jobs_dashboard_url(session) feature_flags = _resolve_feature_flags() @@ -66,7 +62,7 @@ def get_frontend_spec( function_target_image_name_prefix_template = ( config.httpdb.builder.function_target_image_name_prefix_template ) - return mlrun.api.schemas.FrontendSpec( + return mlrun.common.schemas.FrontendSpec( jobs_dashboard_url=jobs_dashboard_url, abortable_function_kinds=mlrun.runtimes.RuntimeKinds.abortable_runtimes(), feature_flags=feature_flags, @@ -76,7 +72,7 @@ def get_frontend_spec( function_deployment_target_image_template=function_deployment_target_image_template, function_deployment_target_image_name_prefix_template=function_target_image_name_prefix_template, function_deployment_target_image_registries_to_enforce_prefix=registries_to_enforce_prefix, - function_deployment_mlrun_command=mlrun.builder.resolve_mlrun_install_command(), + function_deployment_mlrun_command=_resolve_function_deployment_mlrun_command(), auto_mount_type=config.storage.auto_mount_type, auto_mount_params=config.get_storage_auto_mount_params(), default_artifact_path=config.artifact_path, @@ -90,6 +86,15 @@ def get_frontend_spec( ) +def _resolve_function_deployment_mlrun_command(): + # TODO: When UI adds a requirements section, mlrun should be specified there instead of the commands section i.e. + # frontend spec will contain only the mlrun_version_specifier instead of the full command + mlrun_version_specifier = ( + mlrun.api.utils.builder.resolve_mlrun_install_command_version() + ) + return f'python -m pip install "{mlrun_version_specifier}"' + + def _resolve_jobs_dashboard_url(session: str) -> typing.Optional[str]: iguazio_client = mlrun.api.utils.clients.iguazio.Client() grafana_service_url = iguazio_client.try_get_grafana_service_url(session) @@ -102,25 +107,25 @@ def _resolve_jobs_dashboard_url(session: str) -> typing.Optional[str]: return None -def _resolve_feature_flags() -> mlrun.api.schemas.FeatureFlags: - project_membership = mlrun.api.schemas.ProjectMembershipFeatureFlag.disabled +def _resolve_feature_flags() -> mlrun.common.schemas.FeatureFlags: + project_membership = mlrun.common.schemas.ProjectMembershipFeatureFlag.disabled if mlrun.mlconf.httpdb.authorization.mode == "opa": - project_membership = mlrun.api.schemas.ProjectMembershipFeatureFlag.enabled - authentication = mlrun.api.schemas.AuthenticationFeatureFlag( + project_membership = mlrun.common.schemas.ProjectMembershipFeatureFlag.enabled + authentication = mlrun.common.schemas.AuthenticationFeatureFlag( mlrun.mlconf.httpdb.authentication.mode ) - nuclio_streams = mlrun.api.schemas.NuclioStreamsFeatureFlag.disabled + nuclio_streams = mlrun.common.schemas.NuclioStreamsFeatureFlag.disabled if mlrun.mlconf.get_parsed_igz_version() and semver.VersionInfo.parse( mlrun.runtimes.utils.resolve_nuclio_version() ) >= semver.VersionInfo.parse("1.7.8"): - nuclio_streams = mlrun.api.schemas.NuclioStreamsFeatureFlag.enabled + nuclio_streams = mlrun.common.schemas.NuclioStreamsFeatureFlag.enabled - preemption_nodes = mlrun.api.schemas.PreemptionNodesFeatureFlag.disabled + preemption_nodes = mlrun.common.schemas.PreemptionNodesFeatureFlag.disabled if mlrun.mlconf.is_preemption_nodes_configured(): - preemption_nodes = mlrun.api.schemas.PreemptionNodesFeatureFlag.enabled + preemption_nodes = mlrun.common.schemas.PreemptionNodesFeatureFlag.enabled - return mlrun.api.schemas.FeatureFlags( + return mlrun.common.schemas.FeatureFlags( project_membership=project_membership, authentication=authentication, nuclio_streams=nuclio_streams, diff --git a/mlrun/api/api/endpoints/functions.py b/mlrun/api/api/endpoints/functions.py index 51fd608ad170..fcb35843b254 100644 --- a/mlrun/api/api/endpoints/functions.py +++ b/mlrun/api/api/endpoints/functions.py @@ -31,26 +31,29 @@ Response, ) from fastapi.concurrency import run_in_threadpool +from kubernetes.client.rest import ApiException from sqlalchemy.orm import Session import mlrun.api.crud +import mlrun.api.crud.runtimes.nuclio.function import mlrun.api.db.session -import mlrun.api.schemas +import mlrun.api.launcher import mlrun.api.utils.auth.verifier import mlrun.api.utils.background_tasks import mlrun.api.utils.clients.chief +import mlrun.api.utils.singletons.k8s import mlrun.api.utils.singletons.project_member +import mlrun.common.model_monitoring +import mlrun.common.schemas from mlrun.api.api import deps from mlrun.api.api.utils import get_run_db_instance, log_and_raise, log_path from mlrun.api.crud.secrets import Secrets, SecretsClientType -from mlrun.api.schemas import SecretProviderName, SecretsData -from mlrun.api.utils.singletons.k8s import get_k8s -from mlrun.builder import build_runtime +from mlrun.api.utils.builder import build_runtime +from mlrun.api.utils.singletons.scheduler import get_scheduler from mlrun.config import config from mlrun.errors import MLRunRuntimeError, err_to_str from mlrun.run import new_function from mlrun.runtimes import RuntimeKinds, ServingRuntime, runtime_resources_map -from mlrun.runtimes.function import deploy_nuclio_function, get_nuclio_deploy_status from mlrun.runtimes.utils import get_item_name from mlrun.utils import get_in, logger, parse_versioned_object_uri, update_in from mlrun.utils.model_monitoring import parse_model_endpoint_store_prefix @@ -58,14 +61,20 @@ router = APIRouter() -@router.post("/func/{project}/{name}") +@router.post( + "/func/{project}/{name}", + deprecated=True, + description="/func/{project}/{name} is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use /projects/{project}/functions/{name} instead", +) +@router.post("/projects/{project}/functions/{name}") async def store_function( request: Request, project: str, name: str, tag: str = "", versioned: bool = False, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await run_in_threadpool( @@ -75,10 +84,10 @@ async def store_function( auth_info=auth_info, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.function, + mlrun.common.schemas.AuthorizationResourceTypes.function, project, name, - mlrun.api.schemas.AuthorizationAction.store, + mlrun.common.schemas.AuthorizationAction.store, auth_info, ) data = None @@ -87,7 +96,7 @@ async def store_function( except ValueError: log_and_raise(HTTPStatus.BAD_REQUEST.value, reason="bad JSON body") - logger.debug("Storing function", project=project, name=name, tag=tag, data=data) + logger.debug("Storing function", project=project, name=name, tag=tag) hash_key = await run_in_threadpool( mlrun.api.crud.Functions().store_function, db_session, @@ -103,13 +112,19 @@ async def store_function( } -@router.get("/func/{project}/{name}") +@router.get( + "/func/{project}/{name}", + deprecated=True, + description="/func/{project}/{name} is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use /projects/{project}/functions/{name} instead", +) +@router.get("/projects/{project}/functions/{name}") async def get_function( project: str, name: str, tag: str = "", hash_key="", - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): func = await run_in_threadpool( @@ -121,10 +136,10 @@ async def get_function( hash_key, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.function, + mlrun.common.schemas.AuthorizationResourceTypes.function, project, name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) return { @@ -138,37 +153,73 @@ async def get_function( async def delete_function( project: str, name: str, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.function, + mlrun.common.schemas.AuthorizationResourceTypes.function, project, name, - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) + # If the requested function has a schedule, we must delete it before deleting the function + try: + function_schedule = await run_in_threadpool( + get_scheduler().get_schedule, + db_session, + project, + name, + ) + except mlrun.errors.MLRunNotFoundError: + function_schedule = None + + if function_schedule: + # when deleting a function, we should also delete its schedules if exists + # schedules are only supposed to be run by the chief, therefore, if the function has a schedule, + # and we are running in worker, we send the request to the chief client + if ( + mlrun.mlconf.httpdb.clusterization.role + != mlrun.common.schemas.ClusterizationRole.chief + ): + logger.info( + "Function has a schedule, deleting", + function=name, + project=project, + ) + chief_client = mlrun.api.utils.clients.chief.Client() + await chief_client.delete_schedule(project=project, name=name) + else: + await run_in_threadpool( + get_scheduler().delete_schedule, db_session, project, name + ) await run_in_threadpool( mlrun.api.crud.Functions().delete_function, db_session, project, name ) return Response(status_code=HTTPStatus.NO_CONTENT.value) -@router.get("/funcs") +@router.get( + "/funcs", + deprecated=True, + description="/funcs is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use /projects/{project}/functions instead", +) +@router.get("/projects/{project}/functions") async def list_functions( project: str = None, name: str = None, tag: str = None, labels: List[str] = Query([], alias="label"), hash_key: str = None, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): if project is None: project = config.default_project await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) functions = await run_in_threadpool( @@ -181,7 +232,7 @@ async def list_functions( hash_key=hash_key, ) functions = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.function, + mlrun.common.schemas.AuthorizationResourceTypes.function, functions, lambda function: ( function.get("metadata", {}).get("project", mlrun.mlconf.default_project), @@ -198,13 +249,13 @@ async def list_functions( @router.post("/build/function/") async def build_function( request: Request, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), client_version: Optional[str] = Header( - None, alias=mlrun.api.schemas.HeaderNames.client_version + None, alias=mlrun.common.schemas.HeaderNames.client_version ), client_python_version: Optional[str] = Header( - None, alias=mlrun.api.schemas.HeaderNames.python_version + None, alias=mlrun.common.schemas.HeaderNames.python_version ), ): data = None @@ -224,10 +275,10 @@ async def build_function( auth_info=auth_info, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.function, + mlrun.common.schemas.AuthorizationResourceTypes.function, project, function_name, - mlrun.api.schemas.AuthorizationAction.update, + mlrun.common.schemas.AuthorizationAction.update, auth_info, ) @@ -239,7 +290,7 @@ async def build_function( ).get("track_models", False): if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): logger.info( "Requesting to deploy serving function with track models, re-routing to chief", @@ -274,18 +325,18 @@ async def build_function( } -@router.post("/start/function", response_model=mlrun.api.schemas.BackgroundTask) -@router.post("/start/function/", response_model=mlrun.api.schemas.BackgroundTask) +@router.post("/start/function", response_model=mlrun.common.schemas.BackgroundTask) +@router.post("/start/function/", response_model=mlrun.common.schemas.BackgroundTask) async def start_function( request: Request, background_tasks: BackgroundTasks, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), client_version: Optional[str] = Header( - None, alias=mlrun.api.schemas.HeaderNames.client_version + None, alias=mlrun.common.schemas.HeaderNames.client_version ), client_python_version: Optional[str] = Header( - None, alias=mlrun.api.schemas.HeaderNames.python_version + None, alias=mlrun.common.schemas.HeaderNames.python_version ), ): # TODO: ensure project here !!! for background task @@ -299,10 +350,10 @@ async def start_function( function = await run_in_threadpool(_parse_start_function_body, db_session, data) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.function, + mlrun.common.schemas.AuthorizationResourceTypes.function, function.metadata.project, function.metadata.name, - mlrun.api.schemas.AuthorizationAction.update, + mlrun.common.schemas.AuthorizationAction.update, auth_info, ) background_timeout = mlrun.mlconf.background_tasks.default_timeouts.runtimes.dask @@ -328,7 +379,7 @@ async def start_function( @router.post("/status/function/") async def function_status( request: Request, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), ): data = None try: @@ -352,16 +403,16 @@ async def build_status( logs: bool = True, last_log_timestamp: float = 0.0, verbose: bool = False, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.function, + mlrun.common.schemas.AuthorizationResourceTypes.function, project or mlrun.mlconf.default_project, name, # store since with the current mechanism we update the status (and store the function) in the DB when a client # query for the status - mlrun.api.schemas.AuthorizationAction.store, + mlrun.common.schemas.AuthorizationAction.store, auth_info, ) fn = await run_in_threadpool( @@ -387,13 +438,10 @@ async def build_status( return await run_in_threadpool( _handle_job_deploy_status, db_session, - auth_info, fn, name, project, tag, - last_log_timestamp, - verbose, offset, logs, ) @@ -401,13 +449,10 @@ async def build_status( def _handle_job_deploy_status( db_session, - auth_info, fn, name, project, tag, - last_log_timestamp, - verbose, offset, logs, ): @@ -417,7 +462,7 @@ def _handle_job_deploy_status( image = get_in(fn, "spec.build.image", "") out = b"" if not pod: - if state == mlrun.api.schemas.FunctionState.ready: + if state == mlrun.common.schemas.FunctionState.ready: # when the function has been built we set the created image into the `spec.image` for reference see at the # end of the function where we resolve if the status is ready and then set the spec.build.image to # spec.image @@ -439,8 +484,7 @@ def _handle_job_deploy_status( terminal_states = ["failed", "error", "ready"] log_file = log_path(project, f"build_{name}__{tag or 'latest'}") if state in terminal_states and log_file.exists(): - - if state == mlrun.api.schemas.FunctionState.ready: + if state == mlrun.common.schemas.FunctionState.ready: # when the function has been built we set the created image into the `spec.image` for reference see at the # end of the function where we resolve if the status is ready and then set the spec.build.image to # spec.image @@ -463,20 +507,34 @@ def _handle_job_deploy_status( }, ) - logger.info(f"get pod {pod} status") - state = get_k8s().get_pod_status(pod) - logger.info(f"pod state={state}") + # TODO: change state to pod_status + state = mlrun.api.utils.singletons.k8s.get_k8s_helper(silent=False).get_pod_status( + pod + ) + logger.info("Resolved pod status", pod_status=state, pod_name=pod) if state == "succeeded": - logger.info("build completed successfully") - state = mlrun.api.schemas.FunctionState.ready + logger.info("Build completed successfully") + state = mlrun.common.schemas.FunctionState.ready if state in ["failed", "error"]: - logger.error(f"build {state}, watch the build pod logs: {pod}") - state = mlrun.api.schemas.FunctionState.error + logger.error("Build failed", pod_name=pod, pod_status=state) + state = mlrun.common.schemas.FunctionState.error if (logs and state != "pending") or state in terminal_states: - resp = get_k8s().logs(pod) + try: + resp = mlrun.api.utils.singletons.k8s.get_k8s_helper(silent=False).logs(pod) + except ApiException as exc: + logger.warning( + "Failed to get build logs", + function_name=name, + function_state=state, + pod=pod, + exc_info=exc, + ) + resp = "" + if state in terminal_states: + # TODO: move to log collector log_file.parent.mkdir(parents=True, exist_ok=True) with log_file.open("wb") as fp: fp.write(resp.encode()) @@ -486,11 +544,11 @@ def _handle_job_deploy_status( out = resp[offset:].encode() update_in(fn, "status.state", state) - if state == mlrun.api.schemas.FunctionState.ready: + if state == mlrun.common.schemas.FunctionState.ready: update_in(fn, "spec.image", image) versioned = False - if state == mlrun.api.schemas.FunctionState.ready: + if state == mlrun.common.schemas.FunctionState.ready: versioned = True mlrun.api.crud.Functions().store_function( db_session, @@ -523,7 +581,7 @@ def _handle_nuclio_deploy_status( last_log_timestamp, text, status, - ) = get_nuclio_deploy_status( + ) = mlrun.api.crud.runtimes.nuclio.function.get_nuclio_deploy_status( name, project, tag, @@ -595,7 +653,7 @@ def _handle_nuclio_deploy_status( def _build_function( db_session, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, function, with_mlrun=True, skip_deployed=False, @@ -617,9 +675,9 @@ def _build_function( try: run_db = get_run_db_instance(db_session) fn.set_db_connection(run_db) + mlrun.api.launcher.ServerSideLauncher.enrich_runtime(runtime=fn) fn.save(versioned=False) if fn.kind in RuntimeKinds.nuclio_runtimes(): - mlrun.api.api.utils.apply_enrichment_and_validation_on_function( fn, auth_info, @@ -630,25 +688,37 @@ def _build_function( try: if fn.spec.track_models: logger.info("Tracking enabled, initializing model monitoring") - _init_serving_function_stream_args(fn=fn) - # get model monitoring access key - model_monitoring_access_key = _process_model_monitoring_secret( - db_session, - fn.metadata.project, - "MODEL_MONITORING_ACCESS_KEY", - ) - # initialize model monitoring stream - _create_model_monitoring_stream(project=fn.metadata.project) + + # Generating model monitoring access key + model_monitoring_access_key = None + if not mlrun.mlconf.is_ce_mode(): + model_monitoring_access_key = _process_model_monitoring_secret( + db_session, + fn.metadata.project, + mlrun.common.model_monitoring.ProjectSecretKeys.ACCESS_KEY, + ) + + stream_path = mlrun.utils.model_monitoring.get_stream_path( + project=fn.metadata.project + ) + + if stream_path.startswith("v3io://"): + # Initialize model monitoring V3IO stream + _create_model_monitoring_stream( + project=fn.metadata.project, + function=fn, + stream_path=stream_path, + ) if fn.spec.tracking_policy: - # convert to `TrackingPolicy` object as `fn.spec.tracking_policy` is provided as a dict + # Convert to `TrackingPolicy` object as `fn.spec.tracking_policy` is provided as a dict fn.spec.tracking_policy = ( mlrun.utils.model_monitoring.TrackingPolicy.from_dict( fn.spec.tracking_policy ) ) else: - # initialize tracking policy with default values + # Initialize tracking policy with default values fn.spec.tracking_policy = ( mlrun.utils.model_monitoring.TrackingPolicy() ) @@ -656,10 +726,10 @@ def _build_function( # deploy both model monitoring stream and model monitoring batch job mlrun.api.crud.ModelEndpoints().deploy_monitoring_functions( project=fn.metadata.project, - model_monitoring_access_key=model_monitoring_access_key, db_session=db_session, auth_info=auth_info, tracking_policy=fn.spec.tracking_policy, + model_monitoring_access_key=model_monitoring_access_key, ) except Exception as exc: logger.warning( @@ -669,7 +739,7 @@ def _build_function( traceback=traceback.format_exc(), ) - deploy_nuclio_function( + mlrun.api.crud.runtimes.nuclio.function.deploy_nuclio_function( fn, auth_info=auth_info, client_version=client_version, @@ -698,7 +768,7 @@ def _build_function( client_python_version=client_python_version, ) fn.save(versioned=True) - logger.info("Fn:\n %s", fn.to_yaml()) + logger.info("Resolved function", fn=fn.to_yaml()) except Exception as err: logger.error(traceback.format_exc()) log_and_raise( @@ -731,7 +801,7 @@ def _parse_start_function_body(db_session, data): def _start_function( function, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, client_version: str = None, client_python_version: str = None, ): @@ -769,7 +839,7 @@ def _start_function( mlrun.api.db.session.close_session(db_session) -async def _get_function_status(data, auth_info: mlrun.api.schemas.AuthInfo): +async def _get_function_status(data, auth_info: mlrun.common.schemas.AuthInfo): logger.info(f"function_status:\n{data}") selector = data.get("selector") kind = data.get("kind") @@ -785,10 +855,10 @@ async def _get_function_status(data, auth_info: mlrun.api.schemas.AuthInfo): project, name, _ = mlrun.runtimes.utils.parse_function_selector(selector) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.function, + mlrun.common.schemas.AuthorizationResourceTypes.function, project, name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) @@ -810,10 +880,8 @@ async def _get_function_status(data, auth_info: mlrun.api.schemas.AuthInfo): ) -def _create_model_monitoring_stream(project: str): - stream_path = config.model_endpoint_monitoring.store_prefixes.default.format( - project=project, kind="stream" - ) +def _create_model_monitoring_stream(project: str, function, stream_path): + _init_serving_function_stream_args(fn=function) _, container, stream_path = parse_model_endpoint_store_prefix(stream_path) @@ -874,8 +942,7 @@ def _process_model_monitoring_secret(db_session, project_name: str, secret_key: logger.info( "Getting project secret", project_name=project_name, namespace=config.namespace ) - - provider = SecretProviderName.kubernetes + provider = mlrun.common.schemas.SecretProviderName.kubernetes secret_value = Secrets().get_project_secret( project_name, provider, @@ -896,8 +963,6 @@ def _process_model_monitoring_secret(db_session, project_name: str, secret_key: allow_internal_secrets=True, ) if not secret_value: - import mlrun.api.utils.singletons.project_member - project_owner = mlrun.api.utils.singletons.project_member.get_project_member().get_project_owner( db_session, project_name ) @@ -914,7 +979,9 @@ def _process_model_monitoring_secret(db_session, project_name: str, secret_key: project_owner=project_owner.username, ) - secrets = SecretsData(provider=provider, secrets={internal_key_name: secret_value}) + secrets = mlrun.common.schemas.SecretsData( + provider=provider, secrets={internal_key_name: secret_value} + ) Secrets().store_project_secrets(project_name, secrets, allow_internal_secrets=True) if user_provided_key: logger.info( diff --git a/mlrun/api/api/endpoints/grafana_proxy.py b/mlrun/api/api/endpoints/grafana_proxy.py index eedc09c7dd87..bc6bd44e0142 100644 --- a/mlrun/api/api/endpoints/grafana_proxy.py +++ b/mlrun/api/api/endpoints/grafana_proxy.py @@ -13,83 +13,54 @@ # limitations under the License. # import asyncio -import json +import warnings from http import HTTPStatus -from typing import Any, Dict, List, Optional, Set, Union +from typing import List, Union -import numpy as np -import pandas as pd from fastapi import APIRouter, Depends, Request, Response from fastapi.concurrency import run_in_threadpool from sqlalchemy.orm import Session import mlrun.api.crud -import mlrun.api.schemas +import mlrun.api.crud.model_monitoring.grafana import mlrun.api.utils.auth.verifier +import mlrun.common.model_monitoring +import mlrun.common.schemas from mlrun.api.api import deps -from mlrun.api.schemas import ( - GrafanaColumn, - GrafanaDataPoint, - GrafanaNumberColumn, - GrafanaTable, - GrafanaTimeSeriesTarget, - ProjectsFormat, -) -from mlrun.api.utils.singletons.project_member import get_project_member -from mlrun.errors import MLRunBadRequestError -from mlrun.utils import config, logger -from mlrun.utils.model_monitoring import parse_model_endpoint_store_prefix -from mlrun.utils.v3io_clients import get_frames_client -router = APIRouter() +router = APIRouter(prefix="/grafana-proxy/model-endpoints") + +NAME_TO_SEARCH_FUNCTION_DICTIONARY = { + "list_projects": mlrun.api.crud.model_monitoring.grafana.grafana_list_projects, +} +NAME_TO_QUERY_FUNCTION_DICTIONARY = { + "list_endpoints": mlrun.api.crud.model_monitoring.grafana.grafana_list_endpoints, + "individual_feature_analysis": mlrun.api.crud.model_monitoring.grafana.grafana_individual_feature_analysis, + "overall_feature_analysis": mlrun.api.crud.model_monitoring.grafana.grafana_overall_feature_analysis, + "incoming_features": mlrun.api.crud.model_monitoring.grafana.grafana_incoming_features, +} + +SUPPORTED_QUERY_FUNCTIONS = set(NAME_TO_QUERY_FUNCTION_DICTIONARY.keys()) +SUPPORTED_SEARCH_FUNCTIONS = set(NAME_TO_SEARCH_FUNCTION_DICTIONARY) -@router.get("/grafana-proxy/model-endpoints", status_code=HTTPStatus.OK.value) +@router.get("", status_code=HTTPStatus.OK.value) def grafana_proxy_model_endpoints_check_connection( - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), ): """ Root of grafana proxy for the model-endpoints API, used for validating the model-endpoints data source connectivity. """ - mlrun.api.crud.ModelEndpoints().get_access_key(auth_info) + if not mlrun.mlconf.is_ce_mode(): + mlrun.api.crud.ModelEndpoints().get_access_key(auth_info) return Response(status_code=HTTPStatus.OK.value) -@router.post( - "/grafana-proxy/model-endpoints/query", - response_model=List[Union[GrafanaTable, GrafanaTimeSeriesTarget]], -) -async def grafana_proxy_model_endpoints_query( - request: Request, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), -) -> List[Union[GrafanaTable, GrafanaTimeSeriesTarget]]: - """ - Query route for model-endpoints grafana proxy API, used for creating an interface between grafana queries and - model-endpoints logic. - - This implementation requires passing target_endpoint query parameter in order to dispatch different - model-endpoint monitoring functions. - """ - body = await request.json() - query_parameters = _parse_query_parameters(body) - _validate_query_parameters(query_parameters, SUPPORTED_QUERY_FUNCTIONS) - query_parameters = _drop_grafana_escape_chars(query_parameters) - - # At this point everything is validated and we can access everything that is needed without performing all previous - # checks again. - target_endpoint = query_parameters["target_endpoint"] - function = NAME_TO_QUERY_FUNCTION_DICTIONARY[target_endpoint] - if asyncio.iscoroutinefunction(function): - return await function(body, query_parameters, auth_info) - result = await run_in_threadpool(function, body, query_parameters, auth_info) - return result - - -@router.post("/grafana-proxy/model-endpoints/search", response_model=List[str]) +@router.post("/search", response_model=List[str]) async def grafana_proxy_model_endpoints_search( request: Request, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ) -> List[str]: """ @@ -98,411 +69,87 @@ async def grafana_proxy_model_endpoints_search( This implementation requires passing target_endpoint query parameter in order to dispatch different model-endpoint monitoring functions. + + :param request: An api request with the required target and parameters. + :param auth_info: The auth info of the request. + :param db_session: A session that manages the current dialog with the database. + + :return: List of results. e.g. list of available project names. """ - mlrun.api.crud.ModelEndpoints().get_access_key(auth_info) + if not mlrun.mlconf.is_ce_mode(): + mlrun.api.crud.ModelEndpoints().get_access_key(auth_info) body = await request.json() - query_parameters = _parse_search_parameters(body) - - _validate_query_parameters(query_parameters, SUPPORTED_SEARCH_FUNCTIONS) + query_parameters = mlrun.api.crud.model_monitoring.grafana.parse_search_parameters( + body + ) + mlrun.api.crud.model_monitoring.grafana.validate_query_parameters( + query_parameters, SUPPORTED_SEARCH_FUNCTIONS + ) # At this point everything is validated and we can access everything that is needed without performing all previous # checks again. target_endpoint = query_parameters["target_endpoint"] function = NAME_TO_SEARCH_FUNCTION_DICTIONARY[target_endpoint] - if asyncio.iscoroutinefunction(function): - return await function(db_session, auth_info) - result = await run_in_threadpool(function, db_session, auth_info) - return result - -def grafana_list_projects( - db_session: Session, auth_info: mlrun.api.schemas.AuthInfo -) -> List[str]: - projects_output = get_project_member().list_projects( - db_session, format_=ProjectsFormat.name_only, leader_session=auth_info.session - ) - return projects_output.projects - - -async def grafana_list_endpoints( - body: Dict[str, Any], - query_parameters: Dict[str, str], - auth_info: mlrun.api.schemas.AuthInfo, -) -> List[GrafanaTable]: - project = query_parameters.get("project") - - # Filters - model = query_parameters.get("model", None) - function = query_parameters.get("function", None) - labels = query_parameters.get("labels", "") - labels = labels.split(",") if labels else [] - - # Metrics to include - metrics = query_parameters.get("metrics", "") - metrics = metrics.split(",") if metrics else [] - - # Time range for metrics - start = body.get("rangeRaw", {}).get("start", "now-1h") - end = body.get("rangeRaw", {}).get("end", "now") - - if project: - await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( - project, - mlrun.api.schemas.AuthorizationAction.read, - auth_info, + if asyncio.iscoroutinefunction(function): + result = await function(db_session, auth_info, query_parameters) + else: + result = await run_in_threadpool( + function, db_session, auth_info, query_parameters ) - endpoint_list = await run_in_threadpool( - mlrun.api.crud.ModelEndpoints().list_model_endpoints, - auth_info=auth_info, - project=project, - model=model, - function=function, - labels=labels, - metrics=metrics, - start=start, - end=end, - ) - allowed_endpoints = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, - endpoint_list.endpoints, - lambda _endpoint: ( - _endpoint.metadata.project, - _endpoint.metadata.uid, - ), - auth_info, - ) - endpoint_list.endpoints = allowed_endpoints - - columns = [ - GrafanaColumn(text="endpoint_id", type="string"), - GrafanaColumn(text="endpoint_function", type="string"), - GrafanaColumn(text="endpoint_model", type="string"), - GrafanaColumn(text="endpoint_model_class", type="string"), - GrafanaColumn(text="first_request", type="time"), - GrafanaColumn(text="last_request", type="time"), - GrafanaColumn(text="accuracy", type="number"), - GrafanaColumn(text="error_count", type="number"), - GrafanaColumn(text="drift_status", type="number"), - ] - - metric_columns = [] - - found_metrics = set() - for endpoint in endpoint_list.endpoints: - if endpoint.status.metrics is not None: - for key in endpoint.status.metrics.keys(): - if key not in found_metrics: - found_metrics.add(key) - metric_columns.append(GrafanaColumn(text=key, type="number")) - - columns = columns + metric_columns - table = GrafanaTable(columns=columns) - - for endpoint in endpoint_list.endpoints: - row = [ - endpoint.metadata.uid, - endpoint.spec.function_uri, - endpoint.spec.model, - endpoint.spec.model_class, - endpoint.status.first_request, - endpoint.status.last_request, - endpoint.status.accuracy, - endpoint.status.error_count, - endpoint.status.drift_status, - ] - - if endpoint.status.metrics is not None and metric_columns: - for metric_column in metric_columns: - row.append(endpoint.status.metrics[metric_column.text]) - - table.add_row(*row) - - return [table] - - -async def grafana_individual_feature_analysis( - body: Dict[str, Any], - query_parameters: Dict[str, str], - auth_info: mlrun.api.schemas.AuthInfo, -): - endpoint_id = query_parameters.get("endpoint_id") - project = query_parameters.get("project") - await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, - project, - endpoint_id, - mlrun.api.schemas.AuthorizationAction.read, - auth_info, - ) - - endpoint = await run_in_threadpool( - mlrun.api.crud.ModelEndpoints().get_model_endpoint, - auth_info=auth_info, - project=project, - endpoint_id=endpoint_id, - feature_analysis=True, - ) + return result - # Load JSON data from KV, make sure not to fail if a field is missing - feature_stats = endpoint.status.feature_stats or {} - current_stats = endpoint.status.current_stats or {} - drift_measures = endpoint.status.drift_measures or {} - table = GrafanaTable( - columns=[ - GrafanaColumn(text="feature_name", type="string"), - GrafanaColumn(text="actual_min", type="number"), - GrafanaColumn(text="actual_mean", type="number"), - GrafanaColumn(text="actual_max", type="number"), - GrafanaColumn(text="expected_min", type="number"), - GrafanaColumn(text="expected_mean", type="number"), - GrafanaColumn(text="expected_max", type="number"), - GrafanaColumn(text="tvd", type="number"), - GrafanaColumn(text="hellinger", type="number"), - GrafanaColumn(text="kld", type="number"), +@router.post( + "/query", + response_model=List[ + Union[ + mlrun.common.schemas.GrafanaTable, + mlrun.common.schemas.GrafanaTimeSeriesTarget, ] - ) - - for feature, base_stat in feature_stats.items(): - current_stat = current_stats.get(feature, {}) - drift_measure = drift_measures.get(feature, {}) - - table.add_row( - feature, - current_stat.get("min"), - current_stat.get("mean"), - current_stat.get("max"), - base_stat.get("min"), - base_stat.get("mean"), - base_stat.get("max"), - drift_measure.get("tvd"), - drift_measure.get("hellinger"), - drift_measure.get("kld"), - ) - - return [table] - + ], +) +async def grafana_proxy_model_endpoints_query( + request: Request, + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), +) -> List[ + Union[ + mlrun.common.schemas.GrafanaTable, mlrun.common.schemas.GrafanaTimeSeriesTarget + ] +]: + """ + Query route for model-endpoints grafana proxy API, used for creating an interface between grafana queries and + model-endpoints logic. -async def grafana_overall_feature_analysis( - body: Dict[str, Any], - query_parameters: Dict[str, str], - auth_info: mlrun.api.schemas.AuthInfo, -): - endpoint_id = query_parameters.get("endpoint_id") - project = query_parameters.get("project") - await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, - project, - endpoint_id, - mlrun.api.schemas.AuthorizationAction.read, - auth_info, - ) - endpoint = await run_in_threadpool( - mlrun.api.crud.ModelEndpoints().get_model_endpoint, - auth_info=auth_info, - project=project, - endpoint_id=endpoint_id, - feature_analysis=True, - ) + This implementation requires passing target_endpoint query parameter in order to dispatch different + model-endpoint monitoring functions. + """ - table = GrafanaTable( - columns=[ - GrafanaNumberColumn(text="tvd_sum"), - GrafanaNumberColumn(text="tvd_mean"), - GrafanaNumberColumn(text="hellinger_sum"), - GrafanaNumberColumn(text="hellinger_mean"), - GrafanaNumberColumn(text="kld_sum"), - GrafanaNumberColumn(text="kld_mean"), - ] + warnings.warn( + "This api is deprecated in 1.3.1 and will be removed in 1.5.0. " + "Please update grafana model monitoring dashboards that use a different data source", + # TODO: remove in 1.5.0 + FutureWarning, ) - if endpoint.status.drift_measures: - table.add_row( - endpoint.status.drift_measures.get("tvd_sum"), - endpoint.status.drift_measures.get("tvd_mean"), - endpoint.status.drift_measures.get("hellinger_sum"), - endpoint.status.drift_measures.get("hellinger_mean"), - endpoint.status.drift_measures.get("kld_sum"), - endpoint.status.drift_measures.get("kld_mean"), - ) - - return [table] - - -async def grafana_incoming_features( - body: Dict[str, Any], - query_parameters: Dict[str, str], - auth_info: mlrun.api.schemas.AuthInfo, -): - endpoint_id = query_parameters.get("endpoint_id") - project = query_parameters.get("project") - start = body.get("rangeRaw", {}).get("from", "now-1h") - end = body.get("rangeRaw", {}).get("to", "now") - - await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, - project, - endpoint_id, - mlrun.api.schemas.AuthorizationAction.read, - auth_info, + body = await request.json() + query_parameters = mlrun.api.crud.model_monitoring.grafana.parse_query_parameters( + body ) - - endpoint = await run_in_threadpool( - mlrun.api.crud.ModelEndpoints().get_model_endpoint, - auth_info=auth_info, - project=project, - endpoint_id=endpoint_id, + mlrun.api.crud.model_monitoring.grafana.validate_query_parameters( + query_parameters, SUPPORTED_QUERY_FUNCTIONS ) - - time_series = [] - - feature_names = endpoint.spec.feature_names - - if not feature_names: - logger.warn( - "'feature_names' is either missing or not initialized in endpoint record", - endpoint_id=endpoint.metadata.uid, + query_parameters = ( + mlrun.api.crud.model_monitoring.grafana.drop_grafana_escape_chars( + query_parameters ) - return time_series - - path = config.model_endpoint_monitoring.store_prefixes.default.format( - project=project, kind=mlrun.api.schemas.ModelMonitoringStoreKinds.EVENTS - ) - _, container, path = parse_model_endpoint_store_prefix(path) - - client = get_frames_client( - token=auth_info.data_session, - address=config.v3io_framesd, - container=container, ) - data: pd.DataFrame = await run_in_threadpool( - client.read, - backend="tsdb", - table=path, - columns=feature_names, - filter=f"endpoint_id=='{endpoint_id}'", - start=start, - end=end, - ) - - data.drop(["endpoint_id"], axis=1, inplace=True, errors="ignore") - data.index = data.index.astype(np.int64) // 10**6 - - for feature, indexed_values in data.to_dict().items(): - target = GrafanaTimeSeriesTarget(target=feature) - for index, value in indexed_values.items(): - data_point = GrafanaDataPoint(value=float(value), timestamp=index) - target.add_data_point(data_point) - time_series.append(target) - - return time_series - - -def _parse_query_parameters(request_body: Dict[str, Any]) -> Dict[str, str]: - """ - This function searches for the target field in Grafana's SimpleJson json. Once located, the target string is - parsed by splitting on semi-colons (;). Each part in the resulting list is then split by an equal sign (=) to be - read as key-value pairs. - """ - - # Try to get the target - targets = request_body.get("targets", []) - - if len(targets) > 1: - logger.warn( - f"The 'targets' list contains more then one element ({len(targets)}), all targets except the first one are " - f"ignored." - ) - - target_obj = targets[0] if targets else {} - target_query = target_obj.get("target") if target_obj else "" - - if not target_query: - raise MLRunBadRequestError(f"Target missing in request body:\n {request_body}") - - parameters = _parse_parameters(target_query) - - return parameters - - -def _parse_search_parameters(request_body: Dict[str, Any]) -> Dict[str, str]: - """ - This function searches for the target field in Grafana's SimpleJson json. Once located, the target string is - parsed by splitting on semi-colons (;). Each part in the resulting list is then split by an equal sign (=) to be - read as key-value pairs. - """ - - # Try to get the target - target = request_body.get("target") - - if not target: - raise MLRunBadRequestError(f"Target missing in request body:\n {request_body}") - - parameters = _parse_parameters(target) - - return parameters - - -def _parse_parameters(target_query): - parameters = {} - for query in filter(lambda q: q, target_query.split(";")): - query_parts = query.split("=") - if len(query_parts) < 2: - raise MLRunBadRequestError( - f"Query must contain both query key and query value. Expected query_key=query_value, found {query} " - f"instead." - ) - parameters[query_parts[0]] = query_parts[1] - return parameters - - -def _drop_grafana_escape_chars(query_parameters: Dict[str, str]): - query_parameters = dict(query_parameters) - endpoint_id = query_parameters.get("endpoint_id") - if endpoint_id is not None: - query_parameters["endpoint_id"] = endpoint_id.replace("\\", "") - return query_parameters - - -def _validate_query_parameters( - query_parameters: Dict[str, str], supported_endpoints: Optional[Set[str]] = None -): - """Validates the parameters sent via Grafana's SimpleJson query""" - if "target_endpoint" not in query_parameters: - raise MLRunBadRequestError( - f"Expected 'target_endpoint' field in query, found {query_parameters} instead" - ) - - if ( - supported_endpoints is not None - and query_parameters["target_endpoint"] not in supported_endpoints - ): - raise MLRunBadRequestError( - f"{query_parameters['target_endpoint']} unsupported in query parameters: {query_parameters}. " - f"Currently supports: {','.join(supported_endpoints)}" - ) - - -def _json_loads_or_default(string: Optional[str], default: Any): - if string is None: - return default - obj = json.loads(string) - if not obj: - return default - return obj - - -NAME_TO_QUERY_FUNCTION_DICTIONARY = { - "list_endpoints": grafana_list_endpoints, - "individual_feature_analysis": grafana_individual_feature_analysis, - "overall_feature_analysis": grafana_overall_feature_analysis, - "incoming_features": grafana_incoming_features, -} - -NAME_TO_SEARCH_FUNCTION_DICTIONARY = { - "list_projects": grafana_list_projects, -} - -SUPPORTED_QUERY_FUNCTIONS = set(NAME_TO_QUERY_FUNCTION_DICTIONARY.keys()) -SUPPORTED_SEARCH_FUNCTIONS = set(NAME_TO_SEARCH_FUNCTION_DICTIONARY) + # At this point everything is validated and we can access everything that is needed without performing all previous + # checks again. + target_endpoint = query_parameters["target_endpoint"] + function = NAME_TO_QUERY_FUNCTION_DICTIONARY[target_endpoint] + if asyncio.iscoroutinefunction(function): + return await function(body, query_parameters, auth_info) + result = await run_in_threadpool(function, body, query_parameters, auth_info) + return result diff --git a/mlrun/api/api/endpoints/healthz.py b/mlrun/api/api/endpoints/healthz.py index d1cb2a1f73c6..bc0924329d30 100644 --- a/mlrun/api/api/endpoints/healthz.py +++ b/mlrun/api/api/endpoints/healthz.py @@ -12,20 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import http + from fastapi import APIRouter -import mlrun.api.crud -import mlrun.api.schemas +import mlrun.common.schemas +from mlrun.config import config as mlconfig router = APIRouter() @router.get( "/healthz", - response_model=mlrun.api.schemas.ClientSpec, + status_code=http.HTTPStatus.OK.value, ) def health(): - # TODO: From 0.7.0 client uses the /client-spec endpoint, - # when this is the oldest relevant client, remove this logic from the healthz endpoint - return mlrun.api.crud.ClientSpec().get_client_spec() + # offline is the initial state + # waiting for chief is set for workers waiting for chief to be ready and then clusterize against it + if mlconfig.httpdb.state in [ + mlrun.common.schemas.APIStates.offline, + mlrun.common.schemas.APIStates.waiting_for_chief, + ]: + raise mlrun.errors.MLRunServiceUnavailableError() + + return { + # for old `align_mlrun.sh` scripts expecting `version` in the response + # TODO: remove on mlrun >= 1.6.0 + "version": mlconfig.version, + "status": "ok", + } diff --git a/mlrun/api/api/endpoints/hub.py b/mlrun/api/api/endpoints/hub.py new file mode 100644 index 000000000000..91e5bbc80456 --- /dev/null +++ b/mlrun/api/api/endpoints/hub.py @@ -0,0 +1,320 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import mimetypes +from http import HTTPStatus +from typing import List, Optional + +from fastapi import APIRouter, Depends, Query, Response +from fastapi.concurrency import run_in_threadpool +from sqlalchemy.orm import Session + +import mlrun +import mlrun.api.api.deps +import mlrun.api.crud +import mlrun.api.utils.auth.verifier +import mlrun.api.utils.singletons.db +import mlrun.common.schemas +import mlrun.common.schemas.hub + +router = APIRouter(prefix="/hub/sources") + + +@router.post( + path="", + status_code=HTTPStatus.CREATED.value, + response_model=mlrun.common.schemas.hub.IndexedHubSource, +) +async def create_source( + source: mlrun.common.schemas.hub.IndexedHubSource, + db_session: Session = Depends(mlrun.api.api.deps.get_db_session), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), +): + await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.hub_source, + mlrun.common.schemas.AuthorizationAction.create, + auth_info, + ) + + await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().create_hub_source, db_session, source + ) + # Handle credentials if they exist + await run_in_threadpool(mlrun.api.crud.Hub().add_source, source.source) + return await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().get_hub_source, + db_session, + source.source.metadata.name, + ) + + +@router.get( + path="", + response_model=List[mlrun.common.schemas.hub.IndexedHubSource], +) +async def list_sources( + db_session: Session = Depends(mlrun.api.api.deps.get_db_session), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), +): + await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.hub_source, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + + return await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().list_hub_sources, db_session + ) + + +@router.delete( + path="/{source_name}", + status_code=HTTPStatus.NO_CONTENT.value, +) +async def delete_source( + source_name: str, + db_session: Session = Depends(mlrun.api.api.deps.get_db_session), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), +): + await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.hub_source, + mlrun.common.schemas.AuthorizationAction.delete, + auth_info, + ) + + await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().delete_hub_source, + db_session, + source_name, + ) + await run_in_threadpool(mlrun.api.crud.Hub().remove_source, source_name) + + +@router.get( + path="/{source_name}", + response_model=mlrun.common.schemas.hub.IndexedHubSource, +) +async def get_source( + source_name: str, + db_session: Session = Depends(mlrun.api.api.deps.get_db_session), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), +): + hub_source = await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().get_hub_source, db_session, source_name + ) + await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.hub_source, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + + return hub_source + + +@router.put( + path="/{source_name}", + response_model=mlrun.common.schemas.hub.IndexedHubSource, +) +async def store_source( + source_name: str, + source: mlrun.common.schemas.hub.IndexedHubSource, + db_session: Session = Depends(mlrun.api.api.deps.get_db_session), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), +): + await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.hub_source, + mlrun.common.schemas.AuthorizationAction.store, + auth_info, + ) + + await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().store_hub_source, + db_session, + source_name, + source, + ) + # Handle credentials if they exist + await run_in_threadpool(mlrun.api.crud.Hub().add_source, source.source) + + return await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().get_hub_source, db_session, source_name + ) + + +@router.get( + path="/{source_name}/items", + response_model=mlrun.common.schemas.hub.HubCatalog, +) +async def get_catalog( + source_name: str, + version: Optional[str] = Query(None), + tag: Optional[str] = Query(None), + force_refresh: Optional[bool] = Query(False, alias="force-refresh"), + db_session: Session = Depends(mlrun.api.api.deps.get_db_session), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), +): + ordered_source = await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().get_hub_source, db_session, source_name + ) + await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.hub_source, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + + return await run_in_threadpool( + mlrun.api.crud.Hub().get_source_catalog, + ordered_source.source, + version, + tag, + force_refresh, + ) + + +@router.get( + "/{source_name}/items/{item_name}", + response_model=mlrun.common.schemas.hub.HubItem, +) +async def get_item( + source_name: str, + item_name: str, + version: Optional[str] = Query(None), + tag: Optional[str] = Query("latest"), + force_refresh: Optional[bool] = Query(False, alias="force-refresh"), + db_session: Session = Depends(mlrun.api.api.deps.get_db_session), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), +): + ordered_source = await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().get_hub_source, db_session, source_name + ) + await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.hub_source, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + + return await run_in_threadpool( + mlrun.api.crud.Hub().get_item, + ordered_source.source, + item_name, + version, + tag, + force_refresh, + ) + + +@router.get( + "/{source_name}/item-object", +) +async def get_object( + source_name: str, + url: str, + db_session: Session = Depends(mlrun.api.api.deps.get_db_session), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), +): + ordered_source = await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().get_hub_source, db_session, source_name + ) + object_data = await run_in_threadpool( + mlrun.api.crud.Hub().get_item_object_using_source_credentials, + ordered_source.source, + url, + ) + await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.hub_source, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + + if url.endswith("/"): + return object_data + + ctype, _ = mimetypes.guess_type(url) + if not ctype: + ctype = "application/octet-stream" + return Response(content=object_data, media_type=ctype) + + +@router.get("/{source_name}/items/{item_name}/assets/{asset_name}") +async def get_asset( + source_name: str, + item_name: str, + asset_name: str, + tag: Optional[str] = Query("latest"), + version: Optional[str] = Query(None), + db_session: Session = Depends(mlrun.api.api.deps.get_db_session), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), +): + """ + Retrieve asset from a specific item in specific hub source. + + :param source_name: hub source name + :param item_name: the name of the item + :param asset_name: the name of the asset to retrieve + :param tag: tag of item - latest or version number + :param version: item version + :param db_session: a session that manages the current dialog with the database + :param auth_info: the auth info of the request + + :return: fastapi response with the asset in content + """ + source = await run_in_threadpool( + mlrun.api.utils.singletons.db.get_db().get_hub_source, db_session, source_name + ) + + await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.hub_source, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + # Getting the relevant item which hold the asset information + item = await run_in_threadpool( + mlrun.api.crud.Hub().get_item, + source.source, + item_name, + version, + tag, + ) + + # Getting the asset from the item + asset, url = await run_in_threadpool( + mlrun.api.crud.Hub().get_asset, + source.source, + item, + asset_name, + ) + + ctype, _ = mimetypes.guess_type(url) + if not ctype: + ctype = "application/octet-stream" + return Response(content=asset, media_type=ctype) diff --git a/mlrun/api/api/endpoints/internal/memory_reports.py b/mlrun/api/api/endpoints/internal/memory_reports.py index fadc04a1e85d..9ba58c71f8b5 100644 --- a/mlrun/api/api/endpoints/internal/memory_reports.py +++ b/mlrun/api/api/endpoints/internal/memory_reports.py @@ -14,26 +14,26 @@ # import fastapi -import mlrun.api.schemas import mlrun.api.utils.memory_reports +import mlrun.common.schemas -router = fastapi.APIRouter() +router = fastapi.APIRouter(prefix="/memory-reports") @router.get( - "/memory-reports/common-types", - response_model=mlrun.api.schemas.MostCommonObjectTypesReport, + "/common-types", + response_model=mlrun.common.schemas.MostCommonObjectTypesReport, ) def get_most_common_objects_report(): report = ( mlrun.api.utils.memory_reports.MemoryUsageReport().create_most_common_objects_report() ) - return mlrun.api.schemas.MostCommonObjectTypesReport(object_types=report) + return mlrun.common.schemas.MostCommonObjectTypesReport(object_types=report) @router.get( - "/memory-reports/{object_type}", - response_model=mlrun.api.schemas.ObjectTypeReport, + "/{object_type}", + response_model=mlrun.common.schemas.ObjectTypeReport, ) def get_memory_usage_report( object_type: str, @@ -47,7 +47,7 @@ def get_memory_usage_report( object_type, sample_size, start_index, create_graph, max_depth ) ) - return mlrun.api.schemas.ObjectTypeReport( + return mlrun.common.schemas.ObjectTypeReport( object_type=object_type, sample_size=sample_size, start_index=start_index, diff --git a/mlrun/api/api/endpoints/logs.py b/mlrun/api/api/endpoints/logs.py index 24e45e0bedc8..3a0df343ac80 100644 --- a/mlrun/api/api/endpoints/logs.py +++ b/mlrun/api/api/endpoints/logs.py @@ -18,27 +18,27 @@ import mlrun.api.api.deps import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.auth.verifier +import mlrun.common.schemas -router = fastapi.APIRouter() +router = fastapi.APIRouter(prefix="/log/{project}") -@router.post("/log/{project}/{uid}") +@router.post("/{uid}") async def store_log( request: fastapi.Request, project: str, uid: str, append: bool = True, - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.log, + mlrun.common.schemas.AuthorizationResourceTypes.log, project, uid, - mlrun.api.schemas.AuthorizationAction.store, + mlrun.common.schemas.AuthorizationAction.store, auth_info, ) body = await request.body() @@ -52,13 +52,13 @@ async def store_log( return {} -@router.get("/log/{project}/{uid}") +@router.get("/{uid}") async def get_log( project: str, uid: str, size: int = -1, offset: int = 0, - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -66,10 +66,10 @@ async def get_log( ), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.log, + mlrun.common.schemas.AuthorizationResourceTypes.log, project, uid, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) run_state, log_stream = await mlrun.api.crud.Logs().get_logs( diff --git a/mlrun/api/api/endpoints/marketplace.py b/mlrun/api/api/endpoints/marketplace.py deleted file mode 100644 index 3118dd597537..000000000000 --- a/mlrun/api/api/endpoints/marketplace.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import mimetypes -from http import HTTPStatus -from typing import List, Optional - -from fastapi import APIRouter, Depends, Query, Response -from fastapi.concurrency import run_in_threadpool -from sqlalchemy.orm import Session - -import mlrun -import mlrun.api.api.deps -import mlrun.api.crud -import mlrun.api.utils.auth.verifier -from mlrun.api.schemas import AuthorizationAction -from mlrun.api.schemas.marketplace import ( - IndexedMarketplaceSource, - MarketplaceCatalog, - MarketplaceItem, -) -from mlrun.api.utils.singletons.db import get_db - -router = APIRouter() - - -@router.post( - path="/marketplace/sources", - status_code=HTTPStatus.CREATED.value, - response_model=IndexedMarketplaceSource, -) -async def create_source( - source: IndexedMarketplaceSource, - db_session: Session = Depends(mlrun.api.api.deps.get_db_session), - auth_info: mlrun.api.schemas.AuthInfo = Depends( - mlrun.api.api.deps.authenticate_request - ), -): - await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.marketplace_source, - AuthorizationAction.create, - auth_info, - ) - - await run_in_threadpool(get_db().create_marketplace_source, db_session, source) - # Handle credentials if they exist - await run_in_threadpool(mlrun.api.crud.Marketplace().add_source, source.source) - return await run_in_threadpool( - get_db().get_marketplace_source, db_session, source.source.metadata.name - ) - - -@router.get( - path="/marketplace/sources", - response_model=List[IndexedMarketplaceSource], -) -async def list_sources( - db_session: Session = Depends(mlrun.api.api.deps.get_db_session), - auth_info: mlrun.api.schemas.AuthInfo = Depends( - mlrun.api.api.deps.authenticate_request - ), -): - await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.marketplace_source, - AuthorizationAction.read, - auth_info, - ) - - return await run_in_threadpool(get_db().list_marketplace_sources, db_session) - - -@router.delete( - path="/marketplace/sources/{source_name}", - status_code=HTTPStatus.NO_CONTENT.value, -) -async def delete_source( - source_name: str, - db_session: Session = Depends(mlrun.api.api.deps.get_db_session), - auth_info: mlrun.api.schemas.AuthInfo = Depends( - mlrun.api.api.deps.authenticate_request - ), -): - await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.marketplace_source, - AuthorizationAction.delete, - auth_info, - ) - - await run_in_threadpool(get_db().delete_marketplace_source, db_session, source_name) - await run_in_threadpool(mlrun.api.crud.Marketplace().remove_source, source_name) - - -@router.get( - path="/marketplace/sources/{source_name}", - response_model=IndexedMarketplaceSource, -) -async def get_source( - source_name: str, - db_session: Session = Depends(mlrun.api.api.deps.get_db_session), - auth_info: mlrun.api.schemas.AuthInfo = Depends( - mlrun.api.api.deps.authenticate_request - ), -): - marketplace_source = await run_in_threadpool( - get_db().get_marketplace_source, db_session, source_name - ) - await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.marketplace_source, - AuthorizationAction.read, - auth_info, - ) - - return marketplace_source - - -@router.put( - path="/marketplace/sources/{source_name}", response_model=IndexedMarketplaceSource -) -async def store_source( - source_name: str, - source: IndexedMarketplaceSource, - db_session: Session = Depends(mlrun.api.api.deps.get_db_session), - auth_info: mlrun.api.schemas.AuthInfo = Depends( - mlrun.api.api.deps.authenticate_request - ), -): - await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.marketplace_source, - AuthorizationAction.store, - auth_info, - ) - - await run_in_threadpool( - get_db().store_marketplace_source, db_session, source_name, source - ) - # Handle credentials if they exist - await run_in_threadpool(mlrun.api.crud.Marketplace().add_source, source.source) - - return await run_in_threadpool( - get_db().get_marketplace_source, db_session, source_name - ) - - -@router.get( - path="/marketplace/sources/{source_name}/items", - response_model=MarketplaceCatalog, -) -async def get_catalog( - source_name: str, - version: Optional[str] = Query(None), - tag: Optional[str] = Query(None), - force_refresh: Optional[bool] = Query(False, alias="force-refresh"), - db_session: Session = Depends(mlrun.api.api.deps.get_db_session), - auth_info: mlrun.api.schemas.AuthInfo = Depends( - mlrun.api.api.deps.authenticate_request - ), -): - ordered_source = await run_in_threadpool( - get_db().get_marketplace_source, db_session, source_name - ) - await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.marketplace_source, - AuthorizationAction.read, - auth_info, - ) - - return await run_in_threadpool( - mlrun.api.crud.Marketplace().get_source_catalog, - ordered_source.source, - version, - tag, - force_refresh, - ) - - -@router.get( - "/marketplace/sources/{source_name}/items/{item_name}", - response_model=MarketplaceItem, -) -async def get_item( - source_name: str, - item_name: str, - version: Optional[str] = Query(None), - tag: Optional[str] = Query("latest"), - force_refresh: Optional[bool] = Query(False, alias="force-refresh"), - db_session: Session = Depends(mlrun.api.api.deps.get_db_session), - auth_info: mlrun.api.schemas.AuthInfo = Depends( - mlrun.api.api.deps.authenticate_request - ), -): - ordered_source = await run_in_threadpool( - get_db().get_marketplace_source, db_session, source_name - ) - await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.marketplace_source, - AuthorizationAction.read, - auth_info, - ) - - return await run_in_threadpool( - mlrun.api.crud.Marketplace().get_item, - ordered_source.source, - item_name, - version, - tag, - force_refresh, - ) - - -@router.get( - "/marketplace/sources/{source_name}/item-object", -) -async def get_object( - source_name: str, - url: str, - db_session: Session = Depends(mlrun.api.api.deps.get_db_session), - auth_info: mlrun.api.schemas.AuthInfo = Depends( - mlrun.api.api.deps.authenticate_request - ), -): - ordered_source = await run_in_threadpool( - get_db().get_marketplace_source, db_session, source_name - ) - object_data = await run_in_threadpool( - mlrun.api.crud.Marketplace().get_item_object_using_source_credentials, - ordered_source.source, - url, - ) - await mlrun.api.utils.auth.verifier.AuthVerifier().query_global_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.marketplace_source, - AuthorizationAction.read, - auth_info, - ) - - if url.endswith("/"): - return object_data - - ctype, _ = mimetypes.guess_type(url) - if not ctype: - ctype = "application/octet-stream" - return Response(content=object_data, media_type=ctype) diff --git a/mlrun/api/api/endpoints/model_endpoints.py b/mlrun/api/api/endpoints/model_endpoints.py index 4171e8e91744..567c17815cc1 100644 --- a/mlrun/api/api/endpoints/model_endpoints.py +++ b/mlrun/api/api/endpoints/model_endpoints.py @@ -24,28 +24,28 @@ import mlrun.api.api.deps import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.auth.verifier +import mlrun.common.schemas from mlrun.errors import MLRunConflictError -router = APIRouter() +router = APIRouter(prefix="/projects/{project}/model-endpoints") @router.put( - "/projects/{project}/model-endpoints/{endpoint_id}", - status_code=HTTPStatus.NO_CONTENT.value, + "/{endpoint_id}", + response_model=mlrun.common.schemas.ModelEndpoint, ) async def create_or_patch( project: str, endpoint_id: str, - model_endpoint: mlrun.api.schemas.ModelEndpoint, - auth_info: mlrun.api.schemas.AuthInfo = Depends( + model_endpoint: mlrun.common.schemas.ModelEndpoint, + auth_info: mlrun.common.schemas.AuthInfo = Depends( mlrun.api.api.deps.authenticate_request ), db_session: Session = Depends(mlrun.api.api.deps.get_db_session), -): +) -> mlrun.common.schemas.ModelEndpoint: """ - Either create or updates the record of a given ModelEndpoint object. + Either create or update the record of a given `ModelEndpoint` object. Leaving here for backwards compatibility. """ @@ -57,10 +57,10 @@ async def create_or_patch( ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, + mlrun.common.schemas.AuthorizationResourceTypes.model_endpoint, project, endpoint_id, - mlrun.api.schemas.AuthorizationAction.store, + mlrun.common.schemas.AuthorizationAction.store, auth_info, ) # get_access_key will validate the needed auth (which is used later) exists in the request @@ -76,7 +76,7 @@ async def create_or_patch( ) # Since the endpoint records are created automatically, at point of serving function deployment, we need to use # V3IO_ACCESS_KEY here - await run_in_threadpool( + return await run_in_threadpool( mlrun.api.crud.ModelEndpoints().create_or_patch, db_session=db_session, access_key=os.environ.get("V3IO_ACCESS_KEY"), @@ -86,20 +86,20 @@ async def create_or_patch( @router.post( - "/projects/{project}/model-endpoints/{endpoint_id}", - response_model=mlrun.api.schemas.ModelEndpoint, + "/{endpoint_id}", + response_model=mlrun.common.schemas.ModelEndpoint, ) async def create_model_endpoint( project: str, endpoint_id: str, - model_endpoint: mlrun.api.schemas.ModelEndpoint, - auth_info: mlrun.api.schemas.AuthInfo = Depends( + model_endpoint: mlrun.common.schemas.ModelEndpoint, + auth_info: mlrun.common.schemas.AuthInfo = Depends( mlrun.api.api.deps.authenticate_request ), db_session: Session = Depends(mlrun.api.api.deps.get_db_session), -) -> mlrun.api.schemas.ModelEndpoint: +) -> mlrun.common.schemas.ModelEndpoint: """ - Create a DB record of a given ModelEndpoint object. + Create a DB record of a given `ModelEndpoint` object. :param project: The name of the project. :param endpoint_id: The unique id of the model endpoint. @@ -111,11 +111,12 @@ async def create_model_endpoint( :return: A Model endpoint object. """ + await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - resource_type=mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, + resource_type=mlrun.common.schemas.AuthorizationResourceTypes.model_endpoint, project_name=project, resource_name=endpoint_id, - action=mlrun.api.schemas.AuthorizationAction.store, + action=mlrun.common.schemas.AuthorizationAction.store, auth_info=auth_info, ) @@ -137,26 +138,25 @@ async def create_model_endpoint( @router.patch( - "/projects/{project}/model-endpoints/{endpoint_id}", - response_model=mlrun.api.schemas.ModelEndpoint, + "/{endpoint_id}", + response_model=mlrun.common.schemas.ModelEndpoint, ) async def patch_model_endpoint( project: str, endpoint_id: str, attributes: str = None, - auth_info: mlrun.api.schemas.AuthInfo = Depends( + auth_info: mlrun.common.schemas.AuthInfo = Depends( mlrun.api.api.deps.authenticate_request ), -) -> mlrun.api.schemas.ModelEndpoint: +) -> mlrun.common.schemas.ModelEndpoint: """ - Update a DB record of a given ModelEndpoint object. + Update a DB record of a given `ModelEndpoint` object. :param project: The name of the project. :param endpoint_id: The unique id of the model endpoint. :param attributes: Attributes that will be updated. The input is provided in a json structure that will be converted into a dictionary before applying the patch process. Note that the keys of - dictionary should exist in the DB target. More details about the model endpoint available - attributes can be found under :py:class:`~mlrun.api.schemas.ModelEndpoint`. + the dictionary should exist in the DB target. example:: @@ -168,10 +168,10 @@ async def patch_model_endpoint( """ await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - resource_type=mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, + resource_type=mlrun.common.schemas.AuthorizationResourceTypes.model_endpoint, project_name=project, resource_name=endpoint_id, - action=mlrun.api.schemas.AuthorizationAction.update, + action=mlrun.common.schemas.AuthorizationAction.update, auth_info=auth_info, ) @@ -188,13 +188,13 @@ async def patch_model_endpoint( @router.delete( - "/projects/{project}/model-endpoints/{endpoint_id}", + "/{endpoint_id}", status_code=HTTPStatus.NO_CONTENT.value, ) async def delete_model_endpoint( project: str, endpoint_id: str, - auth_info: mlrun.api.schemas.AuthInfo = Depends( + auth_info: mlrun.common.schemas.AuthInfo = Depends( mlrun.api.api.deps.authenticate_request ), ): @@ -208,10 +208,10 @@ async def delete_model_endpoint( """ await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - resource_type=mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, + resource_type=mlrun.common.schemas.AuthorizationResourceTypes.model_endpoint, project_name=project, resource_name=endpoint_id, - action=mlrun.api.schemas.AuthorizationAction.delete, + action=mlrun.common.schemas.AuthorizationAction.delete, auth_info=auth_info, ) @@ -223,8 +223,8 @@ async def delete_model_endpoint( @router.get( - "/projects/{project}/model-endpoints", - response_model=mlrun.api.schemas.ModelEndpointList, + "", + response_model=mlrun.common.schemas.ModelEndpointList, ) async def list_model_endpoints( project: str, @@ -236,16 +236,16 @@ async def list_model_endpoints( metrics: List[str] = Query([], alias="metric"), top_level: bool = Query(False, alias="top-level"), uids: List[str] = Query(None, alias="uid"), - auth_info: mlrun.api.schemas.AuthInfo = Depends( + auth_info: mlrun.common.schemas.AuthInfo = Depends( mlrun.api.api.deps.authenticate_request ), -) -> mlrun.api.schemas.ModelEndpointList: +) -> mlrun.common.schemas.ModelEndpointList: """ Returns a list of endpoints of type 'ModelEndpoint', supports filtering by model, function, tag, labels or top level. By default, when no filters are applied, all available endpoints for the given project will be listed. - If uids are passed: will return ModelEndpointList of endpoints with uid in uids + If uids are passed: will return `ModelEndpointList` of endpoints with uid in uids Labels can be used to filter on the existence of a label: api/projects/{project}/model-endpoints/?label=mylabel @@ -264,11 +264,11 @@ async def list_model_endpoints( :param model: The name of the model to filter by. :param function: The name of the function to filter by. :param labels: A list of labels to filter by. Label filters work by either filtering a specific value of a label - (i.e. list("key==value")) or by looking for the existence of a given key (i.e. "key"). - :param metrics: A list of metrics to return for each endpoint. There are pre-defined metrics for model endpoints - such as predictions_per_second and latency_avg_5m but also custom metrics defined by the user. - Please note that these metrics are stored in the time series DB and the results will be appeared - under model_endpoint.spec.metrics of each endpoint. + (i.e. list("key=value")) or by looking for the existence of a given key (i.e. "key"). + :param metrics: A list of real-time metrics to return for each endpoint. There are pre-defined real-time metrics + for model endpoints such as predictions_per_second and latency_avg_5m but also custom metrics + defined by the user. Please note that these metrics are stored in the time series DB and the + results will be appeared under model_endpoint.spec.metrics of each endpoint. :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the earliest time. @@ -276,15 +276,15 @@ async def list_model_endpoints( time, a Unix timestamp in milliseconds, a relative time (`'now'` or `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the earliest time. :param top_level: If True will return only routers and endpoint that are NOT children of any router. - :param uids: Will return ModelEndpointList of endpoints with uid in uids. + :param uids: Will return `ModelEndpointList` of endpoints with uid in uids. - :return: An object of ModelEndpointList which is literally a list of model endpoints along with some metadata. To + :return: An object of `ModelEndpointList` which is literally a list of model endpoints along with some metadata. To get a standard list of model endpoints use ModelEndpointList.endpoints. """ await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project_name=project, - action=mlrun.api.schemas.AuthorizationAction.read, + action=mlrun.common.schemas.AuthorizationAction.read, auth_info=auth_info, ) @@ -302,7 +302,7 @@ async def list_model_endpoints( uids=uids, ) allowed_endpoints = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, + mlrun.common.schemas.AuthorizationResourceTypes.model_endpoint, endpoints.endpoints, lambda _endpoint: ( _endpoint.metadata.project, @@ -316,8 +316,8 @@ async def list_model_endpoints( @router.get( - "/projects/{project}/model-endpoints/{endpoint_id}", - response_model=mlrun.api.schemas.ModelEndpoint, + "/{endpoint_id}", + response_model=mlrun.common.schemas.ModelEndpoint, ) async def get_model_endpoint( project: str, @@ -326,36 +326,40 @@ async def get_model_endpoint( end: str = Query(default="now"), metrics: List[str] = Query([], alias="metric"), feature_analysis: bool = Query(default=False), - auth_info: mlrun.api.schemas.AuthInfo = Depends( + auth_info: mlrun.common.schemas.AuthInfo = Depends( mlrun.api.api.deps.authenticate_request ), -) -> mlrun.api.schemas.ModelEndpoint: +) -> mlrun.common.schemas.ModelEndpoint: """Get a single model endpoint object. You can apply different time series metrics that will be added to the result. - :param project: The name of the project. - :param endpoint_id: The unique id of the model endpoint. - :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or `'now-[0-9]+[mhd]'`, - where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the earliest time. - :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or `'now-[0-9]+[mhd]'`, - where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the earliest time. - :param metrics: A list of metrics to return for the model endpoint. There are pre-defined metrics for model - endpoints such as predictions_per_second and latency_avg_5m but also custom metrics - defined by the user. Please note that these metrics are stored in the time series DB and - the results will be appeared under model_endpoint.spec.metrics. - :param feature_analysis: When True, the base feature statistics and current feature statistics will be added to - the output of the resulting object. - :param auth_info: The auth info of the request. - - :return: A ModelEndpoint object. + + :param project: The name of the project + :param endpoint_id: The unique id of the model endpoint. + :param start: The start time of the metrics. Can be represented by a string containing an + RFC 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or + 0 for the earliest time. + :param end: The end time of the metrics. Can be represented by a string containing an + RFC 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or + 0 for the earliest time. + :param metrics: A list of real-time metrics to return for the model endpoint. There are + pre-defined real-time metrics for model endpoints such as predictions_per_second + and latency_avg_5m but also custom metrics defined by the user. Please note that + these metrics are stored in the time series DB and the results will be + appeared under model_endpoint.spec.metrics. + :param feature_analysis: When True, the base feature statistics and current feature statistics will + be added to the output of the resulting object. + :param auth_info: The auth info of the request + + :return: A `ModelEndpoint` object. """ await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint, + mlrun.common.schemas.AuthorizationResourceTypes.model_endpoint, project, endpoint_id, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) diff --git a/mlrun/api/api/endpoints/operations.py b/mlrun/api/api/endpoints/operations.py index 6751478cb485..527e99543a73 100644 --- a/mlrun/api/api/endpoints/operations.py +++ b/mlrun/api/api/endpoints/operations.py @@ -21,9 +21,9 @@ import mlrun.api.api.deps import mlrun.api.crud import mlrun.api.initial_data -import mlrun.api.schemas import mlrun.api.utils.background_tasks import mlrun.api.utils.clients.chief +import mlrun.common.schemas from mlrun.utils import logger router = fastapi.APIRouter() @@ -36,7 +36,7 @@ "/operations/migrations", responses={ http.HTTPStatus.OK.value: {}, - http.HTTPStatus.ACCEPTED.value: {"model": mlrun.api.schemas.BackgroundTask}, + http.HTTPStatus.ACCEPTED.value: {"model": mlrun.common.schemas.BackgroundTask}, }, ) async def trigger_migrations( @@ -47,7 +47,7 @@ async def trigger_migrations( # only chief can execute migrations, redirecting request to chief if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): logger.info("Requesting to trigger migrations, re-routing to chief") chief_client = mlrun.api.utils.clients.chief.Client() @@ -72,18 +72,22 @@ async def trigger_migrations( def _get_or_create_migration_background_task( task_name: str, background_tasks -) -> typing.Optional[mlrun.api.schemas.BackgroundTask]: - if mlrun.mlconf.httpdb.state == mlrun.api.schemas.APIStates.migrations_in_progress: +) -> typing.Optional[mlrun.common.schemas.BackgroundTask]: + if ( + mlrun.mlconf.httpdb.state + == mlrun.common.schemas.APIStates.migrations_in_progress + ): background_task = mlrun.api.utils.background_tasks.InternalBackgroundTasksHandler().get_background_task( task_name ) return background_task - elif mlrun.mlconf.httpdb.state == mlrun.api.schemas.APIStates.migrations_failed: + elif mlrun.mlconf.httpdb.state == mlrun.common.schemas.APIStates.migrations_failed: raise mlrun.errors.MLRunPreconditionFailedError( "Migrations were already triggered and failed. Restart the API to retry" ) elif ( - mlrun.mlconf.httpdb.state != mlrun.api.schemas.APIStates.waiting_for_migrations + mlrun.mlconf.httpdb.state + != mlrun.common.schemas.APIStates.waiting_for_migrations ): return None @@ -102,4 +106,4 @@ async def _perform_migration(): mlrun.api.initial_data.init_data, perform_migrations_if_needed=True ) await mlrun.api.main.move_api_to_online() - mlrun.mlconf.httpdb.state = mlrun.api.schemas.APIStates.online + mlrun.mlconf.httpdb.state = mlrun.common.schemas.APIStates.online diff --git a/mlrun/api/api/endpoints/pipelines.py b/mlrun/api/api/endpoints/pipelines.py index fbeefe3a946a..c210e895ccf6 100644 --- a/mlrun/api/api/endpoints/pipelines.py +++ b/mlrun/api/api/endpoints/pipelines.py @@ -23,32 +23,30 @@ from sqlalchemy.orm import Session import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.auth.verifier +import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas import mlrun.errors from mlrun.api.api import deps from mlrun.api.api.utils import log_and_raise from mlrun.config import config -from mlrun.k8s_utils import get_k8s_helper from mlrun.utils import logger -router = APIRouter() +router = APIRouter(prefix="/projects/{project}/pipelines") -@router.get( - "/projects/{project}/pipelines", response_model=mlrun.api.schemas.PipelinesOutput -) +@router.get("", response_model=mlrun.common.schemas.PipelinesOutput) async def list_pipelines( project: str, namespace: str = None, sort_by: str = "", page_token: str = "", filter_: str = Query("", alias="filter"), - format_: mlrun.api.schemas.PipelinesFormat = Query( - mlrun.api.schemas.PipelinesFormat.metadata_only, alias="format" + format_: mlrun.common.schemas.PipelinesFormat = Query( + mlrun.common.schemas.PipelinesFormat.metadata_only, alias="format" ), page_size: int = Query(None, gt=0, le=200), - auth_info: mlrun.api.schemas.AuthInfo = Depends( + auth_info: mlrun.common.schemas.AuthInfo = Depends( mlrun.api.api.deps.authenticate_request ), db_session: Session = Depends(deps.get_db_session), @@ -58,16 +56,18 @@ async def list_pipelines( if project != "*": await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) total_size, next_page_token, runs = None, None, [] - if get_k8s_helper(silent=True).is_running_inside_kubernetes_cluster(): + if mlrun.api.utils.singletons.k8s.get_k8s_helper( + silent=True + ).is_running_inside_kubernetes_cluster(): # we need to resolve the project from the returned run for the opa enforcement (project query param might be # "*"), so we can't really get back only the names here computed_format = ( - mlrun.api.schemas.PipelinesFormat.metadata_only - if format_ == mlrun.api.schemas.PipelinesFormat.name_only + mlrun.common.schemas.PipelinesFormat.metadata_only + if format_ == mlrun.common.schemas.PipelinesFormat.name_only else format_ ) total_size, next_page_token, runs = await run_in_threadpool( @@ -82,7 +82,7 @@ async def list_pipelines( page_size, ) allowed_runs = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.pipeline, + mlrun.common.schemas.AuthorizationResourceTypes.pipeline, runs, lambda run: ( run["project"], @@ -90,23 +90,23 @@ async def list_pipelines( ), auth_info, ) - if format_ == mlrun.api.schemas.PipelinesFormat.name_only: + if format_ == mlrun.common.schemas.PipelinesFormat.name_only: allowed_runs = [run["name"] for run in allowed_runs] - return mlrun.api.schemas.PipelinesOutput( + return mlrun.common.schemas.PipelinesOutput( runs=allowed_runs, total_size=total_size or 0, next_page_token=next_page_token or None, ) -@router.post("/projects/{project}/pipelines") +@router.post("") async def create_pipeline( project: str, request: Request, namespace: str = None, experiment_name: str = Query("Default", alias="experiment"), run_name: str = Query("", alias="run"), - auth_info: mlrun.api.schemas.AuthInfo = Depends( + auth_info: mlrun.common.schemas.AuthInfo = Depends( mlrun.api.api.deps.authenticate_request ), ): @@ -118,8 +118,76 @@ async def create_pipeline( return response +@router.get("/{run_id}") +async def get_pipeline( + run_id: str, + project: str, + namespace: str = Query(config.namespace), + format_: mlrun.common.schemas.PipelinesFormat = Query( + mlrun.common.schemas.PipelinesFormat.summary, alias="format" + ), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), + db_session: Session = Depends(deps.get_db_session), +): + pipeline = await run_in_threadpool( + mlrun.api.crud.Pipelines().get_pipeline, + db_session, + run_id, + project, + namespace, + format_, + ) + if project == "*": + # In some flows the user may use SDK functions that won't require them to specify the pipeline's project (for + # backwards compatibility reasons), so the client will just send * in the project, in that case we use the + # legacy flow in which we first get the pipeline, resolve the project out of it, and only then query permissions + # we don't use the return value from this function since the user may have asked for a different format than + # summary which is the one used inside + await _get_pipeline_without_project(db_session, auth_info, run_id, namespace) + else: + await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.pipeline, + project, + run_id, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + return pipeline + + +async def _get_pipeline_without_project( + db_session: Session, + auth_info: mlrun.common.schemas.AuthInfo, + run_id: str, + namespace: str, +): + """ + This function is for when we receive a get pipeline request without the client specifying the project + So we first get the pipeline, resolve the project out of it, and now that we know the project, we can verify + permissions + """ + run = await run_in_threadpool( + mlrun.api.crud.Pipelines().get_pipeline, + db_session, + run_id, + namespace=namespace, + # minimal format that includes the project + format_=mlrun.common.schemas.PipelinesFormat.summary, + ) + await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.pipeline, + run["run"]["project"], + run["run"]["id"], + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + return run + + async def _create_pipeline( - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, request: Request, namespace: str, experiment_name: str, @@ -129,10 +197,10 @@ async def _create_pipeline( # If we have the project (new clients from 0.7.0 uses the new endpoint in which it's mandatory) - check auth now if project: await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.pipeline, + mlrun.common.schemas.AuthorizationResourceTypes.pipeline, project, "", - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) run_name = run_name or experiment_name + " " + datetime.now().strftime( @@ -153,16 +221,16 @@ async def _create_pipeline( ) else: await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.pipeline, + mlrun.common.schemas.AuthorizationResourceTypes.pipeline, project, "", - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) arguments = {} arguments_data = request.headers.get( - mlrun.api.schemas.HeaderNames.pipeline_arguments + mlrun.common.schemas.HeaderNames.pipeline_arguments ) if arguments_data: arguments = ast.literal_eval(arguments_data) @@ -196,66 +264,3 @@ def _try_resolve_project_from_body( return mlrun.api.crud.Pipelines().resolve_project_from_workflow_manifest( workflow_manifest ) - - -@router.get("/projects/{project}/pipelines/{run_id}") -async def get_pipeline( - run_id: str, - project: str, - namespace: str = Query(config.namespace), - format_: mlrun.api.schemas.PipelinesFormat = Query( - mlrun.api.schemas.PipelinesFormat.summary, alias="format" - ), - auth_info: mlrun.api.schemas.AuthInfo = Depends( - mlrun.api.api.deps.authenticate_request - ), - db_session: Session = Depends(deps.get_db_session), -): - pipeline = mlrun.api.crud.Pipelines().get_pipeline( - db_session, run_id, project, namespace, format_ - ) - if project == "*": - # In some flows the user may use SDK functions that won't require them to specify the pipeline's project (for - # backwards compatibility reasons), so the client will just send * in the project, in that case we use the - # legacy flow in which we first get the pipeline, resolve the project out of it, and only then query permissions - # we don't use the return value from this function since the user may have asked for a different format than - # summary which is the one used inside - await _get_pipeline_without_project(db_session, auth_info, run_id, namespace) - else: - await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.pipeline, - project, - run_id, - mlrun.api.schemas.AuthorizationAction.read, - auth_info, - ) - return pipeline - - -async def _get_pipeline_without_project( - db_session: Session, - auth_info: mlrun.api.schemas.AuthInfo, - run_id: str, - namespace: str, -): - """ - This function is for when we receive a get pipeline request without the client specifying the project - So we first get the pipeline, resolve the project out of it, and now that we know the project, we can verify - permissions - """ - run = await run_in_threadpool( - mlrun.api.crud.Pipelines().get_pipeline, - db_session, - run_id, - namespace=namespace, - # minimal format that includes the project - format_=mlrun.api.schemas.PipelinesFormat.summary, - ) - await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.pipeline, - run["run"]["project"], - run["run"]["id"], - mlrun.api.schemas.AuthorizationAction.read, - auth_info, - ) - return run diff --git a/mlrun/api/api/endpoints/projects.py b/mlrun/api/api/endpoints/projects.py index 2987bcba9251..e4a950a4100d 100644 --- a/mlrun/api/api/endpoints/projects.py +++ b/mlrun/api/api/endpoints/projects.py @@ -20,9 +20,9 @@ from fastapi.concurrency import run_in_threadpool import mlrun.api.api.deps -import mlrun.api.schemas import mlrun.api.utils.auth.verifier import mlrun.api.utils.clients.chief +import mlrun.common.schemas from mlrun.api.utils.singletons.project_member import get_project_member from mlrun.utils import logger @@ -32,17 +32,17 @@ @router.post( "/projects", responses={ - http.HTTPStatus.CREATED.value: {"model": mlrun.api.schemas.Project}, + http.HTTPStatus.CREATED.value: {"model": mlrun.common.schemas.Project}, http.HTTPStatus.ACCEPTED.value: {}, }, ) def create_project( - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, response: fastapi.Response, # TODO: we're in a http request context here, therefore it doesn't make sense that by default it will hold the # request until the process will be completed - after UI supports waiting - change default to False wait_for_completion: bool = fastapi.Query(True, alias="wait-for-completion"), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -65,17 +65,17 @@ def create_project( @router.put( "/projects/{name}", responses={ - http.HTTPStatus.OK.value: {"model": mlrun.api.schemas.Project}, + http.HTTPStatus.OK.value: {"model": mlrun.common.schemas.Project}, http.HTTPStatus.ACCEPTED.value: {}, }, ) def store_project( - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, name: str, # TODO: we're in a http request context here, therefore it doesn't make sense that by default it will hold the # request until the process will be completed - after UI supports waiting - change default to False wait_for_completion: bool = fastapi.Query(True, alias="wait-for-completion"), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -98,21 +98,21 @@ def store_project( @router.patch( "/projects/{name}", responses={ - http.HTTPStatus.OK.value: {"model": mlrun.api.schemas.Project}, + http.HTTPStatus.OK.value: {"model": mlrun.common.schemas.Project}, http.HTTPStatus.ACCEPTED.value: {}, }, ) def patch_project( project: dict, name: str, - patch_mode: mlrun.api.schemas.PatchMode = fastapi.Header( - mlrun.api.schemas.PatchMode.replace, - alias=mlrun.api.schemas.HeaderNames.patch_mode, + patch_mode: mlrun.common.schemas.PatchMode = fastapi.Header( + mlrun.common.schemas.PatchMode.replace, + alias=mlrun.common.schemas.HeaderNames.patch_mode, ), # TODO: we're in a http request context here, therefore it doesn't make sense that by default it will hold the # request until the process will be completed - after UI supports waiting - change default to False wait_for_completion: bool = fastapi.Query(True, alias="wait-for-completion"), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -133,13 +133,13 @@ def patch_project( return project -@router.get("/projects/{name}", response_model=mlrun.api.schemas.Project) +@router.get("/projects/{name}", response_model=mlrun.common.schemas.Project) async def get_project( name: str, db_session: sqlalchemy.orm.Session = fastapi.Depends( mlrun.api.api.deps.get_db_session ), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), ): @@ -150,7 +150,7 @@ async def get_project( if not _is_request_from_leader(auth_info.projects_role): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) return project @@ -166,14 +166,14 @@ async def get_project( async def delete_project( name: str, request: fastapi.Request, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = fastapi.Header( - mlrun.api.schemas.DeletionStrategy.default(), - alias=mlrun.api.schemas.HeaderNames.deletion_strategy, + deletion_strategy: mlrun.common.schemas.DeletionStrategy = fastapi.Header( + mlrun.common.schemas.DeletionStrategy.default(), + alias=mlrun.common.schemas.HeaderNames.deletion_strategy, ), # TODO: we're in a http request context here, therefore it doesn't make sense that by default it will hold the # request until the process will be completed - after UI supports waiting - change default to False wait_for_completion: bool = fastapi.Query(True, alias="wait-for-completion"), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -184,7 +184,7 @@ async def delete_project( # that is why we re-route requests to chief if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): logger.info( "Requesting to delete project, re-routing to chief", @@ -209,15 +209,15 @@ async def delete_project( return fastapi.Response(status_code=http.HTTPStatus.NO_CONTENT.value) -@router.get("/projects", response_model=mlrun.api.schemas.ProjectsOutput) +@router.get("/projects", response_model=mlrun.common.schemas.ProjectsOutput) async def list_projects( - format_: mlrun.api.schemas.ProjectsFormat = fastapi.Query( - mlrun.api.schemas.ProjectsFormat.full, alias="format" + format_: mlrun.common.schemas.ProjectsFormat = fastapi.Query( + mlrun.common.schemas.ProjectsFormat.full, alias="format" ), owner: str = None, labels: typing.List[str] = fastapi.Query(None, alias="label"), - state: mlrun.api.schemas.ProjectState = None, - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + state: mlrun.common.schemas.ProjectState = None, + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -231,7 +231,7 @@ async def list_projects( get_project_member().list_projects, db_session, owner, - mlrun.api.schemas.ProjectsFormat.name_only, + mlrun.common.schemas.ProjectsFormat.name_only, labels, state, auth_info.projects_role, @@ -257,13 +257,13 @@ async def list_projects( @router.get( - "/project-summaries", response_model=mlrun.api.schemas.ProjectSummariesOutput + "/project-summaries", response_model=mlrun.common.schemas.ProjectSummariesOutput ) async def list_project_summaries( owner: str = None, labels: typing.List[str] = fastapi.Query(None, alias="label"), - state: mlrun.api.schemas.ProjectState = None, - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + state: mlrun.common.schemas.ProjectState = None, + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -274,7 +274,7 @@ async def list_project_summaries( get_project_member().list_projects, db_session, owner, - mlrun.api.schemas.ProjectsFormat.name_only, + mlrun.common.schemas.ProjectsFormat.name_only, labels, state, auth_info.projects_role, @@ -299,14 +299,14 @@ async def list_project_summaries( @router.get( - "/project-summaries/{name}", response_model=mlrun.api.schemas.ProjectSummary + "/project-summaries/{name}", response_model=mlrun.common.schemas.ProjectSummary ) async def get_project_summary( name: str, db_session: sqlalchemy.orm.Session = fastapi.Depends( mlrun.api.api.deps.get_db_session ), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), ): @@ -317,14 +317,14 @@ async def get_project_summary( if not _is_request_from_leader(auth_info.projects_role): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) return project_summary def _is_request_from_leader( - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole], + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole], ) -> bool: if projects_role and projects_role.value == mlrun.mlconf.httpdb.projects.leader: return True diff --git a/mlrun/api/api/endpoints/runs.py b/mlrun/api/api/endpoints/runs.py index e3f0c1111b99..2032bacf52c8 100644 --- a/mlrun/api/api/endpoints/runs.py +++ b/mlrun/api/api/endpoints/runs.py @@ -16,17 +16,16 @@ from http import HTTPStatus from typing import List -from fastapi import APIRouter, Depends, Query, Request +from fastapi import APIRouter, Body, Depends, Query, Request, Response from fastapi.concurrency import run_in_threadpool from sqlalchemy.orm import Session import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.auth.verifier import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas from mlrun.api.api import deps from mlrun.api.api.utils import log_and_raise -from mlrun.utils import logger from mlrun.utils.helpers import datetime_from_iso router = APIRouter() @@ -38,7 +37,7 @@ async def store_run( project: str, uid: str, iter: int = 0, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await run_in_threadpool( @@ -48,10 +47,10 @@ async def store_run( auth_info=auth_info, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.run, + mlrun.common.schemas.AuthorizationResourceTypes.run, project, uid, - mlrun.api.schemas.AuthorizationAction.store, + mlrun.common.schemas.AuthorizationAction.store, auth_info, ) data = None @@ -60,7 +59,6 @@ async def store_run( except ValueError: log_and_raise(HTTPStatus.BAD_REQUEST.value, reason="bad JSON body") - logger.info("Storing run", data=data) await run_in_threadpool( mlrun.api.crud.Runs().store_run, db_session, @@ -78,14 +76,14 @@ async def update_run( project: str, uid: str, iter: int = 0, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.run, + mlrun.common.schemas.AuthorizationResourceTypes.run, project, uid, - mlrun.api.schemas.AuthorizationAction.update, + mlrun.common.schemas.AuthorizationAction.update, auth_info, ) data = None @@ -110,17 +108,17 @@ async def get_run( project: str, uid: str, iter: int = 0, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): data = await run_in_threadpool( mlrun.api.crud.Runs().get_run, db_session, uid, iter, project ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.run, + mlrun.common.schemas.AuthorizationResourceTypes.run, project, uid, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) return { @@ -133,14 +131,14 @@ async def delete_run( project: str, uid: str, iter: int = 0, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.run, + mlrun.common.schemas.AuthorizationResourceTypes.run, project, uid, - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) await run_in_threadpool( @@ -167,25 +165,25 @@ async def list_runs( start_time_to: str = None, last_update_time_from: str = None, last_update_time_to: str = None, - partition_by: mlrun.api.schemas.RunPartitionByField = Query( + partition_by: mlrun.common.schemas.RunPartitionByField = Query( None, alias="partition-by" ), rows_per_partition: int = Query(1, alias="rows-per-partition", gt=0), - partition_sort_by: mlrun.api.schemas.SortField = Query( + partition_sort_by: mlrun.common.schemas.SortField = Query( None, alias="partition-sort-by" ), - partition_order: mlrun.api.schemas.OrderType = Query( - mlrun.api.schemas.OrderType.desc, alias="partition-order" + partition_order: mlrun.common.schemas.OrderType = Query( + mlrun.common.schemas.OrderType.desc, alias="partition-order" ), max_partitions: int = Query(0, alias="max-partitions", ge=0), with_notifications: bool = Query(False, alias="with-notifications"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): if project != "*": await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) runs = await run_in_threadpool( @@ -211,7 +209,7 @@ async def list_runs( with_notifications=with_notifications, ) filtered_runs = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.run, + mlrun.common.schemas.AuthorizationResourceTypes.run, runs, lambda run: ( run.get("metadata", {}).get("project", mlrun.mlconf.default_project), @@ -231,7 +229,7 @@ async def delete_runs( labels: List[str] = Query([], alias="label"), state: str = None, days_ago: int = None, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): if not project or project != "*": @@ -239,10 +237,10 @@ async def delete_runs( # Meaning there is no reason at the moment to query the permission for each run under the project # TODO check for every run when we will manage permission per run inside a project await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.run, + mlrun.common.schemas.AuthorizationResourceTypes.run, project or mlrun.mlconf.default_project, "", - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) else: @@ -268,10 +266,10 @@ async def delete_runs( # currently we fail if the user doesn't has permissions to delete runs to one of the projects in the system # TODO Delete only runs from projects that user has permissions to await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.run, + mlrun.common.schemas.AuthorizationResourceTypes.run, run_project, "", - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) @@ -285,3 +283,43 @@ async def delete_runs( days_ago, ) return {} + + +@router.put( + "/projects/{project}/runs/{uid}/notifications", + status_code=HTTPStatus.OK.value, +) +async def set_run_notifications( + project: str, + uid: str, + set_notifications_request: mlrun.common.schemas.SetNotificationRequest = Body(...), + auth_info: mlrun.common.schemas.AuthInfo = Depends( + mlrun.api.api.deps.authenticate_request + ), + db_session: Session = Depends(mlrun.api.api.deps.get_db_session), +): + await run_in_threadpool( + mlrun.api.utils.singletons.project_member.get_project_member().ensure_project, + db_session, + project, + auth_info=auth_info, + ) + + # check permission per object type + await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.run, + project, + resource_name=uid, + action=mlrun.common.schemas.AuthorizationAction.update, + auth_info=auth_info, + ) + + await run_in_threadpool( + mlrun.api.crud.Notifications().set_object_notifications, + db_session, + auth_info, + project, + set_notifications_request.notifications, + mlrun.common.schemas.RunIdentifier(uid=uid), + ) + return Response(status_code=HTTPStatus.OK.value) diff --git a/mlrun/api/api/endpoints/runtime_resources.py b/mlrun/api/api/endpoints/runtime_resources.py index d3ae6bde759a..3586e9ee406f 100644 --- a/mlrun/api/api/endpoints/runtime_resources.py +++ b/mlrun/api/api/endpoints/runtime_resources.py @@ -23,18 +23,18 @@ import mlrun import mlrun.api.api.deps import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.auth.verifier +import mlrun.common.schemas -router = fastapi.APIRouter() +router = fastapi.APIRouter(prefix="/projects/{project}/runtime-resources") @router.get( - "/projects/{project}/runtime-resources", + "", response_model=typing.Union[ - mlrun.api.schemas.RuntimeResourcesOutput, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResourcesOutput, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ], ) async def list_runtime_resources( @@ -43,9 +43,9 @@ async def list_runtime_resources( kind: typing.Optional[str] = None, object_id: typing.Optional[str] = fastapi.Query(None, alias="object-id"), group_by: typing.Optional[ - mlrun.api.schemas.ListRuntimeResourcesGroupByField + mlrun.common.schemas.ListRuntimeResourcesGroupByField ] = fastapi.Query(None, alias="group-by"), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), ): @@ -55,8 +55,8 @@ async def list_runtime_resources( @router.delete( - "/projects/{project}/runtime-resources", - response_model=mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + "", + response_model=mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ) async def delete_runtime_resources( project: str, @@ -67,7 +67,7 @@ async def delete_runtime_resources( grace_period: int = fastapi.Query( mlrun.mlconf.runtime_resources_deletion_grace_period, alias="grace-period" ), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -88,7 +88,7 @@ async def delete_runtime_resources( async def _delete_runtime_resources( db_session: sqlalchemy.orm.Session, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, project: str, label_selector: typing.Optional[str] = None, kind: typing.Optional[str] = None, @@ -97,7 +97,7 @@ async def _delete_runtime_resources( grace_period: int = mlrun.mlconf.runtime_resources_deletion_grace_period, return_body: bool = True, ) -> typing.Union[ - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, fastapi.Response + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, fastapi.Response ]: ( allowed_projects, @@ -110,7 +110,7 @@ async def _delete_runtime_resources( label_selector, kind, object_id, - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, ) # TODO: once we have more granular permissions, we should check if the user is allowed to delete the specific @@ -162,7 +162,7 @@ async def _delete_runtime_resources( return mlrun.api.crud.RuntimeResources().filter_and_format_grouped_by_project_runtime_resources_output( grouped_by_project_runtime_resources_output, filtered_projects, - mlrun.api.schemas.ListRuntimeResourcesGroupByField.project, + mlrun.common.schemas.ListRuntimeResourcesGroupByField.project, ) else: return fastapi.Response(status_code=http.HTTPStatus.NO_CONTENT.value) @@ -170,17 +170,17 @@ async def _delete_runtime_resources( async def _list_runtime_resources( project: str, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, label_selector: typing.Optional[str] = None, group_by: typing.Optional[ - mlrun.api.schemas.ListRuntimeResourcesGroupByField + mlrun.common.schemas.ListRuntimeResourcesGroupByField ] = None, kind_filter: typing.Optional[str] = None, object_id: typing.Optional[str] = None, ) -> typing.Union[ - mlrun.api.schemas.RuntimeResourcesOutput, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResourcesOutput, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ]: ( allowed_projects, @@ -199,31 +199,31 @@ async def _list_runtime_resources( async def _get_runtime_resources_allowed_projects( project: str, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, label_selector: typing.Optional[str] = None, kind: typing.Optional[str] = None, object_id: typing.Optional[str] = None, - action: mlrun.api.schemas.AuthorizationAction = mlrun.api.schemas.AuthorizationAction.read, + action: mlrun.common.schemas.AuthorizationAction = mlrun.common.schemas.AuthorizationAction.read, ) -> typing.Tuple[ typing.List[str], - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, bool, bool, ]: if project != "*": await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) - grouped_by_project_runtime_resources_output: mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput + grouped_by_project_runtime_resources_output: mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput grouped_by_project_runtime_resources_output = await run_in_threadpool( mlrun.api.crud.RuntimeResources().list_runtime_resources, project, kind, object_id, label_selector, - mlrun.api.schemas.ListRuntimeResourcesGroupByField.project, + mlrun.common.schemas.ListRuntimeResourcesGroupByField.project, ) projects = [] @@ -237,7 +237,7 @@ async def _get_runtime_resources_allowed_projects( continue projects.append(project) allowed_projects = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.runtime_resource, + mlrun.common.schemas.AuthorizationResourceTypes.runtime_resource, projects, lambda project: ( project, diff --git a/mlrun/api/api/endpoints/schedules.py b/mlrun/api/api/endpoints/schedules.py index 94594c01d2c8..6e408f4e94a0 100644 --- a/mlrun/api/api/endpoints/schedules.py +++ b/mlrun/api/api/endpoints/schedules.py @@ -20,23 +20,24 @@ from sqlalchemy.orm import Session import mlrun.api.api.utils +import mlrun.api.crud import mlrun.api.utils.auth.verifier import mlrun.api.utils.clients.chief import mlrun.api.utils.singletons.project_member -from mlrun.api import schemas +import mlrun.common.schemas from mlrun.api.api import deps from mlrun.api.utils.singletons.scheduler import get_scheduler from mlrun.utils import logger -router = APIRouter() +router = APIRouter(prefix="/projects/{project}/schedules") -@router.post("/projects/{project}/schedules") +@router.post("") async def create_schedule( project: str, - schedule: schemas.ScheduleInput, + schedule: mlrun.common.schemas.ScheduleInput, request: fastapi.Request, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await run_in_threadpool( @@ -46,16 +47,16 @@ async def create_schedule( auth_info=auth_info, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.schedule, + mlrun.common.schemas.AuthorizationResourceTypes.schedule, project, schedule.name, - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) # to reduce redundant load on the chief, we re-route the request only if the user has permissions if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): logger.info( "Requesting to create schedule, re-routing to chief", @@ -86,26 +87,26 @@ async def create_schedule( return Response(status_code=HTTPStatus.CREATED.value) -@router.put("/projects/{project}/schedules/{name}") +@router.put("/{name}") async def update_schedule( project: str, name: str, - schedule: schemas.ScheduleUpdate, + schedule: mlrun.common.schemas.ScheduleUpdate, request: fastapi.Request, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.schedule, + mlrun.common.schemas.AuthorizationResourceTypes.schedule, project, name, - mlrun.api.schemas.AuthorizationAction.update, + mlrun.common.schemas.AuthorizationAction.update, auth_info, ) # to reduce redundant load on the chief, we re-route the request only if the user has permissions if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): logger.info( "Requesting to update schedule, re-routing to chief", @@ -136,20 +137,20 @@ async def update_schedule( return Response(status_code=HTTPStatus.OK.value) -@router.get("/projects/{project}/schedules", response_model=schemas.SchedulesOutput) +@router.get("", response_model=mlrun.common.schemas.SchedulesOutput) async def list_schedules( project: str, name: str = None, labels: str = None, - kind: schemas.ScheduleKinds = None, + kind: mlrun.common.schemas.ScheduleKinds = None, include_last_run: bool = False, include_credentials: bool = fastapi.Query(False, alias="include-credentials"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( project, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) schedules = await run_in_threadpool( @@ -163,7 +164,7 @@ async def list_schedules( include_credentials, ) filtered_schedules = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.schedule, + mlrun.common.schemas.AuthorizationResourceTypes.schedule, schedules.schedules, lambda schedule: ( schedule.project, @@ -176,14 +177,15 @@ async def list_schedules( @router.get( - "/projects/{project}/schedules/{name}", response_model=schemas.ScheduleOutput + "/{name}", + response_model=mlrun.common.schemas.ScheduleOutput, ) async def get_schedule( project: str, name: str, include_last_run: bool = False, include_credentials: bool = fastapi.Query(False, alias="include-credentials"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): schedule = await run_in_threadpool( @@ -195,34 +197,34 @@ async def get_schedule( include_credentials, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.schedule, + mlrun.common.schemas.AuthorizationResourceTypes.schedule, project, name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) return schedule -@router.post("/projects/{project}/schedules/{name}/invoke") +@router.post("/{name}/invoke") async def invoke_schedule( project: str, name: str, request: fastapi.Request, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.schedule, + mlrun.common.schemas.AuthorizationResourceTypes.schedule, project, name, - mlrun.api.schemas.AuthorizationAction.update, + mlrun.common.schemas.AuthorizationAction.update, auth_info, ) # to reduce redundant load on the chief, we re-route the request only if the user has permissions if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): logger.info( "Requesting to invoke schedule, re-routing to chief", @@ -237,27 +239,25 @@ async def invoke_schedule( return await get_scheduler().invoke_schedule(db_session, auth_info, project, name) -@router.delete( - "/projects/{project}/schedules/{name}", status_code=HTTPStatus.NO_CONTENT.value -) +@router.delete("/{name}", status_code=HTTPStatus.NO_CONTENT.value) async def delete_schedule( project: str, name: str, request: fastapi.Request, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.schedule, + mlrun.common.schemas.AuthorizationResourceTypes.schedule, project, name, - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) # to reduce redundant load on the chief, we re-route the request only if the user has permissions if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): logger.info( "Requesting to delete schedule, re-routing to chief", @@ -273,11 +273,11 @@ async def delete_schedule( return Response(status_code=HTTPStatus.NO_CONTENT.value) -@router.delete("/projects/{project}/schedules", status_code=HTTPStatus.NO_CONTENT.value) +@router.delete("", status_code=HTTPStatus.NO_CONTENT.value) async def delete_schedules( project: str, request: fastapi.Request, - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), ): schedules = await run_in_threadpool( @@ -286,16 +286,16 @@ async def delete_schedules( project, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resources_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.schedule, + mlrun.common.schemas.AuthorizationResourceTypes.schedule, schedules.schedules, lambda schedule: (schedule.project, schedule.name), - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) # to reduce redundant load on the chief, we re-route the request only if the user has permissions if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): logger.info( "Requesting to delete all project schedules, re-routing to chief", @@ -306,3 +306,60 @@ async def delete_schedules( await run_in_threadpool(get_scheduler().delete_schedules, db_session, project) return Response(status_code=HTTPStatus.NO_CONTENT.value) + + +@router.put("/{name}/notifications", status_code=HTTPStatus.OK.value) +async def set_schedule_notifications( + project: str, + name: str, + request: fastapi.Request, + set_notifications_request: mlrun.common.schemas.SetNotificationRequest = fastapi.Body( + ... + ), + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( + mlrun.api.api.deps.authenticate_request + ), + db_session: Session = fastapi.Depends(mlrun.api.api.deps.get_db_session), +): + await fastapi.concurrency.run_in_threadpool( + mlrun.api.utils.singletons.project_member.get_project_member().ensure_project, + db_session, + project, + auth_info=auth_info, + ) + + # check permission per object type + await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.schedule, + project, + resource_name=name, + action=mlrun.common.schemas.AuthorizationAction.update, + auth_info=auth_info, + ) + + if ( + mlrun.mlconf.httpdb.clusterization.role + != mlrun.common.schemas.ClusterizationRole.chief + ): + logger.info( + "Requesting to set schedule notifications, re-routing to chief", + project=project, + schedule=set_notifications_request.dict(), + ) + chief_client = mlrun.api.utils.clients.chief.Client() + return await chief_client.set_schedule_notifications( + project=project, + schedule_name=name, + request=request, + json=set_notifications_request.dict(), + ) + + await fastapi.concurrency.run_in_threadpool( + mlrun.api.crud.Notifications().set_object_notifications, + db_session, + auth_info, + project, + set_notifications_request.notifications, + mlrun.common.schemas.ScheduleIdentifier(name=name), + ) + return fastapi.Response(status_code=HTTPStatus.OK.value) diff --git a/mlrun/api/api/endpoints/secrets.py b/mlrun/api/api/endpoints/secrets.py index 138939ff6f35..3eaaf361ebd2 100644 --- a/mlrun/api/api/endpoints/secrets.py +++ b/mlrun/api/api/endpoints/secrets.py @@ -23,9 +23,8 @@ import mlrun.api.crud import mlrun.api.utils.auth.verifier import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.errors -from mlrun.api import schemas -from mlrun.utils.vault import add_vault_user_secrets router = fastapi.APIRouter() @@ -33,8 +32,8 @@ @router.post("/projects/{project}/secrets", status_code=HTTPStatus.CREATED.value) async def store_project_secrets( project: str, - secrets: schemas.SecretsData, - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + secrets: mlrun.common.schemas.SecretsData, + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: Session = fastapi.Depends(mlrun.api.api.deps.get_db_session), @@ -50,10 +49,10 @@ async def store_project_secrets( ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.secret, + mlrun.common.schemas.AuthorizationResourceTypes.secret, project, secrets.provider, - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) await run_in_threadpool( @@ -66,9 +65,9 @@ async def store_project_secrets( @router.delete("/projects/{project}/secrets", status_code=HTTPStatus.NO_CONTENT.value) async def delete_project_secrets( project: str, - provider: schemas.SecretProviderName = schemas.SecretProviderName.kubernetes, + provider: mlrun.common.schemas.SecretProviderName = mlrun.common.schemas.SecretProviderName.kubernetes, secrets: List[str] = fastapi.Query(None, alias="secret"), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: Session = fastapi.Depends(mlrun.api.api.deps.get_db_session), @@ -81,10 +80,10 @@ async def delete_project_secrets( ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.secret, + mlrun.common.schemas.AuthorizationResourceTypes.secret, project, provider, - mlrun.api.schemas.AuthorizationAction.delete, + mlrun.common.schemas.AuthorizationAction.delete, auth_info, ) await run_in_threadpool( @@ -94,12 +93,17 @@ async def delete_project_secrets( return fastapi.Response(status_code=HTTPStatus.NO_CONTENT.value) -@router.get("/projects/{project}/secret-keys", response_model=schemas.SecretKeysData) +@router.get( + "/projects/{project}/secret-keys", + response_model=mlrun.common.schemas.SecretKeysData, +) async def list_project_secret_keys( project: str, - provider: schemas.SecretProviderName = schemas.SecretProviderName.kubernetes, - token: str = fastapi.Header(None, alias=schemas.HeaderNames.secret_store_token), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + provider: mlrun.common.schemas.SecretProviderName = mlrun.common.schemas.SecretProviderName.kubernetes, + token: str = fastapi.Header( + None, alias=mlrun.common.schemas.HeaderNames.secret_store_token + ), + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: Session = fastapi.Depends(mlrun.api.api.deps.get_db_session), @@ -111,10 +115,10 @@ async def list_project_secret_keys( auth_info.session, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.secret, + mlrun.common.schemas.AuthorizationResourceTypes.secret, project, provider, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) return await run_in_threadpool( @@ -122,13 +126,17 @@ async def list_project_secret_keys( ) -@router.get("/projects/{project}/secrets", response_model=schemas.SecretsData) +@router.get( + "/projects/{project}/secrets", response_model=mlrun.common.schemas.SecretsData +) async def list_project_secrets( project: str, secrets: List[str] = fastapi.Query(None, alias="secret"), - provider: schemas.SecretProviderName = schemas.SecretProviderName.kubernetes, - token: str = fastapi.Header(None, alias=schemas.HeaderNames.secret_store_token), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + provider: mlrun.common.schemas.SecretProviderName = mlrun.common.schemas.SecretProviderName.kubernetes, + token: str = fastapi.Header( + None, alias=mlrun.common.schemas.HeaderNames.secret_store_token + ), + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: Session = fastapi.Depends(mlrun.api.api.deps.get_db_session), @@ -140,10 +148,10 @@ async def list_project_secrets( auth_info.session, ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.secret, + mlrun.common.schemas.AuthorizationResourceTypes.secret, project, provider, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) return await run_in_threadpool( @@ -153,13 +161,10 @@ async def list_project_secrets( @router.post("/user-secrets", status_code=HTTPStatus.CREATED.value) def add_user_secrets( - secrets: schemas.UserSecretCreationRequest, + secrets: mlrun.common.schemas.UserSecretCreationRequest, ): - if secrets.provider != schemas.SecretProviderName.vault: - return fastapi.Response( - status_code=HTTPStatus.BAD_REQUEST.vault, - content=f"Invalid secrets provider {secrets.provider}", - ) - - add_vault_user_secrets(secrets.user, secrets.secrets) - return fastapi.Response(status_code=HTTPStatus.CREATED.value) + # vault is not used + return fastapi.Response( + status_code=HTTPStatus.BAD_REQUEST.value, + content=f"Invalid secrets provider {secrets.provider}", + ) diff --git a/mlrun/api/api/endpoints/submit.py b/mlrun/api/api/endpoints/submit.py index fce34fa7107f..f9eb9af19901 100644 --- a/mlrun/api/api/endpoints/submit.py +++ b/mlrun/api/api/endpoints/submit.py @@ -20,10 +20,10 @@ from sqlalchemy.orm import Session import mlrun.api.api.utils -import mlrun.api.schemas import mlrun.api.utils.auth.verifier import mlrun.api.utils.clients.chief import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.utils.helpers from mlrun.api.api import deps from mlrun.utils import logger @@ -38,13 +38,13 @@ async def submit_job( request: Request, username: Optional[str] = Header(None, alias="x-remote-user"), - auth_info: mlrun.api.schemas.AuthInfo = Depends(deps.authenticate_request), + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), db_session: Session = Depends(deps.get_db_session), client_version: Optional[str] = Header( - None, alias=mlrun.api.schemas.HeaderNames.client_version + None, alias=mlrun.common.schemas.HeaderNames.client_version ), client_python_version: Optional[str] = Header( - None, alias=mlrun.api.schemas.HeaderNames.python_version + None, alias=mlrun.common.schemas.HeaderNames.python_version ), ): data = None @@ -70,18 +70,18 @@ async def submit_job( _, ) = mlrun.utils.helpers.parse_versioned_object_uri(function_url) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.function, + mlrun.common.schemas.AuthorizationResourceTypes.function, function_project, function_name, - mlrun.api.schemas.AuthorizationAction.read, + mlrun.common.schemas.AuthorizationAction.read, auth_info, ) if data.get("schedule"): await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.schedule, + mlrun.common.schemas.AuthorizationResourceTypes.schedule, data["task"]["metadata"]["project"], data["task"]["metadata"]["name"], - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) # schedules are meant to be run solely by the chief, then if run is configured to run as scheduled @@ -89,7 +89,7 @@ async def submit_job( # to reduce redundant load on the chief, we re-route the request only if the user has permissions if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): logger.info( "Requesting to submit job with schedules, re-routing to chief", @@ -102,10 +102,10 @@ async def submit_job( else: await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - mlrun.api.schemas.AuthorizationResourceTypes.run, + mlrun.common.schemas.AuthorizationResourceTypes.run, data["task"]["metadata"]["project"], "", - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) diff --git a/mlrun/api/api/endpoints/tags.py b/mlrun/api/api/endpoints/tags.py index b90b472024f4..a625955954ac 100644 --- a/mlrun/api/api/endpoints/tags.py +++ b/mlrun/api/api/endpoints/tags.py @@ -20,20 +20,20 @@ import mlrun.api.api.deps import mlrun.api.crud.tags -import mlrun.api.schemas import mlrun.api.utils.auth.verifier import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas from mlrun.utils.helpers import tag_name_regex_as_string -router = fastapi.APIRouter() +router = fastapi.APIRouter(prefix="/projects/{project}/tags") -@router.post("/projects/{project}/tags/{tag}", response_model=mlrun.api.schemas.Tag) +@router.post("/{tag}", response_model=mlrun.common.schemas.Tag) async def overwrite_object_tags_with_tag( project: str, tag: str = fastapi.Path(..., regex=tag_name_regex_as_string()), - tag_objects: mlrun.api.schemas.TagObjects = fastapi.Body(...), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + tag_objects: mlrun.common.schemas.TagObjects = fastapi.Body(...), + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -49,11 +49,11 @@ async def overwrite_object_tags_with_tag( # check permission per object type await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - getattr(mlrun.api.schemas.AuthorizationResourceTypes, tag_objects.kind), + getattr(mlrun.common.schemas.AuthorizationResourceTypes, tag_objects.kind), project, resource_name="", # not actually overwriting objects, just overwriting the objects tags - action=mlrun.api.schemas.AuthorizationAction.update, + action=mlrun.common.schemas.AuthorizationAction.update, auth_info=auth_info, ) @@ -64,15 +64,15 @@ async def overwrite_object_tags_with_tag( tag, tag_objects, ) - return mlrun.api.schemas.Tag(name=tag, project=project) + return mlrun.common.schemas.Tag(name=tag, project=project) -@router.put("/projects/{project}/tags/{tag}", response_model=mlrun.api.schemas.Tag) +@router.put("/{tag}", response_model=mlrun.common.schemas.Tag) async def append_tag_to_objects( project: str, tag: str = fastapi.Path(..., regex=tag_name_regex_as_string()), - tag_objects: mlrun.api.schemas.TagObjects = fastapi.Body(...), - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + tag_objects: mlrun.common.schemas.TagObjects = fastapi.Body(...), + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -87,10 +87,10 @@ async def append_tag_to_objects( ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - getattr(mlrun.api.schemas.AuthorizationResourceTypes, tag_objects.kind), + getattr(mlrun.common.schemas.AuthorizationResourceTypes, tag_objects.kind), project, resource_name="", - action=mlrun.api.schemas.AuthorizationAction.update, + action=mlrun.common.schemas.AuthorizationAction.update, auth_info=auth_info, ) @@ -101,17 +101,15 @@ async def append_tag_to_objects( tag, tag_objects, ) - return mlrun.api.schemas.Tag(name=tag, project=project) + return mlrun.common.schemas.Tag(name=tag, project=project) -@router.delete( - "/projects/{project}/tags/{tag}", status_code=http.HTTPStatus.NO_CONTENT.value -) +@router.delete("/{tag}", status_code=http.HTTPStatus.NO_CONTENT.value) async def delete_tag_from_objects( project: str, tag: str, - tag_objects: mlrun.api.schemas.TagObjects, - auth_info: mlrun.api.schemas.AuthInfo = fastapi.Depends( + tag_objects: mlrun.common.schemas.TagObjects, + auth_info: mlrun.common.schemas.AuthInfo = fastapi.Depends( mlrun.api.api.deps.authenticate_request ), db_session: sqlalchemy.orm.Session = fastapi.Depends( @@ -126,11 +124,11 @@ async def delete_tag_from_objects( ) await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( - getattr(mlrun.api.schemas.AuthorizationResourceTypes, tag_objects.kind), + getattr(mlrun.common.schemas.AuthorizationResourceTypes, tag_objects.kind), project, resource_name="", # not actually deleting objects, just deleting the objects tags - action=mlrun.api.schemas.AuthorizationAction.update, + action=mlrun.common.schemas.AuthorizationAction.update, auth_info=auth_info, ) diff --git a/mlrun/api/api/utils.py b/mlrun/api/api/utils.py index f937e6cfb35c..dcbcef0c6649 100644 --- a/mlrun/api/api/utils.py +++ b/mlrun/api/api/utils.py @@ -31,19 +31,17 @@ import mlrun.api.utils.auth.verifier import mlrun.api.utils.clients.iguazio import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas import mlrun.errors import mlrun.runtimes.pod import mlrun.utils.helpers -from mlrun.api import schemas from mlrun.api.db.sqldb.db import SQLDB -from mlrun.api.schemas import SecretProviderName, SecurityContextEnrichmentModes from mlrun.api.utils.singletons.db import get_db from mlrun.api.utils.singletons.logs_dir import get_logs_dir from mlrun.api.utils.singletons.scheduler import get_scheduler from mlrun.config import config from mlrun.db.sqldb import SQLDB as SQLRunDB from mlrun.errors import err_to_str -from mlrun.k8s_utils import get_k8s_helper from mlrun.run import import_function, new_function from mlrun.runtimes.utils import enrich_function_from_dict from mlrun.utils import get_in, logger, parse_versioned_object_uri @@ -122,7 +120,9 @@ def get_allowed_path_prefixes_list() -> typing.List[str]: return allowed_paths_list -def get_secrets(auth_info: mlrun.api.schemas.AuthInfo): +def get_secrets( + auth_info: mlrun.common.schemas.AuthInfo, +): return { "V3IO_ACCESS_KEY": auth_info.data_session, } @@ -155,7 +155,7 @@ def parse_submit_run_body(data): def _generate_function_and_task_from_submit_run_body( - db_session: Session, auth_info: mlrun.api.schemas.AuthInfo, data + db_session: Session, auth_info: mlrun.common.schemas.AuthInfo, data ): function_dict, function_url, task = parse_submit_run_body(data) # TODO: block exec for function["kind"] in ["", "local] (must be a @@ -192,7 +192,9 @@ def _generate_function_and_task_from_submit_run_body( return function, task -async def submit_run(db_session: Session, auth_info: mlrun.api.schemas.AuthInfo, data): +async def submit_run( + db_session: Session, auth_info: mlrun.common.schemas.AuthInfo, data +): _, _, _, response = await run_in_threadpool( submit_run_sync, db_session, auth_info, data ) @@ -209,25 +211,31 @@ def mask_notification_params_on_task(task): run_uid = get_in(task, "metadata.uid") project = get_in(task, "metadata.project") notifications = task.get("spec", {}).get("notifications", []) + masked_notifications = [] if notifications: for notification in notifications: notification_object = mlrun.model.Notification.from_dict(notification) - mask_notification_params_with_secret(project, run_uid, notification_object) + masked_notifications.append( + mask_notification_params_with_secret( + project, run_uid, notification_object + ).to_dict() + ) + task.setdefault("spec", {})["notifications"] = masked_notifications def mask_notification_params_with_secret( - project: str, run_uid: str, notification_object: mlrun.model.Notification + project: str, parent: str, notification_object: mlrun.model.Notification ) -> mlrun.model.Notification: if notification_object.params and "secret" not in notification_object.params: secret_key = mlrun.api.crud.Secrets().generate_client_project_secret_key( mlrun.api.crud.SecretsClientType.notifications, - run_uid, + parent, notification_object.name, ) mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( - provider=mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.SecretsData( + provider=mlrun.common.schemas.SecretProviderName.kubernetes, secrets={secret_key: json.dumps(notification_object.params)}, ), allow_internal_secrets=True, @@ -256,7 +264,7 @@ def unmask_notification_params_secret( if not params_secret: return notification_object - k8s = mlrun.api.utils.singletons.k8s.get_k8s() + k8s = mlrun.api.utils.singletons.k8s.get_k8s_helper() if not k8s: raise mlrun.errors.MLRunRuntimeError( "Not running in k8s environment, cannot load notification params secret" @@ -265,7 +273,7 @@ def unmask_notification_params_secret( notification_object.params = json.loads( mlrun.api.crud.Secrets().get_project_secret( project, - mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.SecretProviderName.kubernetes, secret_key=params_secret, allow_internal_secrets=True, allow_secrets_from_k8s=True, @@ -283,7 +291,7 @@ def delete_notification_params_secret( if not params_secret: return - k8s = mlrun.api.utils.singletons.k8s.get_k8s() + k8s = mlrun.api.utils.singletons.k8s.get_k8s_helper() if not k8s: raise mlrun.errors.MLRunRuntimeError( "Not running in k8s environment, cannot delete notification params secret" @@ -291,16 +299,61 @@ def delete_notification_params_secret( mlrun.api.crud.Secrets().delete_project_secret( project, - mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.SecretProviderName.kubernetes, secret_key=params_secret, allow_internal_secrets=True, allow_secrets_from_k8s=True, ) +def validate_and_mask_notification_list( + notifications: typing.List[ + typing.Union[mlrun.model.Notification, mlrun.common.schemas.Notification, dict] + ], + parent: str, + project: str, +) -> typing.List[mlrun.model.Notification]: + """ + Validates notification schema, uniqueness and masks notification params with secret if needed. + If at least one of the validation steps fails, the function will raise an exception and cause the API to return + an error response. + :param notifications: list of notification objects + :param parent: parent identifier + :param project: project name + :return: list of validated and masked notification objects + """ + notification_objects = [] + + for notification in notifications: + if isinstance(notification, dict): + notification_object = mlrun.model.Notification.from_dict(notification) + elif isinstance(notification, mlrun.common.schemas.Notification): + notification_object = mlrun.model.Notification.from_dict( + notification.dict() + ) + elif isinstance(notification, mlrun.model.Notification): + notification_object = notification + else: + raise mlrun.errors.MLRunInvalidArgumentError( + "notification must be a dict or a Notification object" + ) + + # validate notification schema + mlrun.common.schemas.Notification(**notification_object.to_dict()) + + notification_objects.append(notification_object) + + mlrun.model.Notification.validate_notification_uniqueness(notification_objects) + + return [ + mask_notification_params_with_secret(project, parent, notification_object) + for notification_object in notification_objects + ] + + def apply_enrichment_and_validation_on_function( function, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, ensure_auth: bool = True, perform_auto_mount: bool = True, validate_service_account: bool = True, @@ -340,14 +393,14 @@ def apply_enrichment_and_validation_on_function( def ensure_function_auth_and_sensitive_data_is_masked( function, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, allow_empty_access_key: bool = False, ): ensure_function_has_auth_set(function, auth_info, allow_empty_access_key) mask_function_sensitive_data(function, auth_info) -def mask_function_sensitive_data(function, auth_info: mlrun.api.schemas.AuthInfo): +def mask_function_sensitive_data(function, auth_info: mlrun.common.schemas.AuthInfo): if not mlrun.runtimes.RuntimeKinds.is_local_runtime(function.kind): _mask_v3io_access_key_env_var(function, auth_info) _mask_v3io_volume_credentials(function) @@ -431,8 +484,8 @@ def _mask_v3io_volume_credentials(function: mlrun.runtimes.pod.KubeResource): if not username: continue secret_name = mlrun.api.crud.Secrets().store_auth_secret( - mlrun.api.schemas.AuthSecretData( - provider=mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.AuthSecretData( + provider=mlrun.common.schemas.SecretProviderName.kubernetes, username=username, access_key=access_key, ) @@ -494,7 +547,7 @@ def _resolve_v3io_fuse_volume_access_key_matching_username( def _mask_v3io_access_key_env_var( - function: mlrun.runtimes.pod.KubeResource, auth_info: mlrun.api.schemas.AuthInfo + function: mlrun.runtimes.pod.KubeResource, auth_info: mlrun.common.schemas.AuthInfo ): v3io_access_key = function.get_env("V3IO_ACCESS_KEY") # if it's already a V1EnvVarSource or dict instance, it's already been masked @@ -521,14 +574,14 @@ def _mask_v3io_access_key_env_var( ) return secret_name = mlrun.api.crud.Secrets().store_auth_secret( - mlrun.api.schemas.AuthSecretData( - provider=mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.AuthSecretData( + provider=mlrun.common.schemas.SecretProviderName.kubernetes, username=username, access_key=v3io_access_key, ) ) - access_key_secret_key = mlrun.api.schemas.AuthSecretData.get_field_secret_key( - "access_key" + access_key_secret_key = ( + mlrun.common.schemas.AuthSecretData.get_field_secret_key("access_key") ) function.set_env_from_secret( "V3IO_ACCESS_KEY", secret_name, access_key_secret_key @@ -537,7 +590,7 @@ def _mask_v3io_access_key_env_var( def ensure_function_has_auth_set( function: mlrun.runtimes.BaseRuntime, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, allow_empty_access_key: bool = False, ): """ @@ -584,8 +637,8 @@ def ensure_function_has_auth_set( "Username is missing from auth info" ) secret_name = mlrun.api.crud.Secrets().store_auth_secret( - mlrun.api.schemas.AuthSecretData( - provider=mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.AuthSecretData( + provider=mlrun.common.schemas.SecretProviderName.kubernetes, username=auth_info.username, access_key=function.metadata.credentials.access_key, ) @@ -598,8 +651,8 @@ def ensure_function_has_auth_set( mlrun.model.Credentials.secret_reference_prefix ) - access_key_secret_key = mlrun.api.schemas.AuthSecretData.get_field_secret_key( - "access_key" + access_key_secret_key = ( + mlrun.common.schemas.AuthSecretData.get_field_secret_key("access_key") ) auth_env_vars = { mlrun.runtimes.constants.FunctionEnvironmentVariables.auth_session: ( @@ -611,7 +664,7 @@ def ensure_function_has_auth_set( function.set_env_from_secret(env_key, secret_name, secret_key) -def try_perform_auto_mount(function, auth_info: mlrun.api.schemas.AuthInfo): +def try_perform_auto_mount(function, auth_info: mlrun.common.schemas.AuthInfo): if ( mlrun.runtimes.RuntimeKinds.is_local_runtime(function.kind) or function.spec.disable_auto_mount @@ -629,7 +682,9 @@ def try_perform_auto_mount(function, auth_info: mlrun.api.schemas.AuthInfo): def process_function_service_account(function): # If we're not running inside k8s, skip this check as it's not relevant. - if not get_k8s_helper(silent=True).is_running_inside_kubernetes_cluster(): + if not mlrun.api.utils.singletons.k8s.get_k8s_helper( + silent=True + ).is_running_inside_kubernetes_cluster(): return ( @@ -645,7 +700,7 @@ def process_function_service_account(function): def resolve_project_default_service_account(project_name: str): allowed_service_accounts = mlrun.api.crud.secrets.Secrets().get_project_secret( project_name, - SecretProviderName.kubernetes, + mlrun.common.schemas.SecretProviderName.kubernetes, mlrun.api.crud.secrets.Secrets().generate_client_project_secret_key( mlrun.api.crud.secrets.SecretsClientType.service_accounts, "allowed" ), @@ -660,7 +715,7 @@ def resolve_project_default_service_account(project_name: str): default_service_account = mlrun.api.crud.secrets.Secrets().get_project_secret( project_name, - SecretProviderName.kubernetes, + mlrun.common.schemas.SecretProviderName.kubernetes, mlrun.api.crud.secrets.Secrets().generate_client_project_secret_key( mlrun.api.crud.secrets.SecretsClientType.service_accounts, "default" ), @@ -687,7 +742,9 @@ def resolve_project_default_service_account(project_name: str): return allowed_service_accounts, default_service_account -def ensure_function_security_context(function, auth_info: mlrun.api.schemas.AuthInfo): +def ensure_function_security_context( + function, auth_info: mlrun.common.schemas.AuthInfo +): """ For iguazio we enforce that pods run with user id and group id depending on mlrun.mlconf.function.spec.security_context.enrichment_mode @@ -698,7 +755,7 @@ def ensure_function_security_context(function, auth_info: mlrun.api.schemas.Auth # security context is not yet supported with spark runtime since it requires spark 3.2+ if ( mlrun.mlconf.function.spec.security_context.enrichment_mode - == SecurityContextEnrichmentModes.disabled.value + == mlrun.common.schemas.SecurityContextEnrichmentModes.disabled.value or mlrun.runtimes.RuntimeKinds.is_local_runtime(function.kind) or function.kind == mlrun.runtimes.RuntimeKinds.spark # remote spark image currently requires running with user 1000 or root @@ -714,7 +771,7 @@ def ensure_function_security_context(function, auth_info: mlrun.api.schemas.Auth # Enrichment with retain enrichment mode should occur on function creation only. if ( mlrun.mlconf.function.spec.security_context.enrichment_mode - == SecurityContextEnrichmentModes.retain.value + == mlrun.common.schemas.SecurityContextEnrichmentModes.retain.value and function.spec.security_context is not None and function.spec.security_context.run_as_user is not None and function.spec.security_context.run_as_group is not None @@ -727,8 +784,8 @@ def ensure_function_security_context(function, auth_info: mlrun.api.schemas.Auth return if mlrun.mlconf.function.spec.security_context.enrichment_mode in [ - SecurityContextEnrichmentModes.override.value, - SecurityContextEnrichmentModes.retain.value, + mlrun.common.schemas.SecurityContextEnrichmentModes.override.value, + mlrun.common.schemas.SecurityContextEnrichmentModes.retain.value, ]: # before iguazio 3.6 the user unix id is not passed in the session verification response headers @@ -784,7 +841,7 @@ def ensure_function_security_context(function, auth_info: mlrun.api.schemas.Auth def submit_run_sync( - db_session: Session, auth_info: mlrun.api.schemas.AuthInfo, data + db_session: Session, auth_info: mlrun.common.schemas.AuthInfo, data ) -> typing.Tuple[str, str, str, typing.Dict]: """ :return: Tuple with: @@ -814,10 +871,19 @@ def submit_run_sync( if schedule: cron_trigger = schedule if isinstance(cron_trigger, dict): - cron_trigger = schemas.ScheduleCronTrigger(**cron_trigger) + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(**cron_trigger) schedule_labels = task["metadata"].get("labels") created = False + # if the task is pointing to a remote function (hub://), we need to save it to the db + # and update the task to point to the saved function, so that the scheduler will be able to + # access the db version of the function, and not the remote one (which can be changed between runs) + if "://" in task["spec"]["function"]: + function_uri = fn.save(versioned=True) + data.pop("function", None) + data.pop("function_url", None) + task["spec"]["function"] = function_uri.replace("db://", "") + try: get_scheduler().update_schedule( db_session, @@ -829,13 +895,17 @@ def submit_run_sync( schedule_labels, ) except mlrun.errors.MLRunNotFoundError: - logger.debug("No existing schedule found, creating a new one") + logger.debug( + "No existing schedule found, creating a new one", + project=task["metadata"]["project"], + name=task["metadata"]["name"], + ) get_scheduler().create_schedule( db_session, auth_info, task["metadata"]["project"], task["metadata"]["name"], - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, data, cron_trigger, schedule_labels, @@ -857,7 +927,7 @@ def submit_run_sync( mlrun.api.crud.Secrets() .list_project_secrets( task["metadata"]["project"], - mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.SecretProviderName.kubernetes, allow_secrets_from_k8s=True, ) .secrets diff --git a/mlrun/api/crud/__init__.py b/mlrun/api/crud/__init__.py index 0fdd2c760de5..00126c862c57 100644 --- a/mlrun/api/crud/__init__.py +++ b/mlrun/api/crud/__init__.py @@ -12,18 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from .artifacts import Artifacts # noqa: F401 -from .client_spec import ClientSpec # noqa: F401 -from .clusterization_spec import ClusterizationSpec # noqa: F401 -from .feature_store import FeatureStore # noqa: F401 -from .functions import Functions # noqa: F401 -from .logs import Logs # noqa: F401 -from .marketplace import Marketplace # noqa: F401 -from .model_monitoring import ModelEndpoints, ModelEndpointStoreType # noqa: F401 -from .notifications import Notifications # noqa: F401 -from .pipelines import Pipelines # noqa: F401 -from .projects import Projects # noqa: F401 -from .runs import Runs # noqa: F401 -from .runtime_resources import RuntimeResources # noqa: F401 -from .secrets import Secrets, SecretsClientType # noqa: F401 -from .tags import Tags # noqa: F401 +# flake8: noqa: F401 - this is until we take care of the F401 violations with respect to __all__ & sphinx + +from .artifacts import Artifacts +from .client_spec import ClientSpec +from .clusterization_spec import ClusterizationSpec +from .feature_store import FeatureStore +from .functions import Functions +from .hub import Hub +from .logs import Logs +from .model_monitoring import ModelEndpoints +from .notifications import Notifications +from .pipelines import Pipelines +from .projects import Projects +from .runs import Runs +from .runtime_resources import RuntimeResources +from .secrets import Secrets, SecretsClientType +from .tags import Tags diff --git a/mlrun/api/crud/artifacts.py b/mlrun/api/crud/artifacts.py index c95127331c5f..33ca4e4013fd 100644 --- a/mlrun/api/crud/artifacts.py +++ b/mlrun/api/crud/artifacts.py @@ -16,14 +16,14 @@ import sqlalchemy.orm -import mlrun.api.schemas import mlrun.api.utils.projects.remotes.follower import mlrun.api.utils.singletons.db import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas +import mlrun.common.schemas.artifact import mlrun.config import mlrun.errors import mlrun.utils.singleton -from mlrun.api.schemas.artifact import ArtifactsFormat class Artifacts( @@ -66,7 +66,7 @@ def get_artifact( tag: str = "latest", iter: int = 0, project: str = mlrun.mlconf.default_project, - format_: ArtifactsFormat = ArtifactsFormat.full, + format_: mlrun.common.schemas.artifact.ArtifactsFormat = mlrun.common.schemas.artifact.ArtifactsFormat.full, ) -> dict: project = project or mlrun.mlconf.default_project artifact = mlrun.api.utils.singletons.db.get_db().read_artifact( @@ -76,7 +76,7 @@ def get_artifact( iter, project, ) - if format_ == ArtifactsFormat.legacy: + if format_ == mlrun.common.schemas.artifact.ArtifactsFormat.legacy: return _transform_artifact_struct_to_legacy_format(artifact) return artifact @@ -90,10 +90,10 @@ def list_artifacts( since=None, until=None, kind: typing.Optional[str] = None, - category: typing.Optional[mlrun.api.schemas.ArtifactCategories] = None, + category: typing.Optional[mlrun.common.schemas.ArtifactCategories] = None, iter: typing.Optional[int] = None, best_iteration: bool = False, - format_: ArtifactsFormat = ArtifactsFormat.full, + format_: mlrun.common.schemas.artifact.ArtifactsFormat = mlrun.common.schemas.artifact.ArtifactsFormat.full, ) -> typing.List: project = project or mlrun.mlconf.default_project if labels is None: @@ -111,7 +111,7 @@ def list_artifacts( iter, best_iteration, ) - if format_ != ArtifactsFormat.legacy: + if format_ != mlrun.common.schemas.artifact.ArtifactsFormat.legacy: return artifacts return [ _transform_artifact_struct_to_legacy_format(artifact) @@ -122,7 +122,7 @@ def list_artifact_tags( self, db_session: sqlalchemy.orm.Session, project: str = mlrun.mlconf.default_project, - category: mlrun.api.schemas.ArtifactCategories = None, + category: mlrun.common.schemas.ArtifactCategories = None, ): project = project or mlrun.mlconf.default_project return mlrun.api.utils.singletons.db.get_db().list_artifact_tags( @@ -148,7 +148,7 @@ def delete_artifacts( name: str = "", tag: str = "latest", labels: typing.List[str] = None, - auth_info: mlrun.api.schemas.AuthInfo = mlrun.api.schemas.AuthInfo(), + auth_info: mlrun.common.schemas.AuthInfo = mlrun.common.schemas.AuthInfo(), ): project = project or mlrun.mlconf.default_project mlrun.api.utils.singletons.db.get_db().del_artifacts( diff --git a/mlrun/api/crud/client_spec.py b/mlrun/api/crud/client_spec.py index 02e6567870d3..10928eadab42 100644 --- a/mlrun/api/crud/client_spec.py +++ b/mlrun/api/crud/client_spec.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.utils.singleton from mlrun.config import Config, config, default_config from mlrun.runtimes.utils import resolve_mpijob_crd_version, resolve_nuclio_version @@ -24,8 +24,8 @@ class ClientSpec( def get_client_spec( self, client_version: str = None, client_python_version: str = None ): - mpijob_crd_version = resolve_mpijob_crd_version(api_context=True) - return mlrun.api.schemas.ClientSpec( + mpijob_crd_version = resolve_mpijob_crd_version() + return mlrun.common.schemas.ClientSpec( version=config.version, namespace=config.namespace, docker_registry=config.httpdb.builder.docker_registry, diff --git a/mlrun/api/crud/clusterization_spec.py b/mlrun/api/crud/clusterization_spec.py index 83b14659a814..ed1831ab770e 100644 --- a/mlrun/api/crud/clusterization_spec.py +++ b/mlrun/api/crud/clusterization_spec.py @@ -13,7 +13,7 @@ # limitations under the License. # import mlrun -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.utils.singleton @@ -23,7 +23,7 @@ class ClusterizationSpec( @staticmethod def get_clusterization_spec(): is_chief = mlrun.mlconf.httpdb.clusterization.role == "chief" - return mlrun.api.schemas.ClusterizationSpec( + return mlrun.common.schemas.ClusterizationSpec( chief_api_state=mlrun.mlconf.httpdb.state if is_chief else None, chief_version=mlrun.mlconf.version if is_chief else None, ) diff --git a/mlrun/api/crud/feature_store.py b/mlrun/api/crud/feature_store.py index 9e26c07be90b..3b0478263383 100644 --- a/mlrun/api/crud/feature_store.py +++ b/mlrun/api/crud/feature_store.py @@ -16,10 +16,10 @@ import sqlalchemy.orm -import mlrun.api.schemas import mlrun.api.utils.projects.remotes.follower import mlrun.api.utils.singletons.db import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.config import mlrun.errors import mlrun.utils.singleton @@ -32,7 +32,7 @@ def create_feature_set( self, db_session: sqlalchemy.orm.Session, project: str, - feature_set: mlrun.api.schemas.FeatureSet, + feature_set: mlrun.common.schemas.FeatureSet, versioned: bool = True, ) -> str: return self._create_object( @@ -47,7 +47,7 @@ def store_feature_set( db_session: sqlalchemy.orm.Session, project: str, name: str, - feature_set: mlrun.api.schemas.FeatureSet, + feature_set: mlrun.common.schemas.FeatureSet, tag: typing.Optional[str] = None, uid: typing.Optional[str] = None, versioned: bool = True, @@ -70,11 +70,11 @@ def patch_feature_set( feature_set_patch: dict, tag: typing.Optional[str] = None, uid: typing.Optional[str] = None, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ) -> str: return self._patch_object( db_session, - mlrun.api.schemas.FeatureSet, + mlrun.common.schemas.FeatureSet, project, name, feature_set_patch, @@ -90,9 +90,9 @@ def get_feature_set( name: str, tag: typing.Optional[str] = None, uid: typing.Optional[str] = None, - ) -> mlrun.api.schemas.FeatureSet: + ) -> mlrun.common.schemas.FeatureSet: return self._get_object( - db_session, mlrun.api.schemas.FeatureSet, project, name, tag, uid + db_session, mlrun.common.schemas.FeatureSet, project, name, tag, uid ) def list_feature_sets_tags( @@ -104,7 +104,7 @@ def list_feature_sets_tags( :return: a list of Tuple of (project, feature_set.name, tag) """ return self._list_object_type_tags( - db_session, mlrun.api.schemas.FeatureSet, project + db_session, mlrun.common.schemas.FeatureSet, project ) def list_feature_sets( @@ -117,11 +117,11 @@ def list_feature_sets( entities: typing.List[str] = None, features: typing.List[str] = None, labels: typing.List[str] = None, - partition_by: mlrun.api.schemas.FeatureStorePartitionByField = None, + partition_by: mlrun.common.schemas.FeatureStorePartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: mlrun.api.schemas.SortField = None, - partition_order: mlrun.api.schemas.OrderType = mlrun.api.schemas.OrderType.desc, - ) -> mlrun.api.schemas.FeatureSetsOutput: + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, + ) -> mlrun.common.schemas.FeatureSetsOutput: project = project or mlrun.mlconf.default_project return mlrun.api.utils.singletons.db.get_db().list_feature_sets( db_session, @@ -148,7 +148,7 @@ def delete_feature_set( ): self._delete_object( db_session, - mlrun.api.schemas.FeatureSet, + mlrun.common.schemas.FeatureSet, project, name, tag, @@ -163,7 +163,7 @@ def list_features( tag: typing.Optional[str] = None, entities: typing.List[str] = None, labels: typing.List[str] = None, - ) -> mlrun.api.schemas.FeaturesOutput: + ) -> mlrun.common.schemas.FeaturesOutput: project = project or mlrun.mlconf.default_project return mlrun.api.utils.singletons.db.get_db().list_features( db_session, @@ -181,7 +181,7 @@ def list_entities( name: str, tag: typing.Optional[str] = None, labels: typing.List[str] = None, - ) -> mlrun.api.schemas.EntitiesOutput: + ) -> mlrun.common.schemas.EntitiesOutput: project = project or mlrun.mlconf.default_project return mlrun.api.utils.singletons.db.get_db().list_entities( db_session, @@ -195,7 +195,7 @@ def create_feature_vector( self, db_session: sqlalchemy.orm.Session, project: str, - feature_vector: mlrun.api.schemas.FeatureVector, + feature_vector: mlrun.common.schemas.FeatureVector, versioned: bool = True, ) -> str: return self._create_object(db_session, project, feature_vector, versioned) @@ -205,7 +205,7 @@ def store_feature_vector( db_session: sqlalchemy.orm.Session, project: str, name: str, - feature_vector: mlrun.api.schemas.FeatureVector, + feature_vector: mlrun.common.schemas.FeatureVector, tag: typing.Optional[str] = None, uid: typing.Optional[str] = None, versioned: bool = True, @@ -228,11 +228,11 @@ def patch_feature_vector( feature_vector_patch: dict, tag: typing.Optional[str] = None, uid: typing.Optional[str] = None, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ) -> str: return self._patch_object( db_session, - mlrun.api.schemas.FeatureVector, + mlrun.common.schemas.FeatureVector, project, name, feature_vector_patch, @@ -248,10 +248,10 @@ def get_feature_vector( name: str, tag: typing.Optional[str] = None, uid: typing.Optional[str] = None, - ) -> mlrun.api.schemas.FeatureVector: + ) -> mlrun.common.schemas.FeatureVector: return self._get_object( db_session, - mlrun.api.schemas.FeatureVector, + mlrun.common.schemas.FeatureVector, project, name, tag, @@ -267,7 +267,7 @@ def list_feature_vectors_tags( :return: a list of Tuple of (project, feature_vector.name, tag) """ return self._list_object_type_tags( - db_session, mlrun.api.schemas.FeatureVector, project + db_session, mlrun.common.schemas.FeatureVector, project ) def list_feature_vectors( @@ -278,11 +278,11 @@ def list_feature_vectors( tag: typing.Optional[str] = None, state: str = None, labels: typing.List[str] = None, - partition_by: mlrun.api.schemas.FeatureStorePartitionByField = None, + partition_by: mlrun.common.schemas.FeatureStorePartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: mlrun.api.schemas.SortField = None, - partition_order: mlrun.api.schemas.OrderType = mlrun.api.schemas.OrderType.desc, - ) -> mlrun.api.schemas.FeatureVectorsOutput: + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, + ) -> mlrun.common.schemas.FeatureVectorsOutput: project = project or mlrun.mlconf.default_project return mlrun.api.utils.singletons.db.get_db().list_feature_vectors( db_session, @@ -307,7 +307,7 @@ def delete_feature_vector( ): self._delete_object( db_session, - mlrun.api.schemas.FeatureVector, + mlrun.common.schemas.FeatureVector, project, name, tag, @@ -319,17 +319,17 @@ def _create_object( db_session: sqlalchemy.orm.Session, project: str, object_: typing.Union[ - mlrun.api.schemas.FeatureSet, mlrun.api.schemas.FeatureVector + mlrun.common.schemas.FeatureSet, mlrun.common.schemas.FeatureVector ], versioned: bool = True, ) -> str: project = project or mlrun.mlconf.default_project self._validate_and_enrich_identity_for_object_creation(project, object_) - if isinstance(object_, mlrun.api.schemas.FeatureSet): + if isinstance(object_, mlrun.common.schemas.FeatureSet): return mlrun.api.utils.singletons.db.get_db().create_feature_set( db_session, project, object_, versioned ) - elif isinstance(object_, mlrun.api.schemas.FeatureVector): + elif isinstance(object_, mlrun.common.schemas.FeatureVector): return mlrun.api.utils.singletons.db.get_db().create_feature_vector( db_session, project, object_, versioned ) @@ -344,7 +344,7 @@ def _store_object( project: str, name: str, object_: typing.Union[ - mlrun.api.schemas.FeatureSet, mlrun.api.schemas.FeatureVector + mlrun.common.schemas.FeatureSet, mlrun.common.schemas.FeatureVector ], tag: typing.Optional[str] = None, uid: typing.Optional[str] = None, @@ -354,7 +354,7 @@ def _store_object( self._validate_and_enrich_identity_for_object_store( object_, project, name, tag, uid ) - if isinstance(object_, mlrun.api.schemas.FeatureSet): + if isinstance(object_, mlrun.common.schemas.FeatureSet): return mlrun.api.utils.singletons.db.get_db().store_feature_set( db_session, project, @@ -364,7 +364,7 @@ def _store_object( uid, versioned, ) - elif isinstance(object_, mlrun.api.schemas.FeatureVector): + elif isinstance(object_, mlrun.common.schemas.FeatureVector): return mlrun.api.utils.singletons.db.get_db().store_feature_vector( db_session, project, @@ -388,7 +388,7 @@ def _patch_object( object_patch: dict, tag: typing.Optional[str] = None, uid: typing.Optional[str] = None, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ) -> str: project = project or mlrun.mlconf.default_project self._validate_identity_for_object_patch( @@ -399,7 +399,7 @@ def _patch_object( tag, uid, ) - if object_schema.__name__ == mlrun.api.schemas.FeatureSet.__name__: + if object_schema.__name__ == mlrun.common.schemas.FeatureSet.__name__: return mlrun.api.utils.singletons.db.get_db().patch_feature_set( db_session, project, @@ -409,7 +409,7 @@ def _patch_object( uid, patch_mode, ) - elif object_schema.__name__ == mlrun.api.schemas.FeatureVector.__name__: + elif object_schema.__name__ == mlrun.common.schemas.FeatureVector.__name__: return mlrun.api.utils.singletons.db.get_db().patch_feature_vector( db_session, project, @@ -432,13 +432,15 @@ def _get_object( name: str, tag: typing.Optional[str] = None, uid: typing.Optional[str] = None, - ) -> typing.Union[mlrun.api.schemas.FeatureSet, mlrun.api.schemas.FeatureVector]: + ) -> typing.Union[ + mlrun.common.schemas.FeatureSet, mlrun.common.schemas.FeatureVector + ]: project = project or mlrun.mlconf.default_project - if object_schema.__name__ == mlrun.api.schemas.FeatureSet.__name__: + if object_schema.__name__ == mlrun.common.schemas.FeatureSet.__name__: return mlrun.api.utils.singletons.db.get_db().get_feature_set( db_session, project, name, tag, uid ) - elif object_schema.__name__ == mlrun.api.schemas.FeatureVector.__name__: + elif object_schema.__name__ == mlrun.common.schemas.FeatureVector.__name__: return mlrun.api.utils.singletons.db.get_db().get_feature_vector( db_session, project, name, tag, uid ) @@ -454,11 +456,11 @@ def _list_object_type_tags( project: str, ) -> typing.List[typing.Tuple[str, str, str]]: project = project or mlrun.mlconf.default_project - if object_schema.__name__ == mlrun.api.schemas.FeatureSet.__name__: + if object_schema.__name__ == mlrun.common.schemas.FeatureSet.__name__: return mlrun.api.utils.singletons.db.get_db().list_feature_sets_tags( db_session, project ) - elif object_schema.__name__ == mlrun.api.schemas.FeatureVector.__name__: + elif object_schema.__name__ == mlrun.common.schemas.FeatureVector.__name__: return mlrun.api.utils.singletons.db.get_db().list_feature_vectors_tags( db_session, project ) @@ -477,11 +479,11 @@ def _delete_object( uid: typing.Optional[str] = None, ): project = project or mlrun.mlconf.default_project - if object_schema.__name__ == mlrun.api.schemas.FeatureSet.__name__: + if object_schema.__name__ == mlrun.common.schemas.FeatureSet.__name__: mlrun.api.utils.singletons.db.get_db().delete_feature_set( db_session, project, name, tag, uid ) - elif object_schema.__name__ == mlrun.api.schemas.FeatureVector.__name__: + elif object_schema.__name__ == mlrun.common.schemas.FeatureVector.__name__: mlrun.api.utils.singletons.db.get_db().delete_feature_vector( db_session, project, name, tag, uid ) @@ -519,7 +521,7 @@ def _validate_identity_for_object_patch( @staticmethod def _validate_and_enrich_identity_for_object_store( object_: typing.Union[ - mlrun.api.schemas.FeatureSet, mlrun.api.schemas.FeatureVector + mlrun.common.schemas.FeatureSet, mlrun.common.schemas.FeatureVector ], project: str, name: str, @@ -550,7 +552,7 @@ def _validate_and_enrich_identity_for_object_store( def _validate_and_enrich_identity_for_object_creation( project: str, object_: typing.Union[ - mlrun.api.schemas.FeatureSet, mlrun.api.schemas.FeatureVector + mlrun.common.schemas.FeatureSet, mlrun.common.schemas.FeatureVector ], ): object_type = object_.__class__.__name__ diff --git a/mlrun/api/crud/functions.py b/mlrun/api/crud/functions.py index ef08a48e08fa..1583cdd0261a 100644 --- a/mlrun/api/crud/functions.py +++ b/mlrun/api/crud/functions.py @@ -17,10 +17,10 @@ import sqlalchemy.orm import mlrun.api.api.utils -import mlrun.api.schemas import mlrun.api.utils.projects.remotes.follower import mlrun.api.utils.singletons.db import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.config import mlrun.errors import mlrun.utils.singleton @@ -37,7 +37,7 @@ def store_function( project: str = mlrun.mlconf.default_project, tag: str = "", versioned: bool = False, - auth_info: mlrun.api.schemas.AuthInfo = None, + auth_info: mlrun.common.schemas.AuthInfo = None, ) -> str: project = project or mlrun.mlconf.default_project if auth_info: diff --git a/mlrun/api/crud/marketplace.py b/mlrun/api/crud/hub.py similarity index 67% rename from mlrun/api/crud/marketplace.py rename to mlrun/api/crud/hub.py index d521f667ccd7..042c30e7a8cb 100644 --- a/mlrun/api/crud/marketplace.py +++ b/mlrun/api/crud/hub.py @@ -13,37 +13,30 @@ # limitations under the License. # import json -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple +import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas +import mlrun.common.schemas.hub import mlrun.errors import mlrun.utils.singleton -from mlrun.api.schemas.marketplace import ( - MarketplaceCatalog, - MarketplaceItem, - MarketplaceItemMetadata, - MarketplaceItemSpec, - MarketplaceSource, - ObjectStatus, -) -from mlrun.api.utils.singletons.k8s import get_k8s from mlrun.config import config from mlrun.datastore import store_manager -from ..schemas import SecretProviderName from .secrets import Secrets, SecretsClientType # Using a complex separator, as it's less likely someone will use it in a real secret name secret_name_separator = "-__-" -class Marketplace(metaclass=mlrun.utils.singleton.Singleton): +class Hub(metaclass=mlrun.utils.singleton.Singleton): def __init__(self): - self._internal_project_name = config.marketplace.k8s_secrets_project_name + self._internal_project_name = config.hub.k8s_secrets_project_name self._catalogs = {} @staticmethod def _in_k8s(): - k8s_helper = get_k8s() + k8s_helper = mlrun.api.utils.singletons.k8s.get_k8s_helper() return ( k8s_helper is not None and k8s_helper.is_running_inside_kubernetes_cluster() ) @@ -52,10 +45,10 @@ def _in_k8s(): def _generate_credentials_secret_key(source, key=""): full_key = source + secret_name_separator + key return Secrets().generate_client_project_secret_key( - SecretsClientType.marketplace, full_key + SecretsClientType.hub, full_key ) - def add_source(self, source: MarketplaceSource): + def add_source(self, source: mlrun.common.schemas.hub.HubSource): source_name = source.metadata.name credentials = source.spec.credentials if credentials: @@ -75,7 +68,7 @@ def remove_source(self, source_name): ] Secrets().delete_project_secrets( self._internal_project_name, - SecretProviderName.kubernetes, + mlrun.common.schemas.SecretProviderName.kubernetes, secrets_to_delete, allow_internal_secrets=True, ) @@ -83,7 +76,7 @@ def remove_source(self, source_name): def _store_source_credentials(self, source_name, credentials: dict): if not self._in_k8s(): raise mlrun.errors.MLRunInvalidArgumentError( - "MLRun is not configured with k8s, marketplace source credentials cannot be stored securely" + "MLRun is not configured with k8s, hub source credentials cannot be stored securely" ) adjusted_credentials = { @@ -92,8 +85,9 @@ def _store_source_credentials(self, source_name, credentials: dict): } Secrets().store_project_secrets( self._internal_project_name, - mlrun.api.schemas.SecretsData( - provider=SecretProviderName.kubernetes, secrets=adjusted_credentials + mlrun.common.schemas.SecretsData( + provider=mlrun.common.schemas.SecretProviderName.kubernetes, + secrets=adjusted_credentials, ), allow_internal_secrets=True, ) @@ -107,7 +101,7 @@ def _get_source_credentials(self, source_name): Secrets() .list_project_secrets( self._internal_project_name, - SecretProviderName.kubernetes, + mlrun.common.schemas.SecretProviderName.kubernetes, allow_secrets_from_k8s=True, allow_internal_secrets=True, ) @@ -121,33 +115,61 @@ def _get_source_credentials(self, source_name): return source_secrets + @staticmethod + def _get_asset_full_path( + source: mlrun.common.schemas.hub.HubSource, + item: mlrun.common.schemas.hub.HubItem, + asset: str, + ): + """ + Combining the item path with the asset path. + + :param source: Hub source object. + :param item: The relevant item to get the asset from. + :param asset: The asset name + :return: Full path to the asset, relative to the item directory. + """ + asset_path = item.spec.assets.get(asset, None) + if not asset_path: + raise mlrun.errors.MLRunNotFoundError( + f"Asset={asset} not found. " + f"item={item.metadata.name}, version={item.metadata.version}, tag={item.metadata.tag}" + ) + item_path = item.metadata.get_relative_path() + return source.get_full_uri(item_path + asset_path) + @staticmethod def _transform_catalog_dict_to_schema( - source: MarketplaceSource, catalog_dict: Dict[str, Any] - ) -> MarketplaceCatalog: + source: mlrun.common.schemas.hub.HubSource, catalog_dict: Dict[str, Any] + ) -> mlrun.common.schemas.hub.HubCatalog: """ - Transforms catalog dictionary to MarketplaceCatalog schema - :param source: Marketplace source object. + Transforms catalog dictionary to HubCatalog schema + :param source: Hub source object. :param catalog_dict: raw catalog dict, top level keys are item names, second level keys are version tags ("latest, "1.1.0", ...) and bottom level keys include spec as a dict and all the rest is considered as metadata. :return: catalog object """ - catalog = MarketplaceCatalog(catalog=[], channel=source.spec.channel) + catalog = mlrun.common.schemas.hub.HubCatalog( + catalog=[], channel=source.spec.channel + ) # Loop over objects, then over object versions. for object_name, object_dict in catalog_dict.items(): for version_tag, version_dict in object_dict.items(): object_details_dict = version_dict.copy() spec_dict = object_details_dict.pop("spec", {}) - metadata = MarketplaceItemMetadata( + assets = object_details_dict.pop("assets", {}) + metadata = mlrun.common.schemas.hub.HubItemMetadata( tag=version_tag, **object_details_dict ) item_uri = source.get_full_uri(metadata.get_relative_path()) - spec = MarketplaceItemSpec(item_uri=item_uri, **spec_dict) - item = MarketplaceItem( + spec = mlrun.common.schemas.hub.HubItemSpec( + item_uri=item_uri, assets=assets, **spec_dict + ) + item = mlrun.common.schemas.hub.HubItem( metadata=metadata, spec=spec, - status=ObjectStatus(), + status=mlrun.common.schemas.ObjectStatus(), ) catalog.catalog.append(item) @@ -155,16 +177,16 @@ def _transform_catalog_dict_to_schema( def get_source_catalog( self, - source: MarketplaceSource, + source: mlrun.common.schemas.hub.HubSource, version: Optional[str] = None, tag: Optional[str] = None, force_refresh: bool = False, - ) -> MarketplaceCatalog: + ) -> mlrun.common.schemas.hub.HubCatalog: """ Getting the catalog object by source. If version and/or tag are given, the catalog will be filtered accordingly. - :param source: Marketplace source object. + :param source: Hub source object. :param version: version of items to filter by :param tag: tag of items to filter by :param force_refresh: if True, the catalog will be loaded from source always, @@ -182,7 +204,9 @@ def get_source_catalog( else: catalog = self._catalogs[source_name] - result_catalog = MarketplaceCatalog(catalog=[], channel=source.spec.channel) + result_catalog = mlrun.common.schemas.hub.HubCatalog( + catalog=[], channel=source.spec.channel + ) for item in catalog.catalog: # Because tag and version are optionals, # we filter the catalog by one of them with priority to tag @@ -195,23 +219,23 @@ def get_source_catalog( def get_item( self, - source: MarketplaceSource, + source: mlrun.common.schemas.hub.HubSource, item_name: str, version: Optional[str] = None, tag: Optional[str] = None, force_refresh: bool = False, - ) -> MarketplaceItem: + ) -> mlrun.common.schemas.hub.HubItem: """ Retrieve item from source. The item is filtered by tag and version. - :param source: Marketplace source object + :param source: Hub source object :param item_name: name of the item to retrieve :param version: version of the item :param tag: tag of the item :param force_refresh: if True, the catalog will be loaded from source always, otherwise will be pulled from db (if loaded before) - :return: marketplace item object + :return: hub item object :raise if the number of collected items from catalog is not exactly one. """ @@ -232,9 +256,9 @@ def get_item( @staticmethod def _get_catalog_items_filtered_by_name( - catalog: List[MarketplaceItem], + catalog: List[mlrun.common.schemas.hub.HubItem], item_name: str, - ) -> List[MarketplaceItem]: + ) -> List[mlrun.common.schemas.hub.HubItem]: """ Retrieve items from catalog filtered by name @@ -245,7 +269,9 @@ def _get_catalog_items_filtered_by_name( """ return [item for item in catalog if item.metadata.name == item_name] - def get_item_object_using_source_credentials(self, source: MarketplaceSource, url): + def get_item_object_using_source_credentials( + self, source: mlrun.common.schemas.hub.HubSource, url + ): credentials = self._get_source_credentials(source.metadata.name) if not url.startswith(source.spec.path): @@ -262,3 +288,25 @@ def get_item_object_using_source_credentials(self, source: MarketplaceSource, ur else: catalog_data = mlrun.run.get_object(url=url, secrets=credentials) return catalog_data + + def get_asset( + self, + source: mlrun.common.schemas.hub.HubSource, + item: mlrun.common.schemas.hub.HubItem, + asset_name: str, + ) -> Tuple[bytes, str]: + """ + Retrieve asset object from hub source. + + :param source: hub source + :param item: hub item which contains the assets + :param asset_name: asset name, like source, example, etc. + + :return: tuple of asset as bytes and url of asset + """ + credentials = self._get_source_credentials(source.metadata.name) + asset_path = self._get_asset_full_path(source, item, asset_name) + return ( + mlrun.run.get_object(url=asset_path, secrets=credentials), + asset_path, + ) diff --git a/mlrun/api/crud/logs.py b/mlrun/api/crud/logs.py index aaaf2f07e31c..20a9a3e19359 100644 --- a/mlrun/api/crud/logs.py +++ b/mlrun/api/crud/logs.py @@ -21,13 +21,13 @@ from fastapi.concurrency import run_in_threadpool from sqlalchemy.orm import Session -import mlrun.api.schemas import mlrun.api.utils.clients.log_collector as log_collector +import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas import mlrun.utils.singleton from mlrun.api.api.utils import log_and_raise, log_path, project_logs_path from mlrun.api.constants import LogSources from mlrun.api.utils.singletons.db import get_db -from mlrun.api.utils.singletons.k8s import get_k8s from mlrun.runtimes.constants import PodPhases from mlrun.utils import logger @@ -85,7 +85,7 @@ async def get_logs( log_stream = None if ( mlrun.mlconf.log_collector.mode - == mlrun.api.schemas.LogsCollectorMode.best_effort + == mlrun.common.schemas.LogsCollectorMode.best_effort and source == LogSources.AUTO ): try: @@ -112,7 +112,7 @@ async def get_logs( ) elif ( mlrun.mlconf.log_collector.mode - == mlrun.api.schemas.LogsCollectorMode.sidecar + == mlrun.common.schemas.LogsCollectorMode.sidecar and source == LogSources.AUTO ): log_stream = self._get_logs_from_logs_collector( @@ -123,7 +123,7 @@ async def get_logs( ) elif ( mlrun.mlconf.log_collector.mode - == mlrun.api.schemas.LogsCollectorMode.legacy + == mlrun.common.schemas.LogsCollectorMode.legacy or source != LogSources.AUTO ): log_stream = self._get_logs_legacy_method_generator_wrapper( @@ -178,10 +178,12 @@ def _get_logs_legacy_method( fp.seek(offset) log_contents = fp.read(size) elif source in [LogSources.AUTO, LogSources.K8S]: - k8s = get_k8s() + k8s = mlrun.api.utils.singletons.k8s.get_k8s_helper() if k8s and k8s.is_running_inside_kubernetes_cluster(): run_kind = run.get("metadata", {}).get("labels", {}).get("kind") - pods = get_k8s().get_logger_pods(project, uid, run_kind) + pods = mlrun.api.utils.singletons.k8s.get_k8s_helper().get_logger_pods( + project, uid, run_kind + ) if pods: if len(pods) > 1: @@ -195,7 +197,7 @@ def _get_logs_legacy_method( ) pod, pod_phase = list(pods.items())[0] if pod_phase != PodPhases.pending: - resp = get_k8s().logs(pod) + resp = mlrun.api.utils.singletons.k8s.get_k8s_helper().logs(pod) if resp: if size == -1: log_contents = resp.encode()[offset:] @@ -242,10 +244,7 @@ def get_log_mtime(self, project: str, uid: str) -> int: def log_file_exists_for_run_uid(project: str, uid: str) -> (bool, pathlib.Path): """ Checks if the log file exists for the given project and uid - There could be two types of log files: - 1. Log file which was created by the legacy logger with the following file format - project/) - 2. Log file which was created by the new logger with the following file format- /project/- - Therefore, we check if the log file exists for both formats + A Run's log file path is: /mlrun/logs/{project}/{uid} :param project: project name :param uid: run uid :return: True if the log file exists, False otherwise, and the log file path @@ -253,9 +252,10 @@ def log_file_exists_for_run_uid(project: str, uid: str) -> (bool, pathlib.Path): project_logs_dir = project_logs_path(project) if not project_logs_dir.exists(): return False, None - for file in os.listdir(str(project_logs_dir)): - if file.startswith(uid): - return True, project_logs_dir / file + + log_file = log_path(project, uid) + if log_file.exists(): + return True, log_file return False, None diff --git a/mlrun/api/crud/model_monitoring/__init__.py b/mlrun/api/crud/model_monitoring/__init__.py index 62b0bf17478b..11c0e215715a 100644 --- a/mlrun/api/crud/model_monitoring/__init__.py +++ b/mlrun/api/crud/model_monitoring/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# flake8: noqa: F401 - this is until we take care of the F401 violations with respect to __all__ & sphinx -from .model_endpoint_store import ModelEndpointStoreType # noqa: F401 -from .model_endpoints import ModelEndpoints # noqa: F401 +from .model_endpoints import ModelEndpoints diff --git a/mlrun/api/crud/model_monitoring/grafana.py b/mlrun/api/crud/model_monitoring/grafana.py new file mode 100644 index 000000000000..bcb2c7e94019 --- /dev/null +++ b/mlrun/api/crud/model_monitoring/grafana.py @@ -0,0 +1,427 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict, List, Optional, Set + +import numpy as np +import pandas as pd +from fastapi.concurrency import run_in_threadpool +from sqlalchemy.orm import Session + +import mlrun.api.crud +import mlrun.api.utils.auth.verifier +import mlrun.common.model_monitoring +import mlrun.common.schemas +from mlrun.api.utils.singletons.project_member import get_project_member +from mlrun.errors import MLRunBadRequestError +from mlrun.utils import config, logger +from mlrun.utils.model_monitoring import parse_model_endpoint_store_prefix +from mlrun.utils.v3io_clients import get_frames_client + + +def grafana_list_projects( + db_session: Session, + auth_info: mlrun.common.schemas.AuthInfo, + query_parameters: Dict[str, str], +) -> List[str]: + """ + List available project names. Will be used as a filter in each grafana dashboard. + + :param db_session: A session that manages the current dialog with the database. + :param auth_info: The auth info of the request. + :param query_parameters: Dictionary of query parameters attached to the request. Note that this parameter is + required by the API even though it is not being used in this function. + + :return: List of available project names. + """ + + projects_output = get_project_member().list_projects( + db_session, + format_=mlrun.common.schemas.ProjectsFormat.name_only, + leader_session=auth_info.session, + ) + return projects_output.projects + + +# TODO: remove in 1.5.0 the following functions: grafana_list_endpoints, grafana_individual_feature_analysis, +# grafana_overall_feature_analysis, grafana_income_features, parse_query_parameters, drop_grafana_escape_chars, + + +async def grafana_list_endpoints( + body: Dict[str, Any], + query_parameters: Dict[str, str], + auth_info: mlrun.common.schemas.AuthInfo, +) -> List[mlrun.common.schemas.GrafanaTable]: + project = query_parameters.get("project") + + # Filters + model = query_parameters.get("model", None) + function = query_parameters.get("function", None) + labels = query_parameters.get("labels", "") + labels = labels.split(",") if labels else [] + + # Metrics to include + metrics = query_parameters.get("metrics", "") + metrics = metrics.split(",") if metrics else [] + + # Time range for metrics + start = body.get("rangeRaw", {}).get("start", "now-1h") + end = body.get("rangeRaw", {}).get("end", "now") + + if project: + await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions( + project, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + endpoint_list = await run_in_threadpool( + mlrun.api.crud.ModelEndpoints().list_model_endpoints, + auth_info=auth_info, + project=project, + model=model, + function=function, + labels=labels, + metrics=metrics, + start=start, + end=end, + ) + allowed_endpoints = await mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.model_endpoint, + endpoint_list.endpoints, + lambda _endpoint: ( + _endpoint.metadata.project, + _endpoint.metadata.uid, + ), + auth_info, + ) + endpoint_list.endpoints = allowed_endpoints + + columns = [ + mlrun.common.schemas.GrafanaColumn(text="endpoint_id", type="string"), + mlrun.common.schemas.GrafanaColumn(text="endpoint_function", type="string"), + mlrun.common.schemas.GrafanaColumn(text="endpoint_model", type="string"), + mlrun.common.schemas.GrafanaColumn(text="endpoint_model_class", type="string"), + mlrun.common.schemas.GrafanaColumn(text="first_request", type="time"), + mlrun.common.schemas.GrafanaColumn(text="last_request", type="time"), + mlrun.common.schemas.GrafanaColumn(text="accuracy", type="number"), + mlrun.common.schemas.GrafanaColumn(text="error_count", type="number"), + mlrun.common.schemas.GrafanaColumn(text="drift_status", type="number"), + mlrun.common.schemas.GrafanaColumn( + text="predictions_per_second", type="number" + ), + mlrun.common.schemas.GrafanaColumn(text="latency_avg_1h", type="number"), + ] + + table = mlrun.common.schemas.GrafanaTable(columns=columns) + for endpoint in endpoint_list.endpoints: + row = [ + endpoint.metadata.uid, + endpoint.spec.function_uri, + endpoint.spec.model, + endpoint.spec.model_class, + endpoint.status.first_request, + endpoint.status.last_request, + "N/A", # Leaving here for backwards compatibility + endpoint.status.error_count, + endpoint.status.drift_status, + ] + + if ( + endpoint.status.metrics + and mlrun.common.model_monitoring.EventKeyMetrics.GENERIC + in endpoint.status.metrics + ): + row.extend( + [ + endpoint.status.metrics[ + mlrun.common.model_monitoring.EventKeyMetrics.GENERIC + ][ + mlrun.common.model_monitoring.EventLiveStats.PREDICTIONS_PER_SECOND + ], + endpoint.status.metrics[ + mlrun.common.model_monitoring.EventKeyMetrics.GENERIC + ][mlrun.common.model_monitoring.EventLiveStats.LATENCY_AVG_1H], + ] + ) + + table.add_row(*row) + + return [table] + + +async def grafana_individual_feature_analysis( + body: Dict[str, Any], + query_parameters: Dict[str, str], + auth_info: mlrun.common.schemas.AuthInfo, +): + endpoint_id = query_parameters.get("endpoint_id") + project = query_parameters.get("project") + await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.model_endpoint, + project, + endpoint_id, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + + endpoint = await run_in_threadpool( + mlrun.api.crud.ModelEndpoints().get_model_endpoint, + auth_info=auth_info, + project=project, + endpoint_id=endpoint_id, + feature_analysis=True, + ) + + # Load JSON data from KV, make sure not to fail if a field is missing + feature_stats = endpoint.status.feature_stats or {} + current_stats = endpoint.status.current_stats or {} + drift_measures = endpoint.status.drift_measures or {} + + table = mlrun.common.schemas.GrafanaTable( + columns=[ + mlrun.common.schemas.GrafanaColumn(text="feature_name", type="string"), + mlrun.common.schemas.GrafanaColumn(text="actual_min", type="number"), + mlrun.common.schemas.GrafanaColumn(text="actual_mean", type="number"), + mlrun.common.schemas.GrafanaColumn(text="actual_max", type="number"), + mlrun.common.schemas.GrafanaColumn(text="expected_min", type="number"), + mlrun.common.schemas.GrafanaColumn(text="expected_mean", type="number"), + mlrun.common.schemas.GrafanaColumn(text="expected_max", type="number"), + mlrun.common.schemas.GrafanaColumn(text="tvd", type="number"), + mlrun.common.schemas.GrafanaColumn(text="hellinger", type="number"), + mlrun.common.schemas.GrafanaColumn(text="kld", type="number"), + ] + ) + + for feature, base_stat in feature_stats.items(): + current_stat = current_stats.get(feature, {}) + drift_measure = drift_measures.get(feature, {}) + + table.add_row( + feature, + current_stat.get("min"), + current_stat.get("mean"), + current_stat.get("max"), + base_stat.get("min"), + base_stat.get("mean"), + base_stat.get("max"), + drift_measure.get("tvd"), + drift_measure.get("hellinger"), + drift_measure.get("kld"), + ) + + return [table] + + +async def grafana_overall_feature_analysis( + body: Dict[str, Any], + query_parameters: Dict[str, str], + auth_info: mlrun.common.schemas.AuthInfo, +): + endpoint_id = query_parameters.get("endpoint_id") + project = query_parameters.get("project") + await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.model_endpoint, + project, + endpoint_id, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + endpoint = await run_in_threadpool( + mlrun.api.crud.ModelEndpoints().get_model_endpoint, + auth_info=auth_info, + project=project, + endpoint_id=endpoint_id, + feature_analysis=True, + ) + + table = mlrun.common.schemas.GrafanaTable( + columns=[ + mlrun.common.schemas.GrafanaNumberColumn(text="tvd_sum"), + mlrun.common.schemas.GrafanaNumberColumn(text="tvd_mean"), + mlrun.common.schemas.GrafanaNumberColumn(text="hellinger_sum"), + mlrun.common.schemas.GrafanaNumberColumn(text="hellinger_mean"), + mlrun.common.schemas.GrafanaNumberColumn(text="kld_sum"), + mlrun.common.schemas.GrafanaNumberColumn(text="kld_mean"), + ] + ) + + if endpoint.status.drift_measures: + table.add_row( + endpoint.status.drift_measures.get("tvd_sum"), + endpoint.status.drift_measures.get("tvd_mean"), + endpoint.status.drift_measures.get("hellinger_sum"), + endpoint.status.drift_measures.get("hellinger_mean"), + endpoint.status.drift_measures.get("kld_sum"), + endpoint.status.drift_measures.get("kld_mean"), + ) + + return [table] + + +async def grafana_incoming_features( + body: Dict[str, Any], + query_parameters: Dict[str, str], + auth_info: mlrun.common.schemas.AuthInfo, +): + endpoint_id = query_parameters.get("endpoint_id") + project = query_parameters.get("project") + start = body.get("rangeRaw", {}).get("from", "now-1h") + end = body.get("rangeRaw", {}).get("to", "now") + + await mlrun.api.utils.auth.verifier.AuthVerifier().query_project_resource_permissions( + mlrun.common.schemas.AuthorizationResourceTypes.model_endpoint, + project, + endpoint_id, + mlrun.common.schemas.AuthorizationAction.read, + auth_info, + ) + + endpoint = await run_in_threadpool( + mlrun.api.crud.ModelEndpoints().get_model_endpoint, + auth_info=auth_info, + project=project, + endpoint_id=endpoint_id, + ) + + time_series = [] + + feature_names = endpoint.spec.feature_names + + if not feature_names: + logger.warn( + "'feature_names' is either missing or not initialized in endpoint record", + endpoint_id=endpoint.metadata.uid, + ) + return time_series + + path = config.model_endpoint_monitoring.store_prefixes.default.format( + project=project, kind=mlrun.common.schemas.ModelMonitoringStoreKinds.EVENTS + ) + _, container, path = parse_model_endpoint_store_prefix(path) + + client = get_frames_client( + token=auth_info.data_session, + address=config.v3io_framesd, + container=container, + ) + + data: pd.DataFrame = await run_in_threadpool( + client.read, + backend="tsdb", + table=path, + columns=feature_names, + filter=f"endpoint_id=='{endpoint_id}'", + start=start, + end=end, + ) + + data.drop(["endpoint_id"], axis=1, inplace=True, errors="ignore") + data.index = data.index.astype(np.int64) // 10**6 + + for feature, indexed_values in data.to_dict().items(): + target = mlrun.common.schemas.GrafanaTimeSeriesTarget(target=feature) + for index, value in indexed_values.items(): + data_point = mlrun.common.schemas.GrafanaDataPoint( + value=float(value), timestamp=index + ) + target.add_data_point(data_point) + time_series.append(target) + + return time_series + + +def parse_query_parameters(request_body: Dict[str, Any]) -> Dict[str, str]: + """ + This function searches for the target field in Grafana's SimpleJson json. Once located, the target string is + parsed by splitting on semi-colons (;). Each part in the resulting list is then split by an equal sign (=) to be + read as key-value pairs. + """ + + # Try to get the target + targets = request_body.get("targets", []) + + if len(targets) > 1: + logger.warn( + f"The 'targets' list contains more than one element ({len(targets)}), all targets except the first one are " + f"ignored." + ) + + target_obj = targets[0] if targets else {} + target_query = target_obj.get("target") if target_obj else "" + + if not target_query: + raise MLRunBadRequestError(f"Target missing in request body:\n {request_body}") + + parameters = _parse_parameters(target_query) + + return parameters + + +def parse_search_parameters(request_body: Dict[str, Any]) -> Dict[str, str]: + """ + This function searches for the target field in Grafana's SimpleJson json. Once located, the target string is + parsed by splitting on semi-colons (;). Each part in the resulting list is then split by an equal sign (=) to be + read as key-value pairs. + """ + + # Try to get the target + target = request_body.get("target") + + if not target: + raise MLRunBadRequestError(f"Target missing in request body:\n {request_body}") + + parameters = _parse_parameters(target) + + return parameters + + +def _parse_parameters(target_query): + parameters = {} + for query in filter(lambda q: q, target_query.split(";")): + query_parts = query.split("=") + if len(query_parts) < 2: + raise MLRunBadRequestError( + f"Query must contain both query key and query value. Expected query_key=query_value, found {query} " + f"instead." + ) + parameters[query_parts[0]] = query_parts[1] + return parameters + + +def drop_grafana_escape_chars(query_parameters: Dict[str, str]): + query_parameters = dict(query_parameters) + endpoint_id = query_parameters.get("endpoint_id") + if endpoint_id is not None: + query_parameters["endpoint_id"] = endpoint_id.replace("\\", "") + return query_parameters + + +def validate_query_parameters( + query_parameters: Dict[str, str], supported_endpoints: Optional[Set[str]] = None +): + """Validates the parameters sent via Grafana's SimpleJson query""" + if "target_endpoint" not in query_parameters: + raise MLRunBadRequestError( + f"Expected 'target_endpoint' field in query, found {query_parameters} instead" + ) + + if ( + supported_endpoints is not None + and query_parameters["target_endpoint"] not in supported_endpoints + ): + raise MLRunBadRequestError( + f"{query_parameters['target_endpoint']} unsupported in query parameters: {query_parameters}. " + f"Currently supports: {','.join(supported_endpoints)}" + ) diff --git a/mlrun/api/crud/model_monitoring/model_endpoint_store.py b/mlrun/api/crud/model_monitoring/model_endpoint_store.py deleted file mode 100644 index 54f6e30718f6..000000000000 --- a/mlrun/api/crud/model_monitoring/model_endpoint_store.py +++ /dev/null @@ -1,847 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -import enum -import json -import typing -from abc import ABC, abstractmethod - -import v3io.dataplane -import v3io_frames - -import mlrun -import mlrun.api.schemas -import mlrun.model_monitoring.constants as model_monitoring_constants -import mlrun.utils.model_monitoring -import mlrun.utils.v3io_clients -from mlrun.utils import logger - - -class _ModelEndpointStore(ABC): - """ - An abstract class to handle the model endpoint in the DB target. - """ - - def __init__(self, project: str): - """ - Initialize a new model endpoint target. - - :param project: The name of the project. - """ - self.project = project - - @abstractmethod - def write_model_endpoint(self, endpoint: mlrun.api.schemas.ModelEndpoint): - """ - Create a new endpoint record in the DB table. - - :param endpoint: ModelEndpoint object that will be written into the DB. - """ - pass - - @abstractmethod - def update_model_endpoint(self, endpoint_id: str, attributes: dict): - """ - Update a model endpoint record with a given attributes. - - :param endpoint_id: The unique id of the model endpoint. - :param attributes: Dictionary of attributes that will be used for update the model endpoint. Note that the keys - of the attributes dictionary should exist in the KV table. - - """ - pass - - @abstractmethod - def delete_model_endpoint(self, endpoint_id: str): - """ - Deletes the record of a given model endpoint id. - - :param endpoint_id: The unique id of the model endpoint. - """ - pass - - @abstractmethod - def delete_model_endpoints_resources( - self, endpoints: mlrun.api.schemas.model_endpoints.ModelEndpointList - ): - """ - Delete all model endpoints resources. - - :param endpoints: An object of ModelEndpointList which is literally a list of model endpoints along with some - metadata. To get a standard list of model endpoints use ModelEndpointList.endpoints. - """ - pass - - @abstractmethod - def get_model_endpoint( - self, - metrics: typing.List[str] = None, - start: str = "now-1h", - end: str = "now", - feature_analysis: bool = False, - endpoint_id: str = None, - ) -> mlrun.api.schemas.ModelEndpoint: - """ - Get a single model endpoint object. You can apply different time series metrics that will be added to the - result. - - :param endpoint_id: The unique id of the model endpoint. - :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or - `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the - earliest time. - :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or - `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the - earliest time. - :param metrics: A list of metrics to return for the model endpoint. There are pre-defined metrics for - model endpoints such as predictions_per_second and latency_avg_5m but also custom - metrics defined by the user. Please note that these metrics are stored in the time - series DB and the results will be appeared under model_endpoint.spec.metrics. - :param feature_analysis: When True, the base feature statistics and current feature statistics will be added to - the output of the resulting object. - - :return: A ModelEndpoint object. - """ - pass - - @abstractmethod - def list_model_endpoints( - self, model: str, function: str, labels: typing.List, top_level: bool - ): - """ - Returns a list of endpoint unique ids, supports filtering by model, function, - labels or top level. By default, when no filters are applied, all available endpoint ids for the given project - will be listed. - - :param model: The name of the model to filter by. - :param function: The name of the function to filter by. - :param labels: A list of labels to filter by. Label filters work by either filtering a specific value - of a label (i.e. list("key==value")) or by looking for the existence of a given - key (i.e. "key"). - :param top_level: If True will return only routers and endpoint that are NOT children of any router. - - :return: List of model endpoints unique ids. - """ - pass - - -class _ModelEndpointKVStore(_ModelEndpointStore): - """ - Handles the DB operations when the DB target is from type KV. For the KV operations, we use an instance of V3IO - client and usually the KV table can be found under v3io:///users/pipelines/project-name/model-endpoints/endpoints/. - """ - - def __init__(self, access_key: str, project: str): - super().__init__(project=project) - # Initialize a V3IO client instance - self.access_key = access_key - self.client = mlrun.utils.v3io_clients.get_v3io_client( - endpoint=mlrun.mlconf.v3io_api, access_key=self.access_key - ) - # Get the KV table path and container - self.path, self.container = self._get_path_and_container() - - def write_model_endpoint(self, endpoint: mlrun.api.schemas.ModelEndpoint): - """ - Create a new endpoint record in the KV table. - - :param endpoint: ModelEndpoint object that will be written into the DB. - """ - - # Flatten the model endpoint structure in order to write it into the DB table. - # More details about the model endpoint available attributes can be found under - # :py:class:`~mlrun.api.schemas.ModelEndpoint`.` - attributes = self.flatten_model_endpoint_attributes(endpoint) - - # Create or update the model endpoint record - self.client.kv.put( - container=self.container, - table_path=self.path, - key=endpoint.metadata.uid, - attributes=attributes, - ) - - def update_model_endpoint(self, endpoint_id: str, attributes: dict): - """ - Update a model endpoint record with a given attributes. - - :param endpoint_id: The unique id of the model endpoint. - :param attributes: Dictionary of attributes that will be used for update the model endpoint. Note that the keys - of the attributes dictionary should exist in the KV table. More details about the model - endpoint available attributes can be found under - :py:class:`~mlrun.api.schemas.ModelEndpoint`. - - """ - - self.client.kv.update( - container=self.container, - table_path=self.path, - key=endpoint_id, - attributes=attributes, - ) - - logger.info("Model endpoint table updated", endpoint_id=endpoint_id) - - def delete_model_endpoint( - self, - endpoint_id: str, - ): - """ - Deletes the KV record of a given model endpoint id. - - :param endpoint_id: The unique id of the model endpoint. - """ - - self.client.kv.delete( - container=self.container, - table_path=self.path, - key=endpoint_id, - ) - - logger.info("Model endpoint table cleared", endpoint_id=endpoint_id) - - def get_model_endpoint( - self, - endpoint_id: str = None, - start: str = "now-1h", - end: str = "now", - metrics: typing.List[str] = None, - feature_analysis: bool = False, - ) -> mlrun.api.schemas.ModelEndpoint: - """ - Get a single model endpoint object. You can apply different time series metrics that will be added to the - result. - - :param endpoint_id: The unique id of the model endpoint. - :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or - `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the - earliest time. - :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or - `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the - earliest time. - :param metrics: A list of metrics to return for the model endpoint. There are pre-defined metrics for - model endpoints such as predictions_per_second and latency_avg_5m but also custom - metrics defined by the user. Please note that these metrics are stored in the time - series DB and the results will be appeared under model_endpoint.spec.metrics. - :param feature_analysis: When True, the base feature statistics and current feature statistics will be added to - the output of the resulting object. - - :return: A ModelEndpoint object. - """ - logger.info( - "Getting model endpoint record from kv", - endpoint_id=endpoint_id, - ) - - # Getting the raw data from the KV table - endpoint = self.client.kv.get( - container=self.container, - table_path=self.path, - key=endpoint_id, - raise_for_status=v3io.dataplane.RaiseForStatus.never, - access_key=self.access_key, - ) - endpoint = endpoint.output.item - - if not endpoint: - raise mlrun.errors.MLRunNotFoundError(f"Endpoint {endpoint_id} not found") - - # Generate a model endpoint object from the model endpoint KV record - endpoint_obj = self._convert_into_model_endpoint_object( - endpoint, start, end, metrics, feature_analysis - ) - - return endpoint_obj - - def _convert_into_model_endpoint_object( - self, endpoint, start, end, metrics, feature_analysis - ): - """ - Create a ModelEndpoint object according to a provided endpoint record from the DB. - - :param endpoint: KV record of model endpoint which need to be converted into a valid ModelEndpoint - object. - :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or - `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the - earliest time. - :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or - `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the - earliest time. - :param metrics: A list of metrics to return for the model endpoint. There are pre-defined metrics for - model endpoints such as predictions_per_second and latency_avg_5m but also custom - metrics defined by the user. Please note that these metrics are stored in the time - series DB and the results will be appeared under model_endpoint.spec.metrics. - :param feature_analysis: When True, the base feature statistics and current feature statistics will be added to - the output of the resulting object. - - :return: A ModelEndpoint object. - """ - - # Parse JSON values into a dictionary - feature_names = self._json_loads_if_not_none(endpoint.get("feature_names")) - label_names = self._json_loads_if_not_none(endpoint.get("label_names")) - feature_stats = self._json_loads_if_not_none(endpoint.get("feature_stats")) - current_stats = self._json_loads_if_not_none(endpoint.get("current_stats")) - children = self._json_loads_if_not_none(endpoint.get("children")) - monitor_configuration = self._json_loads_if_not_none( - endpoint.get("monitor_configuration") - ) - endpoint_type = self._json_loads_if_not_none(endpoint.get("endpoint_type")) - children_uids = self._json_loads_if_not_none(endpoint.get("children_uids")) - labels = self._json_loads_if_not_none(endpoint.get("labels")) - - # Convert into model endpoint object - endpoint_obj = mlrun.api.schemas.ModelEndpoint( - metadata=mlrun.api.schemas.ModelEndpointMetadata( - project=endpoint.get("project"), - labels=labels, - uid=endpoint.get("endpoint_id"), - ), - spec=mlrun.api.schemas.ModelEndpointSpec( - function_uri=endpoint.get("function_uri"), - model=endpoint.get("model"), - model_class=endpoint.get("model_class"), - model_uri=endpoint.get("model_uri"), - feature_names=feature_names or None, - label_names=label_names or None, - stream_path=endpoint.get("stream_path"), - algorithm=endpoint.get("algorithm"), - monitor_configuration=monitor_configuration or None, - active=endpoint.get("active"), - monitoring_mode=endpoint.get("monitoring_mode"), - ), - status=mlrun.api.schemas.ModelEndpointStatus( - state=endpoint.get("state") or None, - feature_stats=feature_stats or None, - current_stats=current_stats or None, - children=children or None, - first_request=endpoint.get("first_request"), - last_request=endpoint.get("last_request"), - accuracy=endpoint.get("accuracy"), - error_count=endpoint.get("error_count"), - drift_status=endpoint.get("drift_status"), - endpoint_type=endpoint_type or None, - children_uids=children_uids or None, - monitoring_feature_set_uri=endpoint.get("monitoring_feature_set_uri") - or None, - ), - ) - - # If feature analysis was applied, add feature stats and current stats to the model endpoint result - if feature_analysis and feature_names: - endpoint_features = self.get_endpoint_features( - feature_names=feature_names, - feature_stats=feature_stats, - current_stats=current_stats, - ) - if endpoint_features: - endpoint_obj.status.features = endpoint_features - # Add the latest drift measures results (calculated by the model monitoring batch) - drift_measures = self._json_loads_if_not_none( - endpoint.get("drift_measures") - ) - endpoint_obj.status.drift_measures = drift_measures - - # If time metrics were provided, retrieve the results from the time series DB - if metrics: - endpoint_metrics = self.get_endpoint_metrics( - endpoint_id=endpoint_obj.metadata.uid, - start=start, - end=end, - metrics=metrics, - ) - if endpoint_metrics: - endpoint_obj.status.metrics = endpoint_metrics - - return endpoint_obj - - def _get_path_and_container(self): - """Getting path and container based on the model monitoring configurations""" - path = mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( - project=self.project, - kind=mlrun.api.schemas.ModelMonitoringStoreKinds.ENDPOINTS, - ) - ( - _, - container, - path, - ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(path) - return path, container - - def list_model_endpoints( - self, model: str, function: str, labels: typing.List, top_level: bool - ): - """ - Returns a list of endpoint unique ids, supports filtering by model, function, - labels or top level. By default, when no filters are applied, all available endpoint ids for the given project - will be listed. - - :param model: The name of the model to filter by. - :param function: The name of the function to filter by. - :param labels: A list of labels to filter by. Label filters work by either filtering a specific value - of a label (i.e. list("key==value")) or by looking for the existence of a given - key (i.e. "key"). - :param top_level: If True will return only routers and endpoint that are NOT children of any router. - - :return: List of model endpoints unique ids. - """ - - # Retrieve the raw data from the KV table and get the endpoint ids - cursor = self.client.kv.new_cursor( - container=self.container, - table_path=self.path, - filter_expression=self.build_kv_cursor_filter_expression( - self.project, - function, - model, - labels, - top_level, - ), - attribute_names=["endpoint_id"], - raise_for_status=v3io.dataplane.RaiseForStatus.never, - ) - try: - items = cursor.all() - except Exception: - return [] - - # Create a list of model endpoints unique ids - uids = [item["endpoint_id"] for item in items] - - return uids - - def delete_model_endpoints_resources( - self, endpoints: mlrun.api.schemas.model_endpoints.ModelEndpointList - ): - """ - Delete all model endpoints resources in both KV and the time series DB. - - :param endpoints: An object of ModelEndpointList which is literally a list of model endpoints along with some - metadata. To get a standard list of model endpoints use ModelEndpointList.endpoints. - """ - - # Delete model endpoint record from KV table - for endpoint in endpoints.endpoints: - self.delete_model_endpoint( - endpoint.metadata.uid, - ) - - # Delete remain records in the KV - all_records = self.client.kv.new_cursor( - container=self.container, - table_path=self.path, - raise_for_status=v3io.dataplane.RaiseForStatus.never, - ).all() - - all_records = [r["__name"] for r in all_records] - - # Cleanup KV - for record in all_records: - self.client.kv.delete( - container=self.container, - table_path=self.path, - key=record, - raise_for_status=v3io.dataplane.RaiseForStatus.never, - ) - - # Cleanup TSDB - frames = mlrun.utils.v3io_clients.get_frames_client( - token=self.access_key, - address=mlrun.mlconf.v3io_framesd, - container=self.container, - ) - - # Generate the required tsdb paths - tsdb_path, filtered_path = self._generate_tsdb_paths() - - # Delete time series DB resources - try: - frames.delete( - backend=model_monitoring_constants.StoreTarget.TSDB, - table=filtered_path, - if_missing=v3io_frames.frames_pb2.IGNORE, - ) - except v3io_frames.errors.CreateError: - # Frames might raise an exception if schema file does not exist. - pass - - # Final cleanup of tsdb path - tsdb_path.replace("://u", ":///u") - store, _ = mlrun.store_manager.get_or_create_store(tsdb_path) - store.rm(tsdb_path, recursive=True) - - def _generate_tsdb_paths(self) -> typing.Tuple[str, str]: - """Generate a short path to the TSDB resources and a filtered path for the frames object - - :return: A tuple of: - [0] = Short path to the TSDB resources - [1] = Filtered path to TSDB events without schema and container - """ - # Full path for the time series DB events - full_path = ( - mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( - project=self.project, - kind=mlrun.api.schemas.ModelMonitoringStoreKinds.EVENTS, - ) - ) - - # Generate the main directory with the TSDB resources - tsdb_path = mlrun.utils.model_monitoring.parse_model_endpoint_project_prefix( - full_path, self.project - ) - - # Generate filtered path without schema and container as required by the frames object - ( - _, - _, - filtered_path, - ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(full_path) - return tsdb_path, filtered_path - - @staticmethod - def build_kv_cursor_filter_expression( - project: str, - function: str = None, - model: str = None, - labels: typing.List[str] = None, - top_level: bool = False, - ) -> str: - """ - Convert the provided filters into a valid filter expression. The expected filter expression includes different - conditions, divided by ' AND '. - - :param project: The name of the project. - :param model: The name of the model to filter by. - :param function: The name of the function to filter by. - :param labels: A list of labels to filter by. Label filters work by either filtering a specific value of - a label (i.e. list("key==value")) or by looking for the existence of a given - key (i.e. "key"). - :param top_level: If True will return only routers and endpoint that are NOT children of any router. - - :return: A valid filter expression as a string. - """ - - if not project: - raise mlrun.errors.MLRunInvalidArgumentError("project can't be empty") - - # Add project filter - filter_expression = [f"project=='{project}'"] - - # Add function and model filters - if function: - filter_expression.append(f"function=='{function}'") - if model: - filter_expression.append(f"model=='{model}'") - - # Add labels filters - if labels: - for label in labels: - - if not label.startswith("_"): - label = f"_{label}" - - if "=" in label: - lbl, value = list(map(lambda x: x.strip(), label.split("="))) - filter_expression.append(f"{lbl}=='{value}'") - else: - filter_expression.append(f"exists({label})") - - # Apply top_level filter (remove endpoints that considered a child of a router) - if top_level: - filter_expression.append( - f"(endpoint_type=='{str(mlrun.utils.model_monitoring.EndpointType.NODE_EP.value)}' " - f"OR endpoint_type=='{str(mlrun.utils.model_monitoring.EndpointType.ROUTER.value)}')" - ) - - return " AND ".join(filter_expression) - - @staticmethod - def flatten_model_endpoint_attributes( - endpoint: mlrun.api.schemas.ModelEndpoint, - ) -> typing.Dict: - """ - Retrieving flatten structure of the model endpoint object. - - :param endpoint: ModelEndpoint object that will be used for getting the attributes. - - :return: A flat dictionary of attributes. - """ - - # Prepare the data for the attributes dictionary - labels = endpoint.metadata.labels or {} - searchable_labels = {f"_{k}": v for k, v in labels.items()} - feature_names = endpoint.spec.feature_names or [] - label_names = endpoint.spec.label_names or [] - feature_stats = endpoint.status.feature_stats or {} - current_stats = endpoint.status.current_stats or {} - children = endpoint.status.children or [] - endpoint_type = endpoint.status.endpoint_type or None - children_uids = endpoint.status.children_uids or [] - - # Fill the data. Note that because it is a flat dictionary, we use json.dumps() for encoding hierarchies - # such as current_stats or label_names - attributes = { - "endpoint_id": endpoint.metadata.uid, - "project": endpoint.metadata.project, - "function_uri": endpoint.spec.function_uri, - "model": endpoint.spec.model, - "model_class": endpoint.spec.model_class or "", - "labels": json.dumps(labels), - "model_uri": endpoint.spec.model_uri or "", - "stream_path": endpoint.spec.stream_path or "", - "active": endpoint.spec.active or "", - "monitoring_feature_set_uri": endpoint.status.monitoring_feature_set_uri - or "", - "monitoring_mode": endpoint.spec.monitoring_mode or "", - "state": endpoint.status.state or "", - "feature_stats": json.dumps(feature_stats), - "current_stats": json.dumps(current_stats), - "feature_names": json.dumps(feature_names), - "children": json.dumps(children), - "label_names": json.dumps(label_names), - "endpoint_type": json.dumps(endpoint_type), - "children_uids": json.dumps(children_uids), - **searchable_labels, - } - return attributes - - @staticmethod - def _json_loads_if_not_none(field: typing.Any) -> typing.Any: - return json.loads(field) if field is not None else None - - @staticmethod - def get_endpoint_features( - feature_names: typing.List[str], - feature_stats: dict = None, - current_stats: dict = None, - ) -> typing.List[mlrun.api.schemas.Features]: - """ - Getting a new list of features that exist in feature_names along with their expected (feature_stats) and - actual (current_stats) stats. The expected stats were calculated during the creation of the model endpoint, - usually based on the data from the Model Artifact. The actual stats are based on the results from the latest - model monitoring batch job. - - param feature_names: List of feature names. - param feature_stats: Dictionary of feature stats that were stored during the creation of the model endpoint - object. - param current_stats: Dictionary of the latest stats that were stored during the last run of the model monitoring - batch job. - - return: List of feature objects. Each feature has a name, weight, expected values, and actual values. More info - can be found under mlrun.api.schemas.Features. - """ - - # Initialize feature and current stats dictionaries - safe_feature_stats = feature_stats or {} - safe_current_stats = current_stats or {} - - # Create feature object and add it to a general features list - features = [] - for name in feature_names: - if feature_stats is not None and name not in feature_stats: - logger.warn("Feature missing from 'feature_stats'", name=name) - if current_stats is not None and name not in current_stats: - logger.warn("Feature missing from 'current_stats'", name=name) - f = mlrun.api.schemas.Features.new( - name, safe_feature_stats.get(name), safe_current_stats.get(name) - ) - features.append(f) - return features - - def get_endpoint_metrics( - self, - endpoint_id: str, - metrics: typing.List[str], - start: str = "now-1h", - end: str = "now", - ) -> typing.Dict[str, mlrun.api.schemas.Metric]: - """ - Getting metrics from the time series DB. There are pre-defined metrics for model endpoints such as - predictions_per_second and latency_avg_5m but also custom metrics defined by the user. - - :param endpoint_id: The unique id of the model endpoint. - :param metrics: A list of metrics to return for the model endpoint. - :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or - `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the - earliest time. - :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or - `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the - earliest time. - - :return: A dictionary of metrics in which the key is a metric name and the value is a Metric object that also - includes the relevant timestamp. More details about the Metric object can be found under - mlrun.api.schemas.Metric. - """ - - if not metrics: - raise mlrun.errors.MLRunInvalidArgumentError( - "Metric names must be provided" - ) - - # Initialize metrics mapping dictionary - metrics_mapping = {} - - # Getting the path for the time series DB - events_path = ( - mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( - project=self.project, - kind=mlrun.api.schemas.ModelMonitoringStoreKinds.EVENTS, - ) - ) - ( - _, - _, - events_path, - ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(events_path) - - # Retrieve the raw data from the time series DB based on the provided metrics and time ranges - frames_client = mlrun.utils.v3io_clients.get_frames_client( - token=self.access_key, - address=mlrun.mlconf.v3io_framesd, - container=self.container, - ) - - try: - data = frames_client.read( - backend=model_monitoring_constants.StoreTarget.TSDB, - table=events_path, - columns=["endpoint_id", *metrics], - filter=f"endpoint_id=='{endpoint_id}'", - start=start, - end=end, - ) - - # Fill the metrics mapping dictionary with the metric name and values - data_dict = data.to_dict() - for metric in metrics: - metric_data = data_dict.get(metric) - if metric_data is None: - continue - - values = [ - (str(timestamp), value) for timestamp, value in metric_data.items() - ] - metrics_mapping[metric] = mlrun.api.schemas.Metric( - name=metric, values=values - ) - except v3io_frames.errors.ReadError: - logger.warn("Failed to read tsdb", endpoint=endpoint_id) - return metrics_mapping - - -class _ModelEndpointSQLStore(_ModelEndpointStore): - def write_model_endpoint(self, endpoint, update=True): - raise NotImplementedError - - def update_model_endpoint(self, endpoint_id, attributes): - raise NotImplementedError - - def delete_model_endpoint(self, endpoint_id): - raise NotImplementedError - - def delete_model_endpoints_resources( - self, endpoints: mlrun.api.schemas.model_endpoints.ModelEndpointList - ): - raise NotImplementedError - - def get_model_endpoint( - self, - metrics: typing.List[str] = None, - start: str = "now-1h", - end: str = "now", - feature_analysis: bool = False, - endpoint_id: str = None, - ): - raise NotImplementedError - - def list_model_endpoints( - self, model: str, function: str, labels: typing.List, top_level: bool - ): - raise NotImplementedError - - -class ModelEndpointStoreType(enum.Enum): - """Enum class to handle the different store type values for saving a model endpoint record.""" - - kv = "kv" - sql = "sql" - - def to_endpoint_target( - self, project: str, access_key: str = None - ) -> _ModelEndpointStore: - """ - Return a ModelEndpointStore object based on the provided enum value. - - :param project: The name of the project. - :param access_key: Access key with permission to the DB table. Note that if access key is None and the - endpoint target is from type KV then the access key will be retrieved from the environment - variable. - - :return: ModelEndpointStore object. - - """ - - if self.value == ModelEndpointStoreType.kv.value: - - # Get V3IO access key from env - access_key = ( - mlrun.mlconf.get_v3io_access_key() if access_key is None else access_key - ) - - return _ModelEndpointKVStore(project=project, access_key=access_key) - - # Assuming SQL store target if store type is not KV. - # Update these lines once there are more than two store target types. - return _ModelEndpointSQLStore(project=project) - - @classmethod - def _missing_(cls, value: typing.Any): - """A lookup function to handle an invalid value. - :param value: Provided enum (invalid) value. - """ - valid_values = list(cls.__members__.keys()) - raise mlrun.errors.MLRunInvalidArgumentError( - "%r is not a valid %s, please choose a valid value: %s." - % (value, cls.__name__, valid_values) - ) - - -def get_model_endpoint_target( - project: str, access_key: str = None -) -> _ModelEndpointStore: - """ - Getting the DB target type based on mlrun.config.model_endpoint_monitoring.store_type. - - :param project: The name of the project. - :param access_key: Access key with permission to the DB table. - - :return: ModelEndpointStore object. Using this object, the user can apply different operations on the - model endpoint record such as write, update, get and delete. - """ - - # Get store type value from ModelEndpointStoreType enum class - model_endpoint_store_type = ModelEndpointStoreType( - mlrun.mlconf.model_endpoint_monitoring.store_type - ) - - # Convert into model endpoint store target object - return model_endpoint_store_type.to_endpoint_target(project, access_key) diff --git a/mlrun/api/crud/model_monitoring/model_endpoints.py b/mlrun/api/crud/model_monitoring/model_endpoints.py index 4b3b5365aab3..086e3253ec12 100644 --- a/mlrun/api/crud/model_monitoring/model_endpoints.py +++ b/mlrun/api/crud/model_monitoring/model_endpoints.py @@ -12,33 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. # - - +import json import os import typing +import warnings import sqlalchemy.orm import mlrun.api.api.endpoints.functions import mlrun.api.api.utils -import mlrun.api.schemas -import mlrun.api.schemas.model_endpoints +import mlrun.api.crud.runtimes.nuclio.function import mlrun.api.utils.singletons.k8s import mlrun.artifacts +import mlrun.common.model_monitoring as model_monitoring_constants +import mlrun.common.schemas +import mlrun.common.schemas.model_endpoints import mlrun.config import mlrun.datastore.store_resources import mlrun.errors import mlrun.feature_store -import mlrun.model_monitoring.constants as model_monitoring_constants import mlrun.model_monitoring.helpers -import mlrun.runtimes.function import mlrun.utils.helpers import mlrun.utils.model_monitoring import mlrun.utils.v3io_clients +from mlrun.model_monitoring.stores import get_model_endpoint_store from mlrun.utils import logger -from .model_endpoint_store import get_model_endpoint_target - class ModelEndpoints: """Provide different methods for handling model endpoints such as listing, writing and deleting""" @@ -47,12 +46,17 @@ def create_or_patch( self, db_session: sqlalchemy.orm.Session, access_key: str, - model_endpoint: mlrun.api.schemas.ModelEndpoint, - auth_info: mlrun.api.schemas.AuthInfo = mlrun.api.schemas.AuthInfo(), - ) -> mlrun.api.schemas.ModelEndpoint: - # TODO: deprecated, remove in 1.5.0. + model_endpoint: mlrun.common.schemas.ModelEndpoint, + auth_info: mlrun.common.schemas.AuthInfo = mlrun.common.schemas.AuthInfo(), + ) -> mlrun.common.schemas.ModelEndpoint: + # TODO: deprecated in 1.3.0, remove in 1.5.0. + warnings.warn( + "This is deprecated in 1.3.0, and will be removed in 1.5.0." + "Please use create_model_endpoint() for create or patch_model_endpoint() for update", + FutureWarning, + ) """ - Either create or updates the record of a given ModelEndpoint object. + Either create or updates the record of a given `ModelEndpoint` object. Leaving here for backwards compatibility, remove in 1.5.0. :param db_session: A session that manages the current dialog with the database @@ -60,7 +64,7 @@ def create_or_patch( :param model_endpoint: Model endpoint object to update :param auth_info: The auth info of the request - :return: Model endpoint object. + :return: `ModelEndpoint` object. """ return self.create_model_endpoint( @@ -70,16 +74,16 @@ def create_or_patch( def create_model_endpoint( self, db_session: sqlalchemy.orm.Session, - model_endpoint: mlrun.api.schemas.ModelEndpoint, - ) -> mlrun.api.schemas.ModelEndpoint: + model_endpoint: mlrun.common.schemas.ModelEndpoint, + ) -> mlrun.common.schemas.ModelEndpoint: """ Creates model endpoint record in DB. The DB target type is defined under - mlrun.config.model_endpoint_monitoring.store_type (KV by default). + `mlrun.config.model_endpoint_monitoring.store_type` (V3IO-NOSQL by default). :param db_session: A session that manages the current dialog with the database. :param model_endpoint: Model endpoint object to update. - :return: Model endpoint object. + :return: `ModelEndpoint` object. """ if model_endpoint.spec.model_uri or model_endpoint.status.feature_stats: @@ -107,23 +111,22 @@ def create_model_endpoint( if not model_endpoint.status.feature_stats and hasattr( model_obj, "feature_stats" ): - model_endpoint.status.feature_stats = model_obj.feature_stats - + model_endpoint.status.feature_stats = model_obj.spec.feature_stats # Get labels from model object if not found in model endpoint object - if not model_endpoint.spec.label_names and hasattr(model_obj, "outputs"): + if not model_endpoint.spec.label_names and model_obj.spec.outputs: model_label_names = [ - self._clean_feature_name(f.name) for f in model_obj.outputs + self._clean_feature_name(f.name) for f in model_obj.spec.outputs ] model_endpoint.spec.label_names = model_label_names # Get algorithm from model object if not found in model endpoint object - if not model_endpoint.spec.algorithm and hasattr(model_obj, "algorithm"): - model_endpoint.spec.algorithm = model_obj.algorithm + if not model_endpoint.spec.algorithm and model_obj.spec.algorithm: + model_endpoint.spec.algorithm = model_obj.spec.algorithm # Create monitoring feature set if monitoring found in model endpoint object if ( model_endpoint.spec.monitoring_mode - == mlrun.api.schemas.ModelMonitoringMode.enabled.value + == mlrun.common.model_monitoring.ModelMonitoringMode.enabled.value ): monitoring_feature_set = self.create_monitoring_feature_set( model_endpoint, model_obj, db_session, run_db @@ -158,18 +161,18 @@ def create_model_endpoint( logger.info("Creating model endpoint", endpoint_id=model_endpoint.metadata.uid) # Write the new model endpoint - model_endpoint_target = get_model_endpoint_target( + model_endpoint_store = get_model_endpoint_store( project=model_endpoint.metadata.project, ) - model_endpoint_target.write_model_endpoint(endpoint=model_endpoint) + model_endpoint_store.write_model_endpoint(endpoint=model_endpoint.flat_dict()) logger.info("Model endpoint created", endpoint_id=model_endpoint.metadata.uid) return model_endpoint - @staticmethod def create_monitoring_feature_set( - model_endpoint: mlrun.api.schemas.ModelEndpoint, + self, + model_endpoint: mlrun.common.schemas.ModelEndpoint, model_obj: mlrun.artifacts.ModelArtifact, db_session: sqlalchemy.orm.Session, run_db: mlrun.db.sqldb.SQLDB, @@ -195,29 +198,29 @@ def create_monitoring_feature_set( feature_set = mlrun.feature_store.FeatureSet( f"monitoring-{serving_function_name}-{model_name}", - entities=["endpoint_id"], - timestamp_key="timestamp", + entities=[model_monitoring_constants.EventFieldType.ENDPOINT_ID], + timestamp_key=model_monitoring_constants.EventFieldType.TIMESTAMP, description=f"Monitoring feature set for endpoint: {model_endpoint.spec.model}", ) feature_set.metadata.project = model_endpoint.metadata.project feature_set.metadata.labels = { - "endpoint_id": model_endpoint.metadata.uid, - "model_class": model_endpoint.spec.model_class, + model_monitoring_constants.EventFieldType.ENDPOINT_ID: model_endpoint.metadata.uid, + model_monitoring_constants.EventFieldType.MODEL_CLASS: model_endpoint.spec.model_class, } # Add features to the feature set according to the model object - if model_obj.inputs.values(): - for feature in model_obj.inputs.values(): + if model_obj.spec.inputs: + for feature in model_obj.spec.inputs: feature_set.add_feature( mlrun.feature_store.Feature( name=feature.name, value_type=feature.value_type ) ) # Check if features can be found within the feature vector - elif model_obj.feature_vector: + elif model_obj.spec.feature_vector: _, name, _, tag, _ = mlrun.utils.helpers.parse_artifact_uri( - model_obj.feature_vector + model_obj.spec.feature_vector ) fv = run_db.get_feature_vector( name=name, project=model_endpoint.metadata.project, tag=tag @@ -236,16 +239,22 @@ def create_monitoring_feature_set( # Define parquet target for this feature set parquet_path = ( - f"v3io:///projects/{model_endpoint.metadata.project}" - f"/model-endpoints/parquet/key={model_endpoint.metadata.uid}" + self._get_monitoring_parquet_path( + db_session=db_session, project=model_endpoint.metadata.project + ) + + f"/key={model_endpoint.metadata.uid}" + ) + + parquet_target = mlrun.datastore.targets.ParquetTarget( + model_monitoring_constants.FileTargetKind.PARQUET, parquet_path ) - parquet_target = mlrun.datastore.targets.ParquetTarget("parquet", parquet_path) driver = mlrun.datastore.targets.get_target_driver(parquet_target, feature_set) - driver.update_resource_status("created") + feature_set.set_targets( [mlrun.datastore.targets.ParquetTarget(path=parquet_path)], with_defaults=False, ) + driver.update_resource_status("created") # Save the new feature set feature_set._override_run_db(db_session) @@ -258,10 +267,39 @@ def create_monitoring_feature_set( return feature_set + @staticmethod + def _get_monitoring_parquet_path( + db_session: sqlalchemy.orm.Session, project: str + ) -> str: + """Getting model monitoring parquet target for the current project. The parquet target path is based on the + project artifact path. If project artifact path is not defined, the parquet target path will be based on MLRun + artifact path. + + :param db_session: A session that manages the current dialog with the database. Will be used in this function + to get the project record from DB. + :param project: Project name. + + :return: Monitoring parquet target path. + """ + + # Get the artifact path from the project record that was stored in the DB + project_obj = mlrun.api.crud.projects.Projects().get_project( + session=db_session, name=project + ) + artifact_path = project_obj.spec.artifact_path + # Generate monitoring parquet path value + parquet_path = mlrun.mlconf.get_model_monitoring_file_target_path( + project=project, + kind=model_monitoring_constants.FileTargetKind.PARQUET, + target="offline", + artifact_path=artifact_path, + ) + return parquet_path + @staticmethod def _validate_length_features_and_labels(model_endpoint): """ - Validate that the length of feature_stats is equal to the length of feature_names and label_names + Validate that the length of feature_stats is equal to the length of `feature_names` and `label_names` :param model_endpoint: An object representing the model endpoint. """ @@ -288,8 +326,8 @@ def _adjust_feature_names_and_stats( self, model_endpoint ) -> typing.Tuple[typing.Dict, typing.List]: """ - Create a clean matching version of feature names for both feature_stats and feature_names. Please note that - label names exist only in feature_stats and label_names. + Create a clean matching version of feature names for both `feature_stats` and `feature_names`. Please note that + label names exist only in `feature_stats` and `label_names`. :param model_endpoint: An object representing the model endpoint. :return: A tuple of: @@ -312,36 +350,42 @@ def _adjust_feature_names_and_stats( clean_feature_names.append(clean_name) return clean_feature_stats, clean_feature_names - @staticmethod def patch_model_endpoint( + self, project: str, endpoint_id: str, attributes: dict, - ) -> mlrun.api.schemas.ModelEndpoint: + ) -> mlrun.common.schemas.ModelEndpoint: """ Update a model endpoint record with a given attributes. :param project: The name of the project. :param endpoint_id: The unique id of the model endpoint. :param attributes: Dictionary of attributes that will be used for update the model endpoint. Note that the keys - of the attributes dictionary should exist in the KV table. More details about the model + of the attributes dictionary should exist in the DB table. More details about the model endpoint available attributes can be found under - :py:class:`~mlrun.api.schemas.ModelEndpoint`. + :py:class:`~mlrun.common.schemas.ModelEndpoint`. - :return: A patched ModelEndpoint object. + :return: A patched `ModelEndpoint` object. """ - model_endpoint_target = get_model_endpoint_target( + # Generate a model endpoint store object and apply the update process + model_endpoint_store = get_model_endpoint_store( project=project, ) - model_endpoint_target.update_model_endpoint( + model_endpoint_store.update_model_endpoint( endpoint_id=endpoint_id, attributes=attributes ) - return model_endpoint_target.get_model_endpoint( - endpoint_id=endpoint_id, start="now-1h", end="now" + logger.info("Model endpoint table updated", endpoint_id=endpoint_id) + + # Get the patched model endpoint record + model_endpoint_record = model_endpoint_store.get_model_endpoint( + endpoint_id=endpoint_id, ) + return self._convert_into_model_endpoint_object(endpoint=model_endpoint_record) + @staticmethod def delete_model_endpoint( project: str, @@ -353,59 +397,83 @@ def delete_model_endpoint( :param project: The name of the project. :param endpoint_id: The id of the endpoint. """ - model_endpoint_target = get_model_endpoint_target( + model_endpoint_store = get_model_endpoint_store( project=project, ) - model_endpoint_target.delete_model_endpoint(endpoint_id=endpoint_id) - @staticmethod + model_endpoint_store.delete_model_endpoint(endpoint_id=endpoint_id) + + logger.info("Model endpoint table cleared", endpoint_id=endpoint_id) + def get_model_endpoint( - auth_info: mlrun.api.schemas.AuthInfo, + self, + auth_info: mlrun.common.schemas.AuthInfo, project: str, endpoint_id: str, metrics: typing.List[str] = None, start: str = "now-1h", end: str = "now", feature_analysis: bool = False, - ) -> mlrun.api.schemas.ModelEndpoint: + ) -> mlrun.common.schemas.ModelEndpoint: """Get a single model endpoint object. You can apply different time series metrics that will be added to the result. - :param auth_info: The auth info of the request - :param project: The name of the project - :param endpoint_id: The unique id of the model endpoint. - :param metrics: A list of metrics to return for the model endpoint. There are pre-defined metrics for - model endpoints such as predictions_per_second and latency_avg_5m but also custom - metrics defined by the user. Please note that these metrics are stored in the time - series DB and the results will be appeared under model_endpoint.spec.metrics. - :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or - `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = - days), or 0 for the earliest time. - :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or - `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = - days), or 0 for the earliest time. - :param feature_analysis: When True, the base feature statistics and current feature statistics will be added to - the output of the resulting object. - - :return: A ModelEndpoint object. + :param auth_info: The auth info of the request + :param project: The name of the project + :param endpoint_id: The unique id of the model endpoint. + :param metrics: A list of metrics to return for the model endpoint. There are pre-defined + metrics for model endpoints such as predictions_per_second and + latency_avg_5m but also custom metrics defined by the user. Please note that + these metrics are stored in the time series DB and the results will be + appeared under `model_endpoint.spec.metrics`. + :param start: The start time of the metrics. Can be represented by a string containing an + RFC 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or + 0 for the earliest time. + :param end: The end time of the metrics. Can be represented by a string containing an + RFC 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or + 0 for the earliest time. + :param feature_analysis: When True, the base feature statistics and current feature statistics will + be added to the output of the resulting object. + + :return: A `ModelEndpoint` object. """ - model_endpoint_target = get_model_endpoint_target( + logger.info( + "Getting model endpoint record from DB", + endpoint_id=endpoint_id, + ) + + # Generate a model endpoint store object and get the model endpoint record as a dictionary + model_endpoint_store = get_model_endpoint_store( project=project, access_key=auth_info.data_session ) - return model_endpoint_target.get_model_endpoint( + + model_endpoint_record = model_endpoint_store.get_model_endpoint( endpoint_id=endpoint_id, - metrics=metrics, - start=start, - end=end, - feature_analysis=feature_analysis, ) - @staticmethod + # Convert to `ModelEndpoint` object + model_endpoint_object = self._convert_into_model_endpoint_object( + endpoint=model_endpoint_record, feature_analysis=feature_analysis + ) + + # If time metrics were provided, retrieve the results from the time series DB + if metrics: + self._add_real_time_metrics( + model_endpoint_store=model_endpoint_store, + model_endpoint_object=model_endpoint_object, + metrics=metrics, + start=start, + end=end, + ) + + return model_endpoint_object + def list_model_endpoints( - auth_info: mlrun.api.schemas.AuthInfo, + self, + auth_info: mlrun.common.schemas.AuthInfo, project: str, model: str = None, function: str = None, @@ -415,10 +483,11 @@ def list_model_endpoints( end: str = "now", top_level: bool = False, uids: typing.List[str] = None, - ) -> mlrun.api.schemas.model_endpoints.ModelEndpointList: + ) -> mlrun.common.schemas.ModelEndpointList: """ - Returns a list of ModelEndpointState objects. Each object represents the current state of a model endpoint. - This functions supports filtering by the following parameters: + Returns a list of `ModelEndpoint` objects, wrapped in `ModelEndpointList` object. Each `ModelEndpoint` + object represents the current state of a model endpoint. This functions supports filtering by the following + parameters: 1) model 2) function 3) labels @@ -435,22 +504,22 @@ def list_model_endpoints( :param model: The name of the model to filter by. :param function: The name of the function to filter by. :param labels: A list of labels to filter by. Label filters work by either filtering a specific value of a - label (i.e. list("key==value")) or by looking for the existence of a given key (i.e. "key"). + label (i.e. list("key=value")) or by looking for the existence of a given key (i.e. "key"). :param metrics: A list of metrics to return for each endpoint. There are pre-defined metrics for model - endpoints such as predictions_per_second and latency_avg_5m but also custom metrics defined - by the user. Please note that these metrics are stored in the time series DB and the results - will be appeared under model_endpoint.spec.metrics of each endpoint. + endpoints such as `predictions_per_second` and `latency_avg_5m` but also custom metrics + defined by the user. Please note that these metrics are stored in the time series DB and the + results will be appeared under model_endpoint.spec.metrics of each endpoint. :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the earliest time. :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the earliest time. - :param top_level: If True will return only routers and endpoint that are NOT children of any router. - :param uids: Will return ModelEndpointList of endpoints with uid in uids. + :param top_level: If True, return only routers and endpoints that are NOT children of any router. + :param uids: List of model endpoint unique ids to include in the result. - :return: An object of ModelEndpointList which is literally a list of model endpoints along with some metadata. - To get a standard list of model endpoints use ModelEndpointList.endpoints. + :return: An object of `ModelEndpointList` which is literally a list of model endpoints along with some metadata. + To get a standard list of model endpoints use `ModelEndpointList.endpoints`. """ logger.info( @@ -466,39 +535,181 @@ def list_model_endpoints( uids=uids, ) - endpoint_target = get_model_endpoint_target( + # Initialize an empty model endpoints list + endpoint_list = mlrun.common.schemas.model_endpoints.ModelEndpointList( + endpoints=[] + ) + + # Generate a model endpoint store object and get a list of model endpoint dictionaries + endpoint_store = get_model_endpoint_store( access_key=auth_info.data_session, project=project ) - # Initialize an empty model endpoints list - endpoint_list = mlrun.api.schemas.model_endpoints.ModelEndpointList( - endpoints=[] + endpoint_dictionary_list = endpoint_store.list_model_endpoints( + function=function, + model=model, + labels=labels, + top_level=top_level, + uids=uids, ) - # If list of model endpoint ids was not provided, retrieve it from the DB - if uids is None: - uids = endpoint_target.list_model_endpoints( - function=function, model=model, labels=labels, top_level=top_level - ) + for endpoint_dict in endpoint_dictionary_list: - # Add each relevant model endpoint to the model endpoints list - for endpoint_id in uids: - endpoint = endpoint_target.get_model_endpoint( - metrics=metrics, - endpoint_id=endpoint_id, - start=start, - end=end, + # Convert to `ModelEndpoint` object + endpoint_obj = self._convert_into_model_endpoint_object( + endpoint=endpoint_dict ) - endpoint_list.endpoints.append(endpoint) + + # If time metrics were provided, retrieve the results from the time series DB + if metrics: + self._add_real_time_metrics( + model_endpoint_store=endpoint_store, + model_endpoint_object=endpoint_obj, + metrics=metrics, + start=start, + end=end, + ) + + # Add the `ModelEndpoint` object into the model endpoints list + endpoint_list.endpoints.append(endpoint_obj) return endpoint_list + @staticmethod + def _add_real_time_metrics( + model_endpoint_store: mlrun.model_monitoring.stores.ModelEndpointStore, + model_endpoint_object: mlrun.common.schemas.ModelEndpoint, + metrics: typing.List[str] = None, + start: str = "now-1h", + end: str = "now", + ) -> mlrun.common.schemas.ModelEndpoint: + """Add real time metrics from the time series DB to a provided `ModelEndpoint` object. The real time metrics + will be stored under `ModelEndpoint.status.metrics.real_time` + + :param model_endpoint_store: `ModelEndpointStore` object that will be used for communicating with the database + and querying the required metrics. + :param model_endpoint_object: `ModelEndpoint` object that will be filled with the relevant + real time metrics. + :param metrics: A list of metrics to return for each endpoint. There are pre-defined metrics for + model endpoints such as `predictions_per_second` and `latency_avg_5m` but also + custom metrics defined by the user. Please note that these metrics are stored in + the time series DB and the results will be appeared under + model_endpoint.spec.metrics of each endpoint. + :param start: The start time of the metrics. Can be represented by a string containing an RFC + 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m`= minutes, `h` = hours, and `'d'` = days), or 0 + for the earliest time. + :param end: The end time of the metrics. Can be represented by a string containing an RFC + 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m`= minutes, `h` = hours, and `'d'` = days), or 0 + for the earliest time. + + """ + if model_endpoint_object.status.metrics is None: + model_endpoint_object.status.metrics = {} + + endpoint_metrics = model_endpoint_store.get_endpoint_real_time_metrics( + endpoint_id=model_endpoint_object.metadata.uid, + start=start, + end=end, + metrics=metrics, + ) + if endpoint_metrics: + model_endpoint_object.status.metrics[ + model_monitoring_constants.EventKeyMetrics.REAL_TIME + ] = endpoint_metrics + return model_endpoint_object + + def _convert_into_model_endpoint_object( + self, endpoint: typing.Dict[str, typing.Any], feature_analysis: bool = False + ) -> mlrun.common.schemas.ModelEndpoint: + """ + Create a `ModelEndpoint` object according to a provided model endpoint dictionary. + + :param endpoint: Dictinoary that represents a DB record of a model endpoint which need to be converted + into a valid `ModelEndpoint` object. + :param feature_analysis: When True, the base feature statistics and current feature statistics will be added to + the output of the resulting object. + + :return: A `ModelEndpoint` object. + """ + + # Convert into `ModelEndpoint` object + endpoint_obj = mlrun.common.schemas.ModelEndpoint().from_flat_dict(endpoint) + + # If feature analysis was applied, add feature stats and current stats to the model endpoint result + if feature_analysis and endpoint_obj.spec.feature_names: + + endpoint_features = self.get_endpoint_features( + feature_names=endpoint_obj.spec.feature_names, + feature_stats=endpoint_obj.status.feature_stats, + current_stats=endpoint_obj.status.current_stats, + ) + if endpoint_features: + endpoint_obj.status.features = endpoint_features + # Add the latest drift measures results (calculated by the model monitoring batch) + drift_measures = self._json_loads_if_not_none( + endpoint.get( + model_monitoring_constants.EventFieldType.DRIFT_MEASURES + ) + ) + endpoint_obj.status.drift_measures = drift_measures + + return endpoint_obj + + @staticmethod + def get_endpoint_features( + feature_names: typing.List[str], + feature_stats: dict = None, + current_stats: dict = None, + ) -> typing.List[mlrun.common.schemas.Features]: + """ + Getting a new list of features that exist in feature_names along with their expected (feature_stats) and + actual (current_stats) stats. The expected stats were calculated during the creation of the model endpoint, + usually based on the data from the Model Artifact. The actual stats are based on the results from the latest + model monitoring batch job. + + param feature_names: List of feature names. + param feature_stats: Dictionary of feature stats that were stored during the creation of the model endpoint + object. + param current_stats: Dictionary of the latest stats that were stored during the last run of the model monitoring + batch job. + + return: List of feature objects. Each feature has a name, weight, expected values, and actual values. More info + can be found under `mlrun.common.schemas.Features`. + """ + + # Initialize feature and current stats dictionaries + safe_feature_stats = feature_stats or {} + safe_current_stats = current_stats or {} + + # Create feature object and add it to a general features list + features = [] + for name in feature_names: + if feature_stats is not None and name not in feature_stats: + logger.warn("Feature missing from 'feature_stats'", name=name) + if current_stats is not None and name not in current_stats: + logger.warn("Feature missing from 'current_stats'", name=name) + f = mlrun.common.schemas.Features.new( + name, safe_feature_stats.get(name), safe_current_stats.get(name) + ) + features.append(f) + return features + + @staticmethod + def _json_loads_if_not_none(field: typing.Any) -> typing.Any: + return ( + json.loads(field) + if field and field != "null" and field is not None + else None + ) + def deploy_monitoring_functions( self, project: str, model_monitoring_access_key: str, db_session: sqlalchemy.orm.Session, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, tracking_policy: mlrun.utils.model_monitoring.TrackingPolicy, ): """ @@ -526,7 +737,7 @@ def deploy_monitoring_functions( ) def verify_project_has_no_model_endpoints(self, project_name: str): - auth_info = mlrun.api.schemas.AuthInfo( + auth_info = mlrun.common.schemas.AuthInfo( data_session=os.getenv("V3IO_ACCESS_KEY") ) @@ -539,13 +750,14 @@ def verify_project_has_no_model_endpoints(self, project_name: str): f"Project {project_name} can not be deleted since related resources found: model endpoints" ) - def delete_model_endpoints_resources(self, project_name: str): + @staticmethod + def delete_model_endpoints_resources(project_name: str): """ Delete all model endpoints resources. :param project_name: The name of the project. """ - auth_info = mlrun.api.schemas.AuthInfo( + auth_info = mlrun.common.schemas.AuthInfo( data_session=os.getenv("V3IO_ACCESS_KEY") ) @@ -554,19 +766,21 @@ def delete_model_endpoints_resources(self, project_name: str): if not mlrun.mlconf.igz_version or not mlrun.mlconf.v3io_api: return - endpoints = self.list_model_endpoints(auth_info, project_name) - - endpoint_target = get_model_endpoint_target( + # Generate a model endpoint store object and get a list of model endpoint dictionaries + endpoint_store = get_model_endpoint_store( access_key=auth_info.data_session, project=project_name ) - endpoint_target.delete_model_endpoints_resources(endpoints) + endpoints = endpoint_store.list_model_endpoints() + + # Delete model endpoints resources from databases using the model endpoint store object + endpoint_store.delete_model_endpoints_resources(endpoints) - @staticmethod def deploy_model_monitoring_stream_processing( + self, project: str, model_monitoring_access_key: str, db_session: sqlalchemy.orm.Session, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, tracking_policy: mlrun.utils.model_monitoring.TrackingPolicy, ): """ @@ -587,7 +801,7 @@ def deploy_model_monitoring_stream_processing( ) try: # validate that the model monitoring stream has not yet been deployed - mlrun.runtimes.function.get_nuclio_deploy_status( + mlrun.api.crud.runtimes.nuclio.function.get_nuclio_deploy_status( name="model-monitoring-stream", project=project, tag="", @@ -603,8 +817,17 @@ def deploy_model_monitoring_stream_processing( "Deploying model monitoring stream processing function", project=project ) + # Get parquet target value for model monitoring stream function + parquet_target = self._get_monitoring_parquet_path( + db_session=db_session, project=project + ) + fn = mlrun.model_monitoring.helpers.initial_model_monitoring_stream_processing_function( - project, model_monitoring_access_key, db_session, tracking_policy + project=project, + model_monitoring_access_key=model_monitoring_access_key, + tracking_policy=tracking_policy, + auth_info=auth_info, + parquet_target=parquet_target, ) mlrun.api.api.endpoints.functions._build_function( @@ -616,7 +839,7 @@ def deploy_model_monitoring_batch_processing( project: str, model_monitoring_access_key: str, db_session: sqlalchemy.orm.Session, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, tracking_policy: mlrun.utils.model_monitoring.TrackingPolicy, ): """ @@ -700,7 +923,7 @@ def _clean_feature_name(feature_name): return feature_name.replace(" ", "_").replace("(", "").replace(")", "") @staticmethod - def get_access_key(auth_info: mlrun.api.schemas.AuthInfo): + def get_access_key(auth_info: mlrun.common.schemas.AuthInfo): """ Getting access key from the current data session. This method is usually used to verify that the session is valid and contains an access key. @@ -739,7 +962,7 @@ def _get_batching_interval_param(intervals_list: typing.List): @staticmethod def _convert_to_cron_string( - cron_trigger: mlrun.api.schemas.schedule.ScheduleCronTrigger, + cron_trigger: mlrun.common.schemas.schedule.ScheduleCronTrigger, ): """Converting the batch interval `ScheduleCronTrigger` into a cron trigger expression""" return "{} {} {} * *".format( diff --git a/mlrun/api/crud/notifications.py b/mlrun/api/crud/notifications.py index 5ff367986bb2..e3658244fd30 100644 --- a/mlrun/api/crud/notifications.py +++ b/mlrun/api/crud/notifications.py @@ -17,7 +17,11 @@ import sqlalchemy.orm import mlrun.api.api.utils +import mlrun.api.db.sqldb.db +import mlrun.api.utils.scheduler import mlrun.api.utils.singletons.db +import mlrun.api.utils.singletons.scheduler +import mlrun.common.schemas import mlrun.utils.singleton @@ -32,13 +36,11 @@ def store_run_notifications( project: str = None, ): project = project or mlrun.mlconf.default_project - notification_objects_to_store = [] - for notification_object in notification_objects: - notification_objects_to_store.append( - mlrun.api.api.utils.mask_notification_params_with_secret( - project, run_uid, notification_object - ) + notification_objects_to_store = ( + mlrun.api.api.utils.validate_and_mask_notification_list( + notification_objects, run_uid, project ) + ) mlrun.api.utils.singletons.db.get_db().store_run_notifications( session, notification_objects_to_store, run_uid, project @@ -78,3 +80,70 @@ def delete_run_notifications( mlrun.api.utils.singletons.db.get_db().delete_run_notifications( session, name, run_uid, project ) + + @staticmethod + def set_object_notifications( + db_session: sqlalchemy.orm.Session, + auth_info: mlrun.common.schemas.AuthInfo, + project: str, + notifications: typing.List[mlrun.common.schemas.Notification], + notification_parent: typing.Union[ + mlrun.common.schemas.RunIdentifier, mlrun.common.schemas.ScheduleIdentifier + ], + ): + """ + Sets notifications on given object (run or schedule, might be extended in the future). + This will replace any existing notifications. + :param db_session: DB session + :param auth_info: Authorization info + :param project: Project name + :param notifications: List of notifications to set + :param notification_parent: Identifier of the object on which to set the notifications + """ + set_notification_methods = { + "run": { + "factory": mlrun.api.utils.singletons.db.get_db, + "method_name": mlrun.api.db.sqldb.db.SQLDB.set_run_notifications.__name__, + "identifier_key": "uid", + }, + "schedule": { + "factory": mlrun.api.utils.singletons.scheduler.get_scheduler, + "method_name": mlrun.api.utils.scheduler.Scheduler.set_schedule_notifications.__name__, + "identifier_key": "name", + }, + } + + set_notification_method = set_notification_methods.get( + notification_parent.kind, {} + ) + factory = set_notification_method.get("factory") + if not factory: + raise mlrun.errors.MLRunNotFoundError( + f"couldn't find factory for object kind: {notification_parent.kind}" + ) + set_func = set_notification_method.get("method_name") + if not set_func: + raise mlrun.errors.MLRunNotFoundError( + f"couldn't find set notification function for object kind: {notification_parent.kind}" + ) + identifier_key = set_notification_method.get("identifier_key") + if not identifier_key: + raise mlrun.errors.MLRunNotFoundError( + f"couldn't find identifier key for object kind: {notification_parent.kind}" + ) + + notification_objects_to_set = ( + mlrun.api.api.utils.validate_and_mask_notification_list( + notifications, + getattr(notification_parent, identifier_key), + project, + ) + ) + + getattr(factory(), set_func)( + session=db_session, + project=project, + notifications=notification_objects_to_set, + identifier=notification_parent, + auth_info=auth_info, + ) diff --git a/mlrun/api/crud/pipelines.py b/mlrun/api/crud/pipelines.py index f0b1c74e4ed1..53c7eccbfb4d 100644 --- a/mlrun/api/crud/pipelines.py +++ b/mlrun/api/crud/pipelines.py @@ -20,11 +20,12 @@ import typing import kfp +import kfp_server_api import sqlalchemy.orm import mlrun import mlrun.api.api.utils -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.errors import mlrun.kfpops import mlrun.utils.helpers @@ -44,14 +45,14 @@ def list_pipelines( sort_by: str = "", page_token: str = "", filter_: str = "", - format_: mlrun.api.schemas.PipelinesFormat = mlrun.api.schemas.PipelinesFormat.metadata_only, + format_: mlrun.common.schemas.PipelinesFormat = mlrun.common.schemas.PipelinesFormat.metadata_only, page_size: typing.Optional[int] = None, ) -> typing.Tuple[int, typing.Optional[int], typing.List[dict]]: if project != "*" and (page_token or page_size): raise mlrun.errors.MLRunInvalidArgumentError( "Filtering by project can not be used together with pagination" ) - if format_ == mlrun.api.schemas.PipelinesFormat.summary: + if format_ == mlrun.common.schemas.PipelinesFormat.summary: # we don't support summary format in list pipelines since the returned runs doesn't include the workflow # manifest status that includes the nodes section we use to generate the DAG. # (There is a workflow manifest under the run's pipeline_spec field, but it doesn't include the status) @@ -72,7 +73,7 @@ def list_pipelines( # the filter that was used to create it) response = kfp_client._run_api.list_runs( page_token=page_token, - page_size=mlrun.api.schemas.PipelinesPagination.max_page_size, + page_size=mlrun.common.schemas.PipelinesPagination.max_page_size, sort_by=sort_by, filter=filter_ if page_token == "" else "", ) @@ -90,7 +91,7 @@ def list_pipelines( response = kfp_client._run_api.list_runs( page_token=page_token, page_size=page_size - or mlrun.api.schemas.PipelinesPagination.default_page_size, + or mlrun.common.schemas.PipelinesPagination.default_page_size, sort_by=sort_by, filter=filter_, ) @@ -107,7 +108,7 @@ def get_pipeline( run_id: str, project: typing.Optional[str] = None, namespace: typing.Optional[str] = None, - format_: mlrun.api.schemas.PipelinesFormat = mlrun.api.schemas.PipelinesFormat.summary, + format_: mlrun.common.schemas.PipelinesFormat = mlrun.common.schemas.PipelinesFormat.summary, ): kfp_url = mlrun.mlconf.resolve_kfp_url(namespace) if not kfp_url: @@ -123,13 +124,16 @@ def get_pipeline( if project and project != "*": run_project = self.resolve_project_from_pipeline(run) if run_project != project: - raise mlrun.errors.MLRunInvalidArgumentError( + raise mlrun.errors.MLRunNotFoundError( f"Pipeline run with id {run_id} is not of project {project}" ) run = self._format_run( db_session, run, format_, api_run_detail.to_dict() ) - + except kfp_server_api.ApiException as exc: + mlrun.errors.raise_for_status_code(int(exc.status), err_to_str(exc)) + except mlrun.errors.MLRunHTTPStatusError: + raise except Exception as exc: raise mlrun.errors.MLRunRuntimeError( f"Failed getting kfp run: {err_to_str(exc)}" @@ -159,7 +163,6 @@ def create_pipeline( ) logger.debug("Writing pipeline to temp file", content_type=content_type) - print(str(data)) pipeline_file = tempfile.NamedTemporaryFile(suffix=content_type) with open(pipeline_file.name, "wb") as fp: @@ -201,7 +204,7 @@ def _format_runs( self, db_session: sqlalchemy.orm.Session, runs: typing.List[dict], - format_: mlrun.api.schemas.PipelinesFormat = mlrun.api.schemas.PipelinesFormat.metadata_only, + format_: mlrun.common.schemas.PipelinesFormat = mlrun.common.schemas.PipelinesFormat.metadata_only, ) -> typing.List[dict]: formatted_runs = [] for run in runs: @@ -212,15 +215,15 @@ def _format_run( self, db_session: sqlalchemy.orm.Session, run: dict, - format_: mlrun.api.schemas.PipelinesFormat = mlrun.api.schemas.PipelinesFormat.metadata_only, + format_: mlrun.common.schemas.PipelinesFormat = mlrun.common.schemas.PipelinesFormat.metadata_only, api_run_detail: typing.Optional[dict] = None, ) -> dict: run["project"] = self.resolve_project_from_pipeline(run) - if format_ == mlrun.api.schemas.PipelinesFormat.full: + if format_ == mlrun.common.schemas.PipelinesFormat.full: return run - elif format_ == mlrun.api.schemas.PipelinesFormat.metadata_only: + elif format_ == mlrun.common.schemas.PipelinesFormat.metadata_only: return { - k: str(v) + k: str(v) if v is not None else v for k, v in run.items() if k in [ @@ -235,9 +238,9 @@ def _format_run( "description", ] } - elif format_ == mlrun.api.schemas.PipelinesFormat.name_only: + elif format_ == mlrun.common.schemas.PipelinesFormat.name_only: return run.get("name") - elif format_ == mlrun.api.schemas.PipelinesFormat.summary: + elif format_ == mlrun.common.schemas.PipelinesFormat.summary: if not api_run_detail: raise mlrun.errors.MLRunRuntimeError( "The full kfp api_run_detail object is needed to generate the summary format" diff --git a/mlrun/api/crud/projects.py b/mlrun/api/crud/projects.py index 99dc1eeae2f9..85c24915967c 100644 --- a/mlrun/api/crud/projects.py +++ b/mlrun/api/crud/projects.py @@ -23,11 +23,12 @@ import mlrun.api.crud import mlrun.api.db.session -import mlrun.api.schemas +import mlrun.api.utils.events.events_factory as events_factory import mlrun.api.utils.projects.remotes.follower import mlrun.api.utils.singletons.db import mlrun.api.utils.singletons.k8s import mlrun.api.utils.singletons.scheduler +import mlrun.common.schemas import mlrun.errors import mlrun.utils.singleton from mlrun.utils import logger @@ -44,18 +45,38 @@ def __init__(self) -> None: } def create_project( - self, session: sqlalchemy.orm.Session, project: mlrun.api.schemas.Project + self, session: sqlalchemy.orm.Session, project: mlrun.common.schemas.Project ): - logger.debug("Creating project", project=project) + logger.debug( + "Creating project", + name=project.metadata.name, + owner=project.spec.owner, + created_time=project.metadata.created, + desired_state=project.spec.desired_state, + state=project.status.state, + function_amount=len(project.spec.functions or []), + artifact_amount=len(project.spec.artifacts or []), + workflows_amount=len(project.spec.workflows or []), + ) mlrun.api.utils.singletons.db.get_db().create_project(session, project) def store_project( self, session: sqlalchemy.orm.Session, name: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): - logger.debug("Storing project", name=name, project=project) + logger.debug( + "Storing project", + name=project.metadata.name, + owner=project.spec.owner, + created_time=project.metadata.created, + desired_state=project.spec.desired_state, + state=project.status.state, + function_amount=len(project.spec.functions or []), + artifact_amount=len(project.spec.artifacts or []), + workflows_amount=len(project.spec.workflows or []), + ) mlrun.api.utils.singletons.db.get_db().store_project(session, name, project) def patch_project( @@ -63,7 +84,7 @@ def patch_project( session: sqlalchemy.orm.Session, name: str, project: dict, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ): logger.debug( "Patching project", name=name, project=project, patch_mode=patch_mode @@ -76,12 +97,12 @@ def delete_project( self, session: sqlalchemy.orm.Session, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), ): logger.debug("Deleting project", name=name, deletion_strategy=deletion_strategy) if ( deletion_strategy.is_restricted() - or deletion_strategy == mlrun.api.schemas.DeletionStrategy.check + or deletion_strategy == mlrun.common.schemas.DeletionStrategy.check ): if not mlrun.api.utils.singletons.db.get_db().is_project_exists( session, name @@ -91,7 +112,7 @@ def delete_project( session, name ) self._verify_project_has_no_external_resources(name) - if deletion_strategy == mlrun.api.schemas.DeletionStrategy.check: + if deletion_strategy == mlrun.common.schemas.DeletionStrategy.check: return elif deletion_strategy.is_cascading(): self.delete_project_resources(session, name) @@ -114,7 +135,7 @@ def _verify_project_has_no_external_resources(self, project: str): # Therefore, this check should remain at the end of the verification flow. if ( mlrun.mlconf.is_api_running_on_k8s() - and mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_keys( + and mlrun.api.utils.singletons.k8s.get_k8s_helper().get_project_secret_keys( project ) ): @@ -142,7 +163,7 @@ def delete_project_resources( # log collector service will delete the logs, so we don't need to do it here if ( mlrun.mlconf.log_collector.mode - == mlrun.api.schemas.LogsCollectorMode.legacy + == mlrun.common.schemas.LogsCollectorMode.legacy ): mlrun.api.crud.Logs().delete_logs(name) @@ -156,22 +177,32 @@ def delete_project_resources( # delete project secrets - passing None will delete all secrets if mlrun.mlconf.is_api_running_on_k8s(): - mlrun.api.utils.singletons.k8s.get_k8s().delete_project_secrets(name, None) + secrets = None + ( + secret_name, + _, + ) = mlrun.api.utils.singletons.k8s.get_k8s_helper().delete_project_secrets( + name, secrets + ) + events_client = events_factory.EventsFactory().get_events_client() + events_client.emit( + events_client.generate_project_secret_deleted_event(name, secret_name) + ) def get_project( self, session: sqlalchemy.orm.Session, name: str - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: return mlrun.api.utils.singletons.db.get_db().get_project(session, name) def list_projects( self, session: sqlalchemy.orm.Session, owner: str = None, - format_: mlrun.api.schemas.ProjectsFormat = mlrun.api.schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectsOutput: + ) -> mlrun.common.schemas.ProjectsOutput: return mlrun.api.utils.singletons.db.get_db().list_projects( session, owner, format_, labels, state, names ) @@ -181,14 +212,14 @@ async def list_project_summaries( session: sqlalchemy.orm.Session, owner: str = None, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectSummariesOutput: + ) -> mlrun.common.schemas.ProjectSummariesOutput: projects_output = await fastapi.concurrency.run_in_threadpool( self.list_projects, session, owner, - mlrun.api.schemas.ProjectsFormat.name_only, + mlrun.common.schemas.ProjectsFormat.name_only, labels, state, names, @@ -196,13 +227,13 @@ async def list_project_summaries( project_summaries = await self.generate_projects_summaries( projects_output.projects ) - return mlrun.api.schemas.ProjectSummariesOutput( + return mlrun.common.schemas.ProjectSummariesOutput( project_summaries=project_summaries ) async def get_project_summary( self, session: sqlalchemy.orm.Session, name: str - ) -> mlrun.api.schemas.ProjectSummary: + ) -> mlrun.common.schemas.ProjectSummary: # Call get project so we'll explode if project doesn't exists await fastapi.concurrency.run_in_threadpool(self.get_project, session, name) project_summaries = await self.generate_projects_summaries([name]) @@ -210,7 +241,7 @@ async def get_project_summary( async def generate_projects_summaries( self, projects: typing.List[str] - ) -> typing.List[mlrun.api.schemas.ProjectSummary]: + ) -> typing.List[mlrun.common.schemas.ProjectSummary]: ( project_to_files_count, project_to_schedule_count, @@ -223,7 +254,7 @@ async def generate_projects_summaries( project_summaries = [] for project in projects: project_summaries.append( - mlrun.api.schemas.ProjectSummary( + mlrun.common.schemas.ProjectSummary( name=project, files_count=project_to_files_count.get(project, 0), schedules_count=project_to_schedule_count.get(project, 0), @@ -294,7 +325,7 @@ async def _get_project_resources_counters( @staticmethod def _list_pipelines( session, - format_: mlrun.api.schemas.PipelinesFormat = mlrun.api.schemas.PipelinesFormat.metadata_only, + format_: mlrun.common.schemas.PipelinesFormat = mlrun.common.schemas.PipelinesFormat.metadata_only, ): return mlrun.api.crud.Pipelines().list_pipelines(session, "*", format_=format_) diff --git a/mlrun/api/crud/runs.py b/mlrun/api/crud/runs.py index e2b13972aba6..c99ce0f35023 100644 --- a/mlrun/api/crud/runs.py +++ b/mlrun/api/crud/runs.py @@ -16,10 +16,10 @@ import sqlalchemy.orm -import mlrun.api.schemas import mlrun.api.utils.projects.remotes.follower import mlrun.api.utils.singletons.db import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.config import mlrun.errors import mlrun.lists @@ -41,7 +41,6 @@ def store_run( project: str = mlrun.mlconf.default_project, ): project = project or mlrun.mlconf.default_project - logger.info("Storing run", data=data) mlrun.api.utils.singletons.db.get_db().store_run( db_session, data, @@ -59,7 +58,7 @@ def update_run( data: dict, ): project = project or mlrun.mlconf.default_project - logger.debug("Updating run", project=project, uid=uid, iter=iter, data=data) + logger.debug("Updating run", project=project, uid=uid, iter=iter) # TODO: do some desired state for run, it doesn't make sense that API user changes the status in order to # trigger abortion if ( @@ -120,10 +119,10 @@ def list_runs( start_time_to=None, last_update_time_from=None, last_update_time_to=None, - partition_by: mlrun.api.schemas.RunPartitionByField = None, + partition_by: mlrun.common.schemas.RunPartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: mlrun.api.schemas.SortField = None, - partition_order: mlrun.api.schemas.OrderType = mlrun.api.schemas.OrderType.desc, + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, max_partitions: int = 0, requested_logs: bool = None, return_as_run_structs: bool = True, diff --git a/mlrun/api/crud/runtime_resources.py b/mlrun/api/crud/runtime_resources.py index 33ac0556d6b2..fd92a5c3e8f0 100644 --- a/mlrun/api/crud/runtime_resources.py +++ b/mlrun/api/crud/runtime_resources.py @@ -18,9 +18,9 @@ import sqlalchemy.orm import mlrun.api.api.utils -import mlrun.api.schemas import mlrun.api.utils.projects.remotes.follower import mlrun.api.utils.singletons.db +import mlrun.common.schemas import mlrun.config import mlrun.errors import mlrun.runtimes @@ -37,12 +37,12 @@ def list_runtime_resources( object_id: typing.Optional[str] = None, label_selector: typing.Optional[str] = None, group_by: typing.Optional[ - mlrun.api.schemas.ListRuntimeResourcesGroupByField + mlrun.common.schemas.ListRuntimeResourcesGroupByField ] = None, ) -> typing.Union[ - mlrun.api.schemas.RuntimeResourcesOutput, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResourcesOutput, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ]: response = [] if group_by is None else {} kinds = mlrun.runtimes.RuntimeKinds.runtime_with_handlers() @@ -56,7 +56,7 @@ def list_runtime_resources( ) if group_by is None: response.append( - mlrun.api.schemas.KindRuntimeResources( + mlrun.common.schemas.KindRuntimeResources( kind=kind, resources=resources ) ) @@ -66,15 +66,15 @@ def list_runtime_resources( def filter_and_format_grouped_by_project_runtime_resources_output( self, - grouped_by_project_runtime_resources_output: mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + grouped_by_project_runtime_resources_output: mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, allowed_projects: typing.List[str], group_by: typing.Optional[ - mlrun.api.schemas.ListRuntimeResourcesGroupByField + mlrun.common.schemas.ListRuntimeResourcesGroupByField ] = None, ) -> typing.Union[ - mlrun.api.schemas.RuntimeResourcesOutput, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResourcesOutput, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ]: runtime_resources_by_kind = {} for ( @@ -94,7 +94,7 @@ def filter_and_format_grouped_by_project_runtime_resources_output( ) if group_by is None: runtimes_resources_output.append( - mlrun.api.schemas.KindRuntimeResources( + mlrun.common.schemas.KindRuntimeResources( kind=kind, resources=resources ) ) diff --git a/mlrun/api/db/filedb/__init__.py b/mlrun/api/crud/runtimes/__init__.py similarity index 100% rename from mlrun/api/db/filedb/__init__.py rename to mlrun/api/crud/runtimes/__init__.py diff --git a/tests/api/api/marketplace/__init__.py b/mlrun/api/crud/runtimes/nuclio/__init__.py similarity index 100% rename from tests/api/api/marketplace/__init__.py rename to mlrun/api/crud/runtimes/nuclio/__init__.py diff --git a/mlrun/api/crud/runtimes/nuclio/function.py b/mlrun/api/crud/runtimes/nuclio/function.py new file mode 100644 index 000000000000..1b0b3ca271b3 --- /dev/null +++ b/mlrun/api/crud/runtimes/nuclio/function.py @@ -0,0 +1,500 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import base64 +import shlex + +import nuclio +import nuclio.utils +import requests + +import mlrun +import mlrun.api.crud.runtimes.nuclio.helpers +import mlrun.api.utils.builder +import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas +import mlrun.datastore +import mlrun.errors +import mlrun.runtimes.function +import mlrun.runtimes.pod +import mlrun.utils +from mlrun.utils import logger + + +def deploy_nuclio_function( + function: mlrun.runtimes.function.RemoteRuntime, + auth_info: mlrun.common.schemas.AuthInfo = None, + client_version: str = None, + builder_env: dict = None, + client_python_version: str = None, +): + """Deploys a nuclio function. + + :param function: nuclio function object + :param auth_info: service AuthInfo + :param client_version: mlrun client version + :param builder_env: mlrun builder environment (for config/credentials) + :param client_python_version: mlrun client python version + """ + function_name, project_name, function_config = _compile_function_config( + function, + client_version=client_version, + client_python_version=client_python_version, + builder_env=builder_env or {}, + auth_info=auth_info, + ) + + # if mode allows it, enrich function http trigger with an ingress + mlrun.api.crud.runtimes.nuclio.helpers.enrich_function_with_ingress( + function_config, + function.spec.add_templated_ingress_host_mode + or mlrun.mlconf.httpdb.nuclio.add_templated_ingress_host_mode, + function.spec.service_type or mlrun.mlconf.httpdb.nuclio.default_service_type, + ) + + try: + logger.info( + "Starting Nuclio function deployment", + function_name=function_name, + project_name=project_name, + ) + return nuclio.deploy.deploy_config( + function_config, + dashboard_url=mlrun.mlconf.nuclio_dashboard_url, + name=function_name, + project=project_name, + tag=function.metadata.tag, + verbose=function.verbose, + create_new=True, + watch=False, + return_address_mode=nuclio.deploy.ReturnAddressModes.all, + auth_info=auth_info.to_nuclio_auth_info() if auth_info else None, + ) + except nuclio.utils.DeployError as exc: + if exc.err: + err_message = ( + f"Failed to deploy nuclio function {project_name}/{function_name}" + ) + + try: + + # the error might not be jsonable, so we'll try to parse it + # and extract the error message + json_err = exc.err.response.json() + if "error" in json_err: + err_message += f" {json_err['error']}" + if "errorStackTrace" in json_err: + logger.warning( + "Failed to deploy nuclio function", + nuclio_stacktrace=json_err["errorStackTrace"], + ) + except Exception as parse_exc: + logger.warning( + "Failed to parse nuclio deploy error", + parse_exc=mlrun.errors.err_to_str(parse_exc), + ) + + mlrun.errors.raise_for_status( + exc.err.response, + err_message, + ) + raise + + +def get_nuclio_deploy_status( + name, + project, + tag, + last_log_timestamp=0, + verbose=False, + resolve_address=True, + auth_info: mlrun.common.schemas.AuthInfo = None, +): + """ + Get nuclio function deploy status + + :param name: function name + :param project: project name + :param tag: function tag + :param last_log_timestamp: last log timestamp + :param verbose: print logs + :param resolve_address: whether to resolve function address + :param auth_info: authentication information + """ + api_address = nuclio.deploy.find_dashboard_url(mlrun.mlconf.nuclio_dashboard_url) + name = mlrun.runtimes.function.get_fullname(name, project, tag) + get_err_message = f"Failed to get function {name} deploy status" + + try: + ( + state, + address, + last_log_timestamp, + outputs, + function_status, + ) = nuclio.deploy.get_deploy_status( + api_address, + name, + last_log_timestamp, + verbose, + resolve_address, + return_function_status=True, + auth_info=auth_info.to_nuclio_auth_info() if auth_info else None, + ) + except requests.exceptions.ConnectionError as exc: + mlrun.errors.raise_for_status( + exc.response, + get_err_message, + ) + + except nuclio.utils.DeployError as exc: + if exc.err: + mlrun.errors.raise_for_status( + exc.err.response, + get_err_message, + ) + raise exc + else: + text = "\n".join(outputs) if outputs else "" + return state, address, name, last_log_timestamp, text, function_status + + +def _compile_function_config( + function: mlrun.runtimes.function.RemoteRuntime, + client_version: str = None, + client_python_version: str = None, + builder_env=None, + auth_info=None, +): + _set_function_labels(function) + + # resolve env vars before compiling the nuclio spec, as we need to set them in the spec + env_dict, external_source_env_dict = _resolve_env_vars(function) + + nuclio_spec = nuclio.ConfigSpec( + env=env_dict, + external_source_env=external_source_env_dict, + config=function.spec.config, + ) + nuclio_spec.cmd = function.spec.build.commands or [] + + _resolve_and_set_build_requirements(function, nuclio_spec) + _resolve_and_set_nuclio_runtime( + function, nuclio_spec, client_version, client_python_version + ) + + project = function.metadata.project or "default" + tag = function.metadata.tag + handler = function.spec.function_handler + + _set_build_params(function, nuclio_spec, builder_env, project, auth_info) + _set_function_scheduling_params(function, nuclio_spec) + _set_function_replicas(function, nuclio_spec) + _set_misc_specs(function, nuclio_spec) + + # if the user code is given explicitly or from a source, we need to set the handler and relevant attributes + if ( + function.spec.base_spec + or function.spec.build.functionSourceCode + or function.spec.build.source + or function.kind == mlrun.runtimes.RuntimeKinds.serving # serving can be empty + ): + config = function.spec.base_spec + if not config: + # if base_spec was not set (when not using code_to_function) and we have base64 code + # we create the base spec with essential attributes + config = nuclio.config.new_config() + mlrun.utils.update_in(config, "spec.handler", handler or "main:handler") + + config = nuclio.config.extend_config( + config, nuclio_spec, tag, function.spec.build.code_origin + ) + + if ( + function.kind == mlrun.runtimes.RuntimeKinds.serving + and not mlrun.utils.get_in(config, "spec.build.functionSourceCode") + ): + _set_source_code_and_handler(function, config) + else: + # this may also be called in case of using single file code_to_function(embed_code=False) + # this option need to be removed or be limited to using remote files (this code runs in server) + function_name, config, code = nuclio.build_file( + function.spec.source, + name=function.metadata.name, + project=project, + handler=handler, + tag=tag, + spec=nuclio_spec, + kind=function.spec.function_kind, + verbose=function.verbose, + ) + + mlrun.utils.update_in( + config, "spec.volumes", function.spec.generate_nuclio_volumes() + ) + + _resolve_and_set_base_image(function, config, client_version, client_python_version) + function_name = _set_function_name(function, config, project, tag) + + return function_name, project, config + + +def _set_function_labels(function): + labels = function.metadata.labels or {} + labels.update({"mlrun/class": function.kind}) + for key, value in labels.items(): + # Adding escaping to the key to prevent it from being split by dots if it contains any + function.set_config(f"metadata.labels.\\{key}\\", value) + + +def _resolve_env_vars(function): + # Add secret configurations to function's pod spec, if secret sources were added. + # Needs to be here, since it adds env params, which are handled in the next lines. + # This only needs to run if we're running within k8s context. If running in Docker, for example, skip. + if mlrun.api.utils.singletons.k8s.get_k8s_helper( + silent=True + ).is_running_inside_kubernetes_cluster(): + function.add_secrets_config_to_spec() + + env_dict, external_source_env_dict = function._get_nuclio_config_spec_env() + + # In nuclio 1.6.0<=v<1.8.0, python runtimes default behavior was to not decode event strings + # Our code is counting on the strings to be decoded, so add the needed env var for those versions + if ( + mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.6.0", "1.8.0" + ) + and "NUCLIO_PYTHON_DECODE_EVENT_STRINGS" not in env_dict + ): + env_dict["NUCLIO_PYTHON_DECODE_EVENT_STRINGS"] = "true" + + return env_dict, external_source_env_dict + + +def _resolve_and_set_nuclio_runtime( + function, nuclio_spec, client_version, client_python_version +): + nuclio_runtime = ( + function.spec.nuclio_runtime + or mlrun.api.crud.runtimes.nuclio.helpers.resolve_nuclio_runtime_python_image( + mlrun_client_version=client_version, python_version=client_python_version + ) + ) + + # For backwards compatibility, we need to adjust the runtime for old Nuclio versions + if mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "0.0.0", "1.6.0" + ) and nuclio_runtime in [ + "python:3.7", + "python:3.8", + ]: + nuclio_runtime_set_from_spec = nuclio_runtime == function.spec.nuclio_runtime + if nuclio_runtime_set_from_spec: + raise mlrun.errors.MLRunInvalidArgumentError( + f"Nuclio version does not support the configured runtime: {nuclio_runtime}" + ) + else: + # our default is python:3.9, simply set it to python:3.6 to keep supporting envs with old Nuclio + nuclio_runtime = "python:3.6" + + nuclio_spec.set_config("spec.runtime", nuclio_runtime) + + +def _resolve_and_set_build_requirements(function, nuclio_spec): + if function.spec.build.requirements: + resolved_requirements = [] + # wrap in single quote to ensure that the requirement is treated as a single string + # quote the requirement to avoid issues with special characters, double quotes, etc. + for requirement in function.spec.build.requirements: + # -r / --requirement are flags and should not be escaped + # we allow such flags (could be passed within the requirements.txt file) and do not + # try to open the file and include its content since it might be a remote file + # given on the base image. + for req_flag in ["-r", "--requirement"]: + if requirement.startswith(req_flag): + requirement = requirement[len(req_flag) :].strip() + resolved_requirements.append(req_flag) + break + + resolved_requirements.append(shlex.quote(requirement)) + + encoded_requirements = " ".join(resolved_requirements) + nuclio_spec.cmd.append(f"python -m pip install {encoded_requirements}") + + +def _set_build_params(function, nuclio_spec, builder_env, project, auth_info=None): + # handle archive build params + if function.spec.build.source: + mlrun.api.crud.runtimes.nuclio.helpers.compile_nuclio_archive_config( + nuclio_spec, function, builder_env, project, auth_info=auth_info + ) + + if function.spec.no_cache: + nuclio_spec.set_config("spec.build.noCache", True) + if function.spec.build.functionSourceCode: + nuclio_spec.set_config( + "spec.build.functionSourceCode", function.spec.build.functionSourceCode + ) + + image_pull_secret = ( + mlrun.api.crud.runtimes.nuclio.helpers.resolve_function_image_pull_secret( + function + ) + ) + if image_pull_secret: + nuclio_spec.set_config("spec.imagePullSecrets", image_pull_secret) + + if function.spec.base_image_pull: + nuclio_spec.set_config("spec.build.noBaseImagesPull", False) + + +def _set_function_scheduling_params(function, nuclio_spec): + # don't send node selections if nuclio is not compatible + if mlrun.runtimes.function.validate_nuclio_version_compatibility( + "1.5.20", "1.6.10" + ): + if function.spec.node_selector: + nuclio_spec.set_config("spec.nodeSelector", function.spec.node_selector) + if function.spec.node_name: + nuclio_spec.set_config("spec.nodeName", function.spec.node_name) + if function.spec.affinity: + nuclio_spec.set_config( + "spec.affinity", + mlrun.runtimes.pod.get_sanitized_attribute(function.spec, "affinity"), + ) + + # don't send tolerations if nuclio is not compatible + if mlrun.runtimes.function.validate_nuclio_version_compatibility("1.7.5"): + if function.spec.tolerations: + nuclio_spec.set_config( + "spec.tolerations", + mlrun.runtimes.pod.get_sanitized_attribute( + function.spec, "tolerations" + ), + ) + # don't send preemption_mode if nuclio is not compatible + if mlrun.runtimes.function.validate_nuclio_version_compatibility("1.8.6"): + if function.spec.preemption_mode: + nuclio_spec.set_config( + "spec.PreemptionMode", + function.spec.preemption_mode, + ) + + +def _set_function_replicas(function, nuclio_spec): + if function.spec.replicas: + nuclio_spec.set_config( + "spec.minReplicas", + mlrun.utils.as_number("spec.Replicas", function.spec.replicas), + ) + nuclio_spec.set_config( + "spec.maxReplicas", + mlrun.utils.as_number("spec.Replicas", function.spec.replicas), + ) + else: + nuclio_spec.set_config( + "spec.minReplicas", + mlrun.utils.as_number("spec.minReplicas", function.spec.min_replicas), + ) + nuclio_spec.set_config( + "spec.maxReplicas", + mlrun.utils.as_number("spec.maxReplicas", function.spec.max_replicas), + ) + + +def _set_misc_specs(function, nuclio_spec): + # in Nuclio >= 1.6.x default serviceType has changed to "ClusterIP". + nuclio_spec.set_config( + "spec.serviceType", + function.spec.service_type or mlrun.mlconf.httpdb.nuclio.default_service_type, + ) + if function.spec.readiness_timeout: + nuclio_spec.set_config( + "spec.readinessTimeoutSeconds", function.spec.readiness_timeout + ) + if function.spec.resources: + nuclio_spec.set_config("spec.resources", function.spec.resources) + + # don't send default or any priority class name if nuclio is not compatible + if ( + function.spec.priority_class_name + and mlrun.runtimes.function.validate_nuclio_version_compatibility("1.6.18") + and len(mlrun.mlconf.get_valid_function_priority_class_names()) + ): + nuclio_spec.set_config( + "spec.priorityClassName", function.spec.priority_class_name + ) + + if function.spec.service_account: + nuclio_spec.set_config("spec.serviceAccount", function.spec.service_account) + + if function.spec.security_context: + nuclio_spec.set_config( + "spec.securityContext", + mlrun.runtimes.pod.get_sanitized_attribute( + function.spec, "security_context" + ), + ) + + +def _set_source_code_and_handler(function, config): + if not function.spec.build.source: + # set the source to the mlrun serving wrapper + body = nuclio.build.mlrun_footer.format(mlrun.runtimes.serving.serving_subkind) + mlrun.utils.update_in( + config, + "spec.build.functionSourceCode", + base64.b64encode(body.encode("utf-8")).decode("utf-8"), + ) + elif not function.spec.function_handler: + # point the nuclio function handler to mlrun serving wrapper handlers + mlrun.utils.update_in( + config, + "spec.handler", + "mlrun.serving.serving_wrapper:handler", + ) + + +def _resolve_and_set_base_image( + function, config, client_version, client_python_version +): + base_image = ( + mlrun.utils.get_in(config, "spec.build.baseImage") + or function.spec.image + or function.spec.build.base_image + ) + if base_image: + # we ignore the returned registry secret as nuclio uses the image pull secret, which is resolved in the + # build params + ( + base_image, + _, + ) = mlrun.api.utils.builder.resolve_image_target_and_registry_secret( + base_image, secret_name=function.spec.build.secret + ) + mlrun.utils.update_in( + config, + "spec.build.baseImage", + mlrun.utils.enrich_image_url( + base_image, client_version, client_python_version + ), + ) + + +def _set_function_name(function, config, project, tag): + name = mlrun.runtimes.function.get_fullname(function.metadata.name, project, tag) + function.status.nuclio_name = name + mlrun.utils.update_in(config, "metadata.name", name) + return name diff --git a/mlrun/api/crud/runtimes/nuclio/helpers.py b/mlrun/api/crud/runtimes/nuclio/helpers.py new file mode 100644 index 000000000000..5fc746843444 --- /dev/null +++ b/mlrun/api/crud/runtimes/nuclio/helpers.py @@ -0,0 +1,310 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import urllib.parse + +import semver + +import mlrun +import mlrun.api.utils.singletons.k8s +import mlrun.runtimes +from mlrun.utils import logger + + +def resolve_function_http_trigger(function_spec): + for trigger_name, trigger_config in function_spec.get("triggers", {}).items(): + if trigger_config.get("kind") != "http": + continue + return trigger_config + + +def resolve_nuclio_runtime_python_image( + mlrun_client_version: str = None, python_version: str = None +): + # if no python version or mlrun version is passed it means we use mlrun client older than 1.3.0 therefore need + # to use the previoud default runtime which is python 3.7 + if not python_version or not mlrun_client_version: + return "python:3.7" + + # If the mlrun version is 0.0.0-, it is a dev version, + # so we can't check if it is higher than 1.3.0, but if the python version was passed, + # it means it is 1.3.0-rc or higher, so use the image according to the python version + if mlrun_client_version.startswith("0.0.0-") or "unstable" in mlrun_client_version: + if python_version.startswith("3.7"): + return "python:3.7" + + return mlrun.mlconf.default_nuclio_runtime + + # if mlrun version is older than 1.3.0 we need to use the previous default runtime which is python 3.7 + if semver.VersionInfo.parse(mlrun_client_version) < semver.VersionInfo.parse( + "1.3.0-X" + ): + return "python:3.7" + + # if mlrun version is 1.3.0 or newer and python version is 3.7 we need to use python 3.7 image + if semver.VersionInfo.parse(mlrun_client_version) >= semver.VersionInfo.parse( + "1.3.0-X" + ) and python_version.startswith("3.7"): + return "python:3.7" + + # if none of the above conditions are met we use the default runtime which is python 3.9 + return mlrun.mlconf.default_nuclio_runtime + + +def resolve_function_ingresses(function_spec): + http_trigger = resolve_function_http_trigger(function_spec) + if not http_trigger: + return [] + + ingresses = [] + for _, ingress_config in ( + http_trigger.get("attributes", {}).get("ingresses", {}).items() + ): + ingresses.append(ingress_config) + return ingresses + + +def enrich_function_with_ingress(config, mode, service_type): + # do not enrich with an ingress + if mode == mlrun.runtimes.constants.NuclioIngressAddTemplatedIngressModes.never: + return + + ingresses = resolve_function_ingresses(config["spec"]) + + # function has ingresses already, nothing to add / enrich + if ingresses: + return + + # if exists, get the http trigger the function has + # we would enrich it with an ingress + http_trigger = resolve_function_http_trigger(config["spec"]) + if not http_trigger: + # function has an HTTP trigger without an ingress + # TODO: read from nuclio-api frontend-spec + http_trigger = { + "kind": "http", + "name": "http", + "maxWorkers": 1, + "workerAvailabilityTimeoutMilliseconds": 10000, # 10 seconds + "attributes": {}, + } + + def enrich(): + http_trigger.setdefault("attributes", {}).setdefault("ingresses", {})["0"] = { + "paths": ["/"], + # this would tell Nuclio to use its default ingress host template + # and would auto assign a host for the ingress + "hostTemplate": "@nuclio.fromDefault", + } + http_trigger["attributes"]["serviceType"] = service_type + config["spec"].setdefault("triggers", {})[http_trigger["name"]] = http_trigger + + if mode == mlrun.runtimes.constants.NuclioIngressAddTemplatedIngressModes.always: + enrich() + elif ( + mode + == mlrun.runtimes.constants.NuclioIngressAddTemplatedIngressModes.on_cluster_ip + ): + + # service type is not cluster ip, bail out + if service_type and service_type.lower() != "clusterip": + return + + enrich() + + +def resolve_function_image_pull_secret(function): + """ + the corresponding attribute for 'build.secret' in nuclio is imagePullSecrets, attached link for reference + https://github.com/nuclio/nuclio/blob/e4af2a000dc52ee17337e75181ecb2652b9bf4e5/pkg/processor/build/builder.go#L1073 + if only one of the secrets is set, use it. + if both are set, use the non default one and give precedence to image_pull_secret + """ + # enrich only on server side + if not mlrun.config.is_running_as_api(): + return function.spec.image_pull_secret or function.spec.build.secret + + if function.spec.image_pull_secret is None: + function.spec.image_pull_secret = ( + mlrun.mlconf.function.spec.image_pull_secret.default + ) + elif ( + function.spec.image_pull_secret + != mlrun.mlconf.function.spec.image_pull_secret.default + ): + return function.spec.image_pull_secret + + if function.spec.build.secret is None: + function.spec.build.secret = mlrun.mlconf.httpdb.builder.docker_registry_secret + elif ( + function.spec.build.secret != mlrun.mlconf.httpdb.builder.docker_registry_secret + ): + return function.spec.build.secret + + return function.spec.image_pull_secret or function.spec.build.secret + + +def resolve_work_dir_and_handler(handler): + """ + Resolves a nuclio function working dir and handler inside an archive/git repo + :param handler: a path describing working dir and handler of a nuclio function + :return: (working_dir, handler) tuple, as nuclio expects to get it + + Example: ("a/b/c#main:Handler") -> ("a/b/c", "main:Handler") + """ + + def extend_handler(base_handler): + # return default handler and module if not specified + if not base_handler: + return "main:handler" + if ":" not in base_handler: + base_handler = f"{base_handler}:handler" + return base_handler + + if not handler: + return "", "main:handler" + + split_handler = handler.split("#") + if len(split_handler) == 1: + return "", extend_handler(handler) + + return split_handler[0], extend_handler(split_handler[1]) + + +def is_nuclio_version_in_range(min_version: str, max_version: str) -> bool: + """ + Return whether the Nuclio version is in the range, inclusive for min, exclusive for max - [min, max) + """ + resolved_nuclio_version = None + try: + parsed_min_version = semver.VersionInfo.parse(min_version) + parsed_max_version = semver.VersionInfo.parse(max_version) + resolved_nuclio_version = mlrun.runtimes.utils.resolve_nuclio_version() + parsed_current_version = semver.VersionInfo.parse(resolved_nuclio_version) + except ValueError: + logger.warning( + "Unable to parse nuclio version, assuming in range", + nuclio_version=resolved_nuclio_version, + min_version=min_version, + max_version=max_version, + ) + return True + return parsed_min_version <= parsed_current_version < parsed_max_version + + +def compile_nuclio_archive_config( + nuclio_spec, + function: mlrun.runtimes.function.RemoteRuntime, + builder_env, + project=None, + auth_info=None, +): + secrets = {} + if ( + project + and mlrun.api.utils.singletons.k8s.get_k8s_helper().is_running_inside_kubernetes_cluster() + ): + secrets = ( + mlrun.api.utils.singletons.k8s.get_k8s_helper().get_project_secret_data( + project + ) + ) + + def get_secret(key): + return builder_env.get(key) or secrets.get(key, "") + + source = function.spec.build.source + parsed_url = urllib.parse.urlparse(source) + code_entry_type = "" + if source.startswith("s3://"): + code_entry_type = "s3" + if source.startswith("git://"): + code_entry_type = "git" + for archive_prefix in ["http://", "https://", "v3io://", "v3ios://"]: + if source.startswith(archive_prefix): + code_entry_type = "archive" + + if code_entry_type == "": + raise mlrun.errors.MLRunInvalidArgumentError( + "Couldn't resolve code entry type from source" + ) + + code_entry_attributes = {} + + # resolve work_dir and handler + work_dir, handler = resolve_work_dir_and_handler(function.spec.function_handler) + work_dir = function.spec.workdir or work_dir + if work_dir != "": + code_entry_attributes["workDir"] = work_dir + + # archive + if code_entry_type == "archive": + v3io_access_key = builder_env.get("V3IO_ACCESS_KEY", "") + if source.startswith("v3io"): + if not parsed_url.netloc: + source = mlrun.mlconf.v3io_api + parsed_url.path + else: + source = f"http{source[len('v3io'):]}" + if auth_info and not v3io_access_key: + v3io_access_key = auth_info.data_session or auth_info.access_key + + if v3io_access_key: + code_entry_attributes["headers"] = {"X-V3io-Session-Key": v3io_access_key} + + # s3 + if code_entry_type == "s3": + bucket, item_key = mlrun.datastore.parse_s3_bucket_and_key(source) + + code_entry_attributes["s3Bucket"] = bucket + code_entry_attributes["s3ItemKey"] = item_key + + code_entry_attributes["s3AccessKeyId"] = get_secret("AWS_ACCESS_KEY_ID") + code_entry_attributes["s3SecretAccessKey"] = get_secret("AWS_SECRET_ACCESS_KEY") + code_entry_attributes["s3SessionToken"] = get_secret("AWS_SESSION_TOKEN") + + # git + if code_entry_type == "git": + + # change git:// to https:// as nuclio expects it to be + if source.startswith("git://"): + source = source.replace("git://", "https://") + + source, reference, branch = mlrun.utils.resolve_git_reference_from_source( + source + ) + if not branch and not reference: + raise mlrun.errors.MLRunInvalidArgumentError( + "git branch or refs must be specified in the source e.g.: " + "'git:///org/repo.git#'" + ) + if reference: + code_entry_attributes["reference"] = reference + if branch: + code_entry_attributes["branch"] = branch + + password = get_secret("GIT_PASSWORD") + username = get_secret("GIT_USERNAME") + + token = get_secret("GIT_TOKEN") + if token: + username, password = mlrun.utils.get_git_username_password_from_token(token) + + code_entry_attributes["username"] = username + code_entry_attributes["password"] = password + + # populate spec with relevant fields + nuclio_spec.set_config("spec.handler", handler) + nuclio_spec.set_config("spec.build.path", source) + nuclio_spec.set_config("spec.build.codeEntryType", code_entry_type) + nuclio_spec.set_config("spec.build.codeEntryAttributes", code_entry_attributes) diff --git a/mlrun/api/crud/secrets.py b/mlrun/api/crud/secrets.py index ddcea5ef9025..1f91372cce8a 100644 --- a/mlrun/api/crud/secrets.py +++ b/mlrun/api/crud/secrets.py @@ -17,8 +17,10 @@ import typing import uuid -import mlrun.api.schemas +import mlrun.api.utils.clients.iguazio +import mlrun.api.utils.events.events_factory as events_factory import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas import mlrun.errors import mlrun.utils.helpers import mlrun.utils.regex @@ -30,7 +32,7 @@ class SecretsClientType(str, enum.Enum): schedules = "schedules" model_monitoring = "model-monitoring" service_accounts = "service-accounts" - marketplace = "marketplace" + hub = "hub" notifications = "notifications" @@ -74,7 +76,7 @@ def validate_internal_project_secret_key_allowed( def store_project_secrets( self, project: str, - secrets: mlrun.api.schemas.SecretsData, + secrets: mlrun.common.schemas.SecretsData, allow_internal_secrets: bool = False, key_map_secret_key: typing.Optional[str] = None, allow_storing_key_maps: bool = False, @@ -94,18 +96,33 @@ def store_project_secrets( allow_storing_key_maps, ) - if secrets.provider == mlrun.api.schemas.SecretProviderName.vault: + if secrets.provider == mlrun.common.schemas.SecretProviderName.vault: # Init is idempotent and will do nothing if infra is already in place mlrun.utils.vault.init_project_vault_configuration(project) # If no secrets were passed, no need to touch the actual secrets. if secrets_to_store: mlrun.utils.vault.store_vault_project_secrets(project, secrets_to_store) - elif secrets.provider == mlrun.api.schemas.SecretProviderName.kubernetes: - if mlrun.api.utils.singletons.k8s.get_k8s(): - mlrun.api.utils.singletons.k8s.get_k8s().store_project_secrets( + elif secrets.provider == mlrun.common.schemas.SecretProviderName.kubernetes: + if mlrun.api.utils.singletons.k8s.get_k8s_helper(): + ( + secret_name, + created, + ) = mlrun.api.utils.singletons.k8s.get_k8s_helper().store_project_secrets( project, secrets_to_store ) + secret_keys = [secret_name for secret_name in secrets_to_store.keys()] + + events_client = events_factory.EventsFactory().get_events_client() + event = events_client.generate_project_secret_event( + project=project, + secret_name=secret_name, + secret_keys=secret_keys, + action=mlrun.common.schemas.SecretEventActions.created + if created + else mlrun.common.schemas.SecretEventActions.updated, + ) + events_client.emit(event) else: raise mlrun.errors.MLRunInternalServerError( "K8s provider cannot be initialized" @@ -117,54 +134,69 @@ def store_project_secrets( def read_auth_secret( self, secret_name, raise_on_not_found=False - ) -> mlrun.api.schemas.AuthSecretData: + ) -> mlrun.common.schemas.AuthSecretData: ( username, access_key, - ) = mlrun.api.utils.singletons.k8s.get_k8s().read_auth_secret( + ) = mlrun.api.utils.singletons.k8s.get_k8s_helper().read_auth_secret( secret_name, raise_on_not_found=raise_on_not_found ) - return mlrun.api.schemas.AuthSecretData( - provider=mlrun.api.schemas.SecretProviderName.kubernetes, + return mlrun.common.schemas.AuthSecretData( + provider=mlrun.common.schemas.SecretProviderName.kubernetes, username=username, access_key=access_key, ) def store_auth_secret( self, - secret: mlrun.api.schemas.AuthSecretData, + secret: mlrun.common.schemas.AuthSecretData, ) -> str: - if secret.provider != mlrun.api.schemas.SecretProviderName.kubernetes: + if secret.provider != mlrun.common.schemas.SecretProviderName.kubernetes: raise mlrun.errors.MLRunInvalidArgumentError( f"Storing auth secret is not implemented for provider {secret.provider}" ) - if not mlrun.api.utils.singletons.k8s.get_k8s(): + if not mlrun.api.utils.singletons.k8s.get_k8s_helper(): raise mlrun.errors.MLRunInternalServerError( "K8s provider cannot be initialized" ) - return mlrun.api.utils.singletons.k8s.get_k8s().store_auth_secret( + ( + auth_secret_name, + created, + ) = mlrun.api.utils.singletons.k8s.get_k8s_helper().store_auth_secret( secret.username, secret.access_key ) + events_client = events_factory.EventsFactory().get_events_client() + event = events_client.generate_project_auth_secret_event( + username=secret.username, + secret_name=auth_secret_name, + action=mlrun.common.schemas.SecretEventActions.created + if created + else mlrun.common.schemas.SecretEventActions.updated, + ) + events_client.emit(event) + + return auth_secret_name + def delete_auth_secret( self, - provider: mlrun.api.schemas.SecretProviderName, + provider: mlrun.common.schemas.SecretProviderName, secret_name: str, ): - if provider != mlrun.api.schemas.SecretProviderName.kubernetes: + if provider != mlrun.common.schemas.SecretProviderName.kubernetes: raise mlrun.errors.MLRunInvalidArgumentError( f"Storing auth secret is not implemented for provider {provider}" ) - if not mlrun.api.utils.singletons.k8s.get_k8s(): + if not mlrun.api.utils.singletons.k8s.get_k8s_helper(): raise mlrun.errors.MLRunInternalServerError( "K8s provider cannot be initialized" ) - mlrun.api.utils.singletons.k8s.get_k8s().delete_auth_secret(secret_name) + mlrun.api.utils.singletons.k8s.get_k8s_helper().delete_auth_secret(secret_name) def delete_project_secrets( self, project: str, - provider: mlrun.api.schemas.SecretProviderName, + provider: mlrun.common.schemas.SecretProviderName, secrets: typing.Optional[typing.List[str]] = None, allow_internal_secrets: bool = False, ): @@ -186,15 +218,30 @@ def delete_project_secrets( # nothing to remove - return return - if provider == mlrun.api.schemas.SecretProviderName.vault: + if provider == mlrun.common.schemas.SecretProviderName.vault: raise mlrun.errors.MLRunInvalidArgumentError( f"Delete secret is not implemented for provider {provider}" ) - elif provider == mlrun.api.schemas.SecretProviderName.kubernetes: - if mlrun.api.utils.singletons.k8s.get_k8s(): - mlrun.api.utils.singletons.k8s.get_k8s().delete_project_secrets( + elif provider == mlrun.common.schemas.SecretProviderName.kubernetes: + if mlrun.api.utils.singletons.k8s.get_k8s_helper(): + ( + secret_name, + deleted, + ) = mlrun.api.utils.singletons.k8s.get_k8s_helper().delete_project_secrets( project, secrets ) + + events_client = events_factory.EventsFactory().get_events_client() + event = events_client.generate_project_secret_event( + project=project, + secret_name=secret_name, + secret_keys=secrets, + action=mlrun.common.schemas.SecretEventActions.deleted + if deleted + else mlrun.common.schemas.SecretEventActions.updated, + ) + events_client.emit(event) + else: raise mlrun.errors.MLRunInternalServerError( "K8s provider cannot be initialized" @@ -207,11 +254,11 @@ def delete_project_secrets( def list_project_secret_keys( self, project: str, - provider: mlrun.api.schemas.SecretProviderName, + provider: mlrun.common.schemas.SecretProviderName, token: typing.Optional[str] = None, allow_internal_secrets: bool = False, - ) -> mlrun.api.schemas.SecretKeysData: - if provider == mlrun.api.schemas.SecretProviderName.vault: + ) -> mlrun.common.schemas.SecretKeysData: + if provider == mlrun.common.schemas.SecretProviderName.vault: if not token: raise mlrun.errors.MLRunInvalidArgumentError( "Vault list project secret keys request without providing token" @@ -220,15 +267,15 @@ def list_project_secret_keys( vault = mlrun.utils.vault.VaultStore(token) secret_values = vault.get_secrets(None, project=project) secret_keys = list(secret_values.keys()) - elif provider == mlrun.api.schemas.SecretProviderName.kubernetes: + elif provider == mlrun.common.schemas.SecretProviderName.kubernetes: if token: raise mlrun.errors.MLRunInvalidArgumentError( "Cannot specify token when requesting k8s secret keys" ) - if mlrun.api.utils.singletons.k8s.get_k8s(): + if mlrun.api.utils.singletons.k8s.get_k8s_helper(): secret_keys = ( - mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_keys( + mlrun.api.utils.singletons.k8s.get_k8s_helper().get_project_secret_keys( project ) or [] @@ -249,20 +296,20 @@ def list_project_secret_keys( ) ) - return mlrun.api.schemas.SecretKeysData( + return mlrun.common.schemas.SecretKeysData( provider=provider, secret_keys=secret_keys ) def list_project_secrets( self, project: str, - provider: mlrun.api.schemas.SecretProviderName, + provider: mlrun.common.schemas.SecretProviderName, secrets: typing.Optional[typing.List[str]] = None, token: typing.Optional[str] = None, allow_secrets_from_k8s: bool = False, allow_internal_secrets: bool = False, - ) -> mlrun.api.schemas.SecretsData: - if provider == mlrun.api.schemas.SecretProviderName.vault: + ) -> mlrun.common.schemas.SecretsData: + if provider == mlrun.common.schemas.SecretProviderName.vault: if not token: raise mlrun.errors.MLRunInvalidArgumentError( "Vault list project secrets request without providing token" @@ -270,13 +317,13 @@ def list_project_secrets( vault = mlrun.utils.vault.VaultStore(token) secrets_data = vault.get_secrets(secrets, project=project) - elif provider == mlrun.api.schemas.SecretProviderName.kubernetes: + elif provider == mlrun.common.schemas.SecretProviderName.kubernetes: if not allow_secrets_from_k8s: raise mlrun.errors.MLRunAccessDeniedError( "Not allowed to list secrets data from kubernetes provider" ) secrets_data = ( - mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_data( + mlrun.api.utils.singletons.k8s.get_k8s_helper().get_project_secret_data( project, secrets ) ) @@ -291,12 +338,12 @@ def list_project_secrets( for key, value in secrets_data.items() if not self._is_internal_project_secret_key(key) } - return mlrun.api.schemas.SecretsData(provider=provider, secrets=secrets_data) + return mlrun.common.schemas.SecretsData(provider=provider, secrets=secrets_data) def delete_project_secret( self, project: str, - provider: mlrun.api.schemas.SecretProviderName, + provider: mlrun.common.schemas.SecretProviderName, secret_key: str, token: typing.Optional[str] = None, allow_secrets_from_k8s: bool = False, @@ -322,7 +369,7 @@ def delete_project_secret( if key_map: self.store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={key_map_secret_key: json.dumps(key_map)}, ), @@ -337,7 +384,7 @@ def delete_project_secret( def get_project_secret( self, project: str, - provider: mlrun.api.schemas.SecretProviderName, + provider: mlrun.common.schemas.SecretProviderName, secret_key: str, token: typing.Optional[str] = None, allow_secrets_from_k8s: bool = False, @@ -366,7 +413,7 @@ def get_project_secret( def _resolve_project_secret_key( self, project: str, - provider: mlrun.api.schemas.SecretProviderName, + provider: mlrun.common.schemas.SecretProviderName, secret_key: str, token: typing.Optional[str] = None, allow_secrets_from_k8s: bool = False, @@ -374,7 +421,7 @@ def _resolve_project_secret_key( key_map_secret_key: typing.Optional[str] = None, ) -> typing.Tuple[bool, str]: if key_map_secret_key: - if provider != mlrun.api.schemas.SecretProviderName.kubernetes: + if provider != mlrun.common.schemas.SecretProviderName.kubernetes: raise mlrun.errors.MLRunInvalidArgumentError( f"Secret using key map is not implemented for provider {provider}" ) @@ -396,7 +443,7 @@ def _resolve_project_secret_key( def _validate_and_enrich_project_secrets_to_store( self, project: str, - secrets: mlrun.api.schemas.SecretsData, + secrets: mlrun.common.schemas.SecretsData, allow_internal_secrets: bool = False, key_map_secret_key: typing.Optional[str] = None, allow_storing_key_maps: bool = False, @@ -419,7 +466,10 @@ def _validate_and_enrich_project_secrets_to_store( f"{self.key_map_secrets_key_prefix})" ) if key_map_secret_key: - if secrets.provider != mlrun.api.schemas.SecretProviderName.kubernetes: + if ( + secrets.provider + != mlrun.common.schemas.SecretProviderName.kubernetes + ): raise mlrun.errors.MLRunInvalidArgumentError( f"Storing secret using key map is not implemented for provider {secrets.provider}" ) @@ -467,7 +517,7 @@ def _get_project_secret_key_map( ) -> typing.Optional[dict]: secrets_data = self.list_project_secrets( project, - mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.SecretProviderName.kubernetes, [key_map_secret_key], allow_secrets_from_k8s=True, allow_internal_secrets=True, diff --git a/mlrun/api/crud/tags.py b/mlrun/api/crud/tags.py index 78d5ec270150..e10477d62b9f 100644 --- a/mlrun/api/crud/tags.py +++ b/mlrun/api/crud/tags.py @@ -15,10 +15,10 @@ import sqlalchemy.orm import mlrun.api.db.sqldb.db -import mlrun.api.schemas import mlrun.api.utils.projects.remotes.follower import mlrun.api.utils.singletons.db import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.config import mlrun.errors import mlrun.utils.singleton @@ -40,7 +40,7 @@ def overwrite_object_tags_with_tag( db_session: sqlalchemy.orm.Session, project: str, tag: str, - tag_objects: mlrun.api.schemas.TagObjects, + tag_objects: mlrun.common.schemas.TagObjects, ): overwrite_func = kind_to_function_names.get(tag_objects.kind, {}).get( "overwrite" @@ -61,7 +61,7 @@ def append_tag_to_objects( db_session: sqlalchemy.orm.Session, project: str, tag: str, - tag_objects: mlrun.api.schemas.TagObjects, + tag_objects: mlrun.common.schemas.TagObjects, ): append_func = kind_to_function_names.get(tag_objects.kind, {}).get("append") if not append_func: @@ -80,7 +80,7 @@ def delete_tag_from_objects( db_session: sqlalchemy.orm.Session, project: str, tag: str, - tag_objects: mlrun.api.schemas.TagObjects, + tag_objects: mlrun.common.schemas.TagObjects, ): delete_func = kind_to_function_names.get(tag_objects.kind, {}).get("delete") if not delete_func: diff --git a/mlrun/api/db/base.py b/mlrun/api/db/base.py index 3d8d726098d1..a7e0bf3ae895 100644 --- a/mlrun/api/db/base.py +++ b/mlrun/api/db/base.py @@ -17,8 +17,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, Union +import mlrun.common.schemas import mlrun.model -from mlrun.api import schemas class DBError(Exception): @@ -98,10 +98,10 @@ def list_runs( start_time_to=None, last_update_time_from=None, last_update_time_to=None, - partition_by: schemas.RunPartitionByField = None, + partition_by: mlrun.common.schemas.RunPartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: schemas.SortField = None, - partition_order: schemas.OrderType = schemas.OrderType.desc, + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, max_partitions: int = 0, requested_logs: bool = None, return_as_run_structs: bool = True, @@ -122,7 +122,7 @@ def overwrite_artifacts_with_tag( session, project: str, tag: str, - identifiers: List[schemas.ArtifactIdentifier], + identifiers: List[mlrun.common.schemas.ArtifactIdentifier], ): pass @@ -131,7 +131,7 @@ def append_tag_to_artifacts( session, project: str, tag: str, - identifiers: List[schemas.ArtifactIdentifier], + identifiers: List[mlrun.common.schemas.ArtifactIdentifier], ): pass @@ -140,7 +140,7 @@ def delete_tag_from_artifacts( session, project: str, tag: str, - identifiers: List[schemas.ArtifactIdentifier], + identifiers: List[mlrun.common.schemas.ArtifactIdentifier], ): pass @@ -172,7 +172,7 @@ def list_artifacts( since=None, until=None, kind=None, - category: schemas.ArtifactCategories = None, + category: mlrun.common.schemas.ArtifactCategories = None, iter: int = None, best_iteration: bool = False, as_records: bool = False, @@ -235,9 +235,9 @@ def create_schedule( session, project: str, name: str, - kind: schemas.ScheduleKinds, + kind: mlrun.common.schemas.ScheduleKinds, scheduled_object: Any, - cron_trigger: schemas.ScheduleCronTrigger, + cron_trigger: mlrun.common.schemas.ScheduleCronTrigger, concurrency_limit: int, labels: Dict = None, next_run_time: datetime.datetime = None, @@ -251,7 +251,7 @@ def update_schedule( project: str, name: str, scheduled_object: Any = None, - cron_trigger: schemas.ScheduleCronTrigger = None, + cron_trigger: mlrun.common.schemas.ScheduleCronTrigger = None, labels: Dict = None, last_run_uri: str = None, concurrency_limit: int = None, @@ -266,12 +266,14 @@ def list_schedules( project: str = None, name: str = None, labels: str = None, - kind: schemas.ScheduleKinds = None, - ) -> List[schemas.ScheduleRecord]: + kind: mlrun.common.schemas.ScheduleKinds = None, + ) -> List[mlrun.common.schemas.ScheduleRecord]: pass @abstractmethod - def get_schedule(self, session, project: str, name: str) -> schemas.ScheduleRecord: + def get_schedule( + self, session, project: str, name: str + ) -> mlrun.common.schemas.ScheduleRecord: pass @abstractmethod @@ -285,7 +287,7 @@ def delete_schedules(self, session, project: str): @abstractmethod def generate_projects_summaries( self, session, projects: List[str] - ) -> List[schemas.ProjectSummary]: + ) -> List[mlrun.common.schemas.ProjectSummary]: pass @abstractmethod @@ -305,17 +307,17 @@ def list_projects( self, session, owner: str = None, - format_: schemas.ProjectsFormat = schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: List[str] = None, - state: schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, names: Optional[List[str]] = None, - ) -> schemas.ProjectsOutput: + ) -> mlrun.common.schemas.ProjectsOutput: pass @abstractmethod def get_project( self, session, name: str = None, project_id: int = None - ) -> schemas.Project: + ) -> mlrun.common.schemas.Project: pass @abstractmethod @@ -332,11 +334,11 @@ async def get_project_resources_counters( pass @abstractmethod - def create_project(self, session, project: schemas.Project): + def create_project(self, session, project: mlrun.common.schemas.Project): pass @abstractmethod - def store_project(self, session, name: str, project: schemas.Project): + def store_project(self, session, name: str, project: mlrun.common.schemas.Project): pass @abstractmethod @@ -345,7 +347,7 @@ def patch_project( session, name: str, project: dict, - patch_mode: schemas.PatchMode = schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ): pass @@ -354,7 +356,7 @@ def delete_project( self, session, name: str, - deletion_strategy: schemas.DeletionStrategy = schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), ): pass @@ -363,7 +365,7 @@ def create_feature_set( self, session, project, - feature_set: schemas.FeatureSet, + feature_set: mlrun.common.schemas.FeatureSet, versioned=True, ) -> str: pass @@ -374,7 +376,7 @@ def store_feature_set( session, project, name, - feature_set: schemas.FeatureSet, + feature_set: mlrun.common.schemas.FeatureSet, tag=None, uid=None, versioned=True, @@ -385,7 +387,7 @@ def store_feature_set( @abstractmethod def get_feature_set( self, session, project: str, name: str, tag: str = None, uid: str = None - ) -> schemas.FeatureSet: + ) -> mlrun.common.schemas.FeatureSet: pass @abstractmethod @@ -397,7 +399,7 @@ def list_features( tag: str = None, entities: List[str] = None, labels: List[str] = None, - ) -> schemas.FeaturesOutput: + ) -> mlrun.common.schemas.FeaturesOutput: pass @abstractmethod @@ -408,7 +410,7 @@ def list_entities( name: str = None, tag: str = None, labels: List[str] = None, - ) -> schemas.EntitiesOutput: + ) -> mlrun.common.schemas.EntitiesOutput: pass @abstractmethod @@ -422,11 +424,11 @@ def list_feature_sets( entities: List[str] = None, features: List[str] = None, labels: List[str] = None, - partition_by: schemas.FeatureStorePartitionByField = None, + partition_by: mlrun.common.schemas.FeatureStorePartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: schemas.SortField = None, - partition_order: schemas.OrderType = schemas.OrderType.desc, - ) -> schemas.FeatureSetsOutput: + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, + ) -> mlrun.common.schemas.FeatureSetsOutput: pass @abstractmethod @@ -449,7 +451,7 @@ def patch_feature_set( feature_set_patch: dict, tag=None, uid=None, - patch_mode: schemas.PatchMode = schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ) -> str: pass @@ -462,7 +464,7 @@ def create_feature_vector( self, session, project, - feature_vector: schemas.FeatureVector, + feature_vector: mlrun.common.schemas.FeatureVector, versioned=True, ) -> str: pass @@ -470,7 +472,7 @@ def create_feature_vector( @abstractmethod def get_feature_vector( self, session, project: str, name: str, tag: str = None, uid: str = None - ) -> schemas.FeatureVector: + ) -> mlrun.common.schemas.FeatureVector: pass @abstractmethod @@ -482,11 +484,11 @@ def list_feature_vectors( tag: str = None, state: str = None, labels: List[str] = None, - partition_by: schemas.FeatureStorePartitionByField = None, + partition_by: mlrun.common.schemas.FeatureStorePartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: schemas.SortField = None, - partition_order: schemas.OrderType = schemas.OrderType.desc, - ) -> schemas.FeatureVectorsOutput: + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, + ) -> mlrun.common.schemas.FeatureVectorsOutput: pass @abstractmethod @@ -506,7 +508,7 @@ def store_feature_vector( session, project, name, - feature_vector: schemas.FeatureVector, + feature_vector: mlrun.common.schemas.FeatureVector, tag=None, uid=None, versioned=True, @@ -523,7 +525,7 @@ def patch_feature_vector( feature_vector_update: dict, tag=None, uid=None, - patch_mode: schemas.PatchMode = schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ) -> str: pass @@ -539,29 +541,33 @@ def delete_feature_vector( pass def list_artifact_tags( - self, session, project, category: Union[str, schemas.ArtifactCategories] = None + self, + session, + project, + category: Union[str, mlrun.common.schemas.ArtifactCategories] = None, ): return [] - def create_marketplace_source( - self, session, ordered_source: schemas.IndexedMarketplaceSource + def create_hub_source( + self, session, ordered_source: mlrun.common.schemas.IndexedHubSource ): pass - def store_marketplace_source( - self, session, name, ordered_source: schemas.IndexedMarketplaceSource + def store_hub_source( + self, + session, + name, + ordered_source: mlrun.common.schemas.IndexedHubSource, ): pass - def list_marketplace_sources( - self, session - ) -> List[schemas.IndexedMarketplaceSource]: + def list_hub_sources(self, session) -> List[mlrun.common.schemas.IndexedHubSource]: pass - def delete_marketplace_source(self, session, name): + def delete_hub_source(self, session, name): pass - def get_marketplace_source(self, session, name) -> schemas.IndexedMarketplaceSource: + def get_hub_source(self, session, name) -> mlrun.common.schemas.IndexedHubSource: pass def store_background_task( @@ -569,14 +575,14 @@ def store_background_task( session, name: str, project: str, - state: str = schemas.BackgroundTaskState.running, + state: str = mlrun.common.schemas.BackgroundTaskState.running, timeout: int = None, ): pass def get_background_task( self, session, name: str, project: str - ) -> schemas.BackgroundTask: + ) -> mlrun.common.schemas.BackgroundTask: pass @abstractmethod @@ -607,3 +613,13 @@ def delete_run_notifications( commit: bool = True, ): pass + + def set_run_notifications( + self, + session, + project: str, + notifications: typing.List[mlrun.model.Notification], + identifiers: typing.List[mlrun.common.schemas.RunIdentifier], + **kwargs, + ): + pass diff --git a/mlrun/api/db/filedb/db.py b/mlrun/api/db/filedb/db.py deleted file mode 100644 index 5fb3ac254d60..000000000000 --- a/mlrun/api/db/filedb/db.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import datetime -from typing import Any, Dict, List, Optional, Tuple, Union - -from mlrun.api import schemas -from mlrun.api.db.base import DBError, DBInterface -from mlrun.db.base import RunDBError -from mlrun.db.filedb import FileRunDB - - -class FileDB(DBInterface): - def __init__(self, dirpath="", format=".yaml"): - self.db = FileRunDB(dirpath, format) - - def initialize(self, session): - self.db.connect() - - def store_log( - self, - session, - uid, - project="", - body=None, - append=False, - ): - return self._transform_run_db_error( - self.db.store_log, uid, project, body, append - ) - - def get_log(self, session, uid, project="", offset=0, size=0): - return self._transform_run_db_error(self.db.get_log, uid, project, offset, size) - - def store_run( - self, - session, - struct, - uid, - project="", - iter=0, - ): - return self._transform_run_db_error( - self.db.store_run, struct, uid, project, iter - ) - - def update_run(self, session, updates: dict, uid, project="", iter=0): - return self._transform_run_db_error( - self.db.update_run, updates, uid, project, iter - ) - - def list_distinct_runs_uids( - self, - session, - project: str = None, - requested_logs_modes: List[bool] = None, - only_uids: bool = False, - last_update_time_from: datetime.datetime = None, - states: List[str] = None, - ): - raise NotImplementedError() - - def update_runs_requested_logs( - self, session, uids: List[str], requested_logs: bool = True - ): - raise NotImplementedError() - - def read_run(self, session, uid, project="", iter=0): - return self._transform_run_db_error(self.db.read_run, uid, project, iter) - - def list_runs( - self, - session, - name="", - uid: Optional[Union[str, List[str]]] = None, - project="", - labels=None, - states=None, - sort=True, - last=0, - iter=False, - start_time_from=None, - start_time_to=None, - last_update_time_from=None, - last_update_time_to=None, - partition_by: schemas.RunPartitionByField = None, - rows_per_partition: int = 1, - partition_sort_by: schemas.SortField = None, - partition_order: schemas.OrderType = schemas.OrderType.desc, - max_partitions: int = 0, - requested_logs: bool = None, - return_as_run_structs: bool = True, - with_notifications: bool = False, - ): - return self._transform_run_db_error( - self.db.list_runs, - name, - uid, - project, - labels, - states[0] if states else "", - sort, - last, - iter, - start_time_from, - start_time_to, - last_update_time_from, - last_update_time_to, - partition_by, - rows_per_partition, - partition_sort_by, - partition_order, - max_partitions, - requested_logs, - return_as_run_structs, - with_notifications, - ) - - def del_run(self, session, uid, project="", iter=0): - return self._transform_run_db_error(self.db.del_run, uid, project, iter) - - def del_runs(self, session, name="", project="", labels=None, state="", days_ago=0): - return self._transform_run_db_error( - self.db.del_runs, name, project, labels, state, days_ago - ) - - def overwrite_artifacts_with_tag( - self, - session, - project: str, - tag: str, - identifiers: List[schemas.ArtifactIdentifier], - ): - raise NotImplementedError() - - def append_tag_to_artifacts( - self, - session, - project: str, - tag: str, - identifiers: List[schemas.ArtifactIdentifier], - ): - raise NotImplementedError() - - def delete_tag_from_artifacts( - self, - session, - project: str, - tag: str, - identifiers: List[schemas.ArtifactIdentifier], - ): - raise NotImplementedError() - - def store_artifact( - self, - session, - key, - artifact, - uid, - iter=None, - tag="", - project="", - ): - return self._transform_run_db_error( - self.db.store_artifact, key, artifact, uid, iter, tag, project - ) - - def read_artifact(self, session, key, tag="", iter=None, project=""): - return self._transform_run_db_error( - self.db.read_artifact, key, tag, iter, project - ) - - def list_artifacts( - self, - session, - name="", - project="", - tag="", - labels=None, - since=None, - until=None, - kind=None, - category: schemas.ArtifactCategories = None, - iter: int = None, - best_iteration: bool = False, - as_records: bool = False, - use_tag_as_uid: bool = None, - ): - return self._transform_run_db_error( - self.db.list_artifacts, name, project, tag, labels, since, until - ) - - def del_artifact(self, session, key, tag="", project=""): - return self._transform_run_db_error(self.db.del_artifact, key, tag, project) - - def del_artifacts(self, session, name="", project="", tag="", labels=None): - return self._transform_run_db_error( - self.db.del_artifacts, name, project, tag, labels - ) - - def store_function( - self, - session, - function, - name, - project="", - tag="", - versioned=False, - ) -> str: - return self._transform_run_db_error( - self.db.store_function, function, name, project, tag, versioned - ) - - def get_function(self, session, name, project="", tag="", hash_key=""): - return self._transform_run_db_error( - self.db.get_function, name, project, tag, hash_key - ) - - def delete_function(self, session, project: str, name: str): - raise NotImplementedError() - - def list_functions( - self, session, name=None, project="", tag="", labels=None, hash_key=None - ): - return self._transform_run_db_error( - self.db.list_functions, name, project, tag, labels - ) - - def store_schedule(self, session, data): - return self._transform_run_db_error(self.db.store_schedule, data) - - def generate_projects_summaries( - self, session, projects: List[str] - ) -> List[schemas.ProjectSummary]: - raise NotImplementedError() - - def delete_project_related_resources(self, session, name: str): - raise NotImplementedError() - - def verify_project_has_no_related_resources(self, session, name: str): - raise NotImplementedError() - - def is_project_exists(self, session, name: str): - raise NotImplementedError() - - def list_projects( - self, - session, - owner: str = None, - format_: schemas.ProjectsFormat = schemas.ProjectsFormat.full, - labels: List[str] = None, - state: schemas.ProjectState = None, - names: Optional[List[str]] = None, - ) -> schemas.ProjectsOutput: - return self._transform_run_db_error( - self.db.list_projects, owner, format_, labels, state - ) - - async def get_project_resources_counters( - self, - ) -> Tuple[ - Dict[str, int], - Dict[str, int], - Dict[str, int], - Dict[str, int], - Dict[str, int], - Dict[str, int], - ]: - raise NotImplementedError() - - def store_project(self, session, name: str, project: schemas.Project): - raise NotImplementedError() - - def patch_project( - self, - session, - name: str, - project: dict, - patch_mode: schemas.PatchMode = schemas.PatchMode.replace, - ): - raise NotImplementedError() - - def create_project(self, session, project: schemas.Project): - raise NotImplementedError() - - def get_project( - self, session, name: str = None, project_id: int = None - ) -> schemas.Project: - raise NotImplementedError() - - def delete_project( - self, - session, - name: str, - deletion_strategy: schemas.DeletionStrategy = schemas.DeletionStrategy.default(), - ): - raise NotImplementedError() - - def create_feature_set( - self, - session, - project, - feature_set: schemas.FeatureSet, - versioned=True, - ) -> str: - raise NotImplementedError() - - def store_feature_set( - self, - session, - project, - name, - feature_set: schemas.FeatureSet, - tag=None, - uid=None, - versioned=True, - always_overwrite=False, - ) -> str: - raise NotImplementedError() - - def get_feature_set( - self, session, project: str, name: str, tag: str = None, uid: str = None - ) -> schemas.FeatureSet: - raise NotImplementedError() - - def list_features( - self, - session, - project: str, - name: str = None, - tag: str = None, - entities: List[str] = None, - labels: List[str] = None, - ) -> schemas.FeaturesOutput: - raise NotImplementedError() - - def list_entities( - self, - session, - project: str, - name: str = None, - tag: str = None, - labels: List[str] = None, - ) -> schemas.EntitiesOutput: - pass - - def list_feature_sets( - self, - session, - project: str, - name: str = None, - tag: str = None, - state: str = None, - entities: List[str] = None, - features: List[str] = None, - labels: List[str] = None, - partition_by: schemas.FeatureStorePartitionByField = None, - rows_per_partition: int = 1, - partition_sort_by: schemas.SortField = None, - partition_order: schemas.OrderType = schemas.OrderType.desc, - ) -> schemas.FeatureSetsOutput: - raise NotImplementedError() - - def list_feature_sets_tags( - self, - session, - project: str, - ): - raise NotImplementedError() - - def patch_feature_set( - self, - session, - project, - name, - feature_set_patch: dict, - tag=None, - uid=None, - patch_mode: schemas.PatchMode = schemas.PatchMode.replace, - ) -> str: - raise NotImplementedError() - - def delete_feature_set(self, session, project, name, tag=None, uid=None): - raise NotImplementedError() - - def create_feature_vector( - self, - session, - project, - feature_vector: schemas.FeatureVector, - versioned=True, - ) -> str: - raise NotImplementedError() - - def get_feature_vector( - self, session, project: str, name: str, tag: str = None, uid: str = None - ) -> schemas.FeatureVector: - raise NotImplementedError() - - def list_feature_vectors( - self, - session, - project: str, - name: str = None, - tag: str = None, - state: str = None, - labels: List[str] = None, - partition_by: schemas.FeatureStorePartitionByField = None, - rows_per_partition: int = 1, - partition_sort_by: schemas.SortField = None, - partition_order: schemas.OrderType = schemas.OrderType.desc, - ) -> schemas.FeatureVectorsOutput: - raise NotImplementedError() - - def list_feature_vectors_tags( - self, - session, - project: str, - ): - raise NotImplementedError() - - def store_feature_vector( - self, - session, - project, - name, - feature_vector: schemas.FeatureVector, - tag=None, - uid=None, - versioned=True, - always_overwrite=False, - ) -> str: - raise NotImplementedError() - - def patch_feature_vector( - self, - session, - project, - name, - feature_vector_update: dict, - tag=None, - uid=None, - patch_mode: schemas.PatchMode = schemas.PatchMode.replace, - ) -> str: - raise NotImplementedError() - - def delete_feature_vector(self, session, project, name, tag=None, uid=None): - raise NotImplementedError() - - def list_artifact_tags( - self, session, project, category: Union[str, schemas.ArtifactCategories] = None - ): - return self._transform_run_db_error( - self.db.list_artifact_tags, project, category - ) - - def create_schedule( - self, - session, - project: str, - name: str, - kind: schemas.ScheduleKinds, - scheduled_object: Any, - cron_trigger: schemas.ScheduleCronTrigger, - concurrency_limit: int, - labels: Dict = None, - next_run_time: datetime.datetime = None, - ): - raise NotImplementedError() - - def update_schedule( - self, - session, - project: str, - name: str, - scheduled_object: Any = None, - cron_trigger: schemas.ScheduleCronTrigger = None, - labels: Dict = None, - last_run_uri: str = None, - concurrency_limit: int = None, - next_run_time: datetime.datetime = None, - ): - raise NotImplementedError() - - def list_schedules( - self, - session, - project: str = None, - name: str = None, - labels: str = None, - kind: schemas.ScheduleKinds = None, - ) -> List[schemas.ScheduleRecord]: - raise NotImplementedError() - - def get_schedule(self, session, project: str, name: str) -> schemas.ScheduleRecord: - raise NotImplementedError() - - def delete_schedule(self, session, project: str, name: str): - raise NotImplementedError() - - def delete_schedules(self, session, project: str): - raise NotImplementedError() - - @staticmethod - def _transform_run_db_error(func, *args, **kwargs): - try: - return func(*args, **kwargs) - except RunDBError as exc: - raise DBError(exc.args) - - def store_run_notifications( - self, session, notification_objects, run_uid: str, project: str - ): - raise NotImplementedError() - - def list_run_notifications( - self, - session, - run_uid: str, - project: str = "", - ): - raise NotImplementedError() - - def delete_run_notifications( - self, - session, - name: str = None, - run_uid: str = None, - project: str = None, - commit: bool = True, - ): - raise NotImplementedError() diff --git a/mlrun/api/db/init_db.py b/mlrun/api/db/init_db.py index 21b6966cb757..fe58fbdfad72 100644 --- a/mlrun/api/db/init_db.py +++ b/mlrun/api/db/init_db.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from sqlalchemy.orm import Session from mlrun.api.db.sqldb.models import Base from mlrun.api.db.sqldb.session import get_engine from mlrun.config import config -def init_db(db_session: Session) -> None: +def init_db() -> None: if config.httpdb.db_type != "filedb": Base.metadata.create_all(bind=get_engine()) diff --git a/mlrun/api/db/session.py b/mlrun/api/db/session.py index 3db0e9b7fcf3..ef62d84d849f 100644 --- a/mlrun/api/db/session.py +++ b/mlrun/api/db/session.py @@ -15,22 +15,14 @@ from sqlalchemy.orm import Session from mlrun.api.db.sqldb.session import create_session as sqldb_create_session -from mlrun.config import config -def create_session(db_type=None) -> Session: - db_type = db_type or config.httpdb.db_type - if db_type == "filedb": - return None - else: - return sqldb_create_session() +def create_session() -> Session: + return sqldb_create_session() def close_session(db_session): - - # will be None when it's filedb session - if db_session is not None: - db_session.close() + db_session.close() def run_function_with_new_db_session(func): diff --git a/mlrun/api/db/sqldb/db.py b/mlrun/api/db/sqldb/db.py index e8a916c1bee3..6a0d249d9bac 100644 --- a/mlrun/api/db/sqldb/db.py +++ b/mlrun/api/db/sqldb/db.py @@ -30,11 +30,12 @@ import mlrun import mlrun.api.db.session +import mlrun.api.utils.helpers import mlrun.api.utils.projects.remotes.follower import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas import mlrun.errors import mlrun.model -from mlrun.api import schemas from mlrun.api.db.base import DBInterface from mlrun.api.db.sqldb.helpers import ( generate_query_predicate_for_name, @@ -53,15 +54,15 @@ FeatureSet, FeatureVector, Function, + HubSource, Log, - MarketplaceSource, - Notification, Project, Run, Schedule, User, _labeled, _tagged, + _with_notifications, ) from mlrun.config import config from mlrun.errors import err_to_str @@ -184,7 +185,11 @@ def store_run( iter=0, ): logger.debug( - "Storing run to db", project=project, uid=uid, iter=iter, run=run_data + "Storing run to db", + project=project, + uid=uid, + iter=iter, + run_name=run_data["metadata"]["name"], ) run = self._get_run(session, uid, project, iter) now = datetime.now(timezone.utc) @@ -321,10 +326,10 @@ def list_runs( start_time_to=None, last_update_time_from=None, last_update_time_to=None, - partition_by: schemas.RunPartitionByField = None, + partition_by: mlrun.common.schemas.RunPartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: schemas.SortField = None, - partition_order: schemas.OrderType = schemas.OrderType.desc, + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, max_partitions: int = 0, requested_logs: bool = None, return_as_run_structs: bool = True, @@ -358,7 +363,9 @@ def list_runs( query = query.filter(Run.requested_logs == requested_logs) if partition_by: self._assert_partition_by_parameters( - schemas.RunPartitionByField, partition_by, partition_sort_by + mlrun.common.schemas.RunPartitionByField, + partition_by, + partition_sort_by, ) query = self._create_partitioned_query( session, @@ -375,7 +382,7 @@ def list_runs( # Purposefully not using outer join to avoid returning runs without notifications if with_notifications: - query = query.join(Notification, Run.id == Notification.run) + query = query.join(Run.Notification) runs = RunList() for run in query: @@ -456,7 +463,7 @@ def overwrite_artifacts_with_tag( session: Session, project: str, tag: str, - identifiers: typing.List[mlrun.api.schemas.ArtifactIdentifier], + identifiers: typing.List[mlrun.common.schemas.ArtifactIdentifier], ): # query all artifacts which match the identifiers artifacts = [] @@ -479,7 +486,7 @@ def append_tag_to_artifacts( session: Session, project: str, tag: str, - identifiers: typing.List[mlrun.api.schemas.ArtifactIdentifier], + identifiers: typing.List[mlrun.common.schemas.ArtifactIdentifier], ): # query all artifacts which match the identifiers artifacts = [] @@ -496,7 +503,7 @@ def delete_tag_from_artifacts( session: Session, project: str, tag: str, - identifiers: typing.List[mlrun.api.schemas.ArtifactIdentifier], + identifiers: typing.List[mlrun.common.schemas.ArtifactIdentifier], ): # query all artifacts which match the identifiers artifacts = [] @@ -512,7 +519,7 @@ def _list_artifacts_for_tagging( self, session: Session, project_name: str, - identifier: mlrun.api.schemas.ArtifactIdentifier, + identifier: mlrun.common.schemas.ArtifactIdentifier, ): return self.list_artifacts( session, @@ -716,7 +723,7 @@ def list_artifacts( since=None, until=None, kind=None, - category: schemas.ArtifactCategories = None, + category: mlrun.common.schemas.ArtifactCategories = None, iter: int = None, best_iteration: bool = False, as_records: bool = False, @@ -981,7 +988,31 @@ def store_function( self.tag_objects_v2(session, [fn], project, tag) return hash_key - def get_function(self, session, name, project="", tag="", hash_key=""): + def get_function(self, session, name, project="", tag="", hash_key="") -> dict: + """ + In version 1.4.0 we added a normalization to the function name before storing. + To be backwards compatible and allow users to query old non-normalized functions, + we're providing a fallback to get_function: + normalize the requested name and try to retrieve it from the database. + If no answer is received, we will check to see if the original name contained underscores, + if so, the retrieval will be repeated and the result (if it exists) returned. + """ + normalized_function_name = mlrun.utils.normalize_name(name) + try: + return self._get_function( + session, normalized_function_name, project, tag, hash_key + ) + except mlrun.errors.MLRunNotFoundError as exc: + if "_" in name: + logger.warning( + "Failed to get underscore-named function, trying without normalization", + function_name=name, + ) + return self._get_function(session, name, project, tag, hash_key) + else: + raise exc + + def _get_function(self, session, name, project="", tag="", hash_key=""): project = project or config.default_project query = self._query(session, Function, name=name, project=project) computed_tag = tag or "latest" @@ -1118,7 +1149,7 @@ def _list_function_tags(self, session, project, function_id): return [row[0] for row in query] def list_artifact_tags( - self, session, project, category: schemas.ArtifactCategories = None + self, session, project, category: mlrun.common.schemas.ArtifactCategories = None ) -> typing.List[typing.Tuple[str, str, str]]: """ :return: a list of Tuple of (project, artifact.key, tag) @@ -1149,9 +1180,9 @@ def create_schedule( session: Session, project: str, name: str, - kind: schemas.ScheduleKinds, + kind: mlrun.common.schemas.ScheduleKinds, scheduled_object: Any, - cron_trigger: schemas.ScheduleCronTrigger, + cron_trigger: mlrun.common.schemas.ScheduleCronTrigger, concurrency_limit: int, labels: Dict = None, next_run_time: datetime = None, @@ -1194,7 +1225,7 @@ def update_schedule( project: str, name: str, scheduled_object: Any = None, - cron_trigger: schemas.ScheduleCronTrigger = None, + cron_trigger: mlrun.common.schemas.ScheduleCronTrigger = None, labels: Dict = None, last_run_uri: str = None, concurrency_limit: int = None, @@ -1240,8 +1271,8 @@ def list_schedules( project: str = None, name: str = None, labels: str = None, - kind: schemas.ScheduleKinds = None, - ) -> List[schemas.ScheduleRecord]: + kind: mlrun.common.schemas.ScheduleKinds = None, + ) -> List[mlrun.common.schemas.ScheduleRecord]: logger.debug("Getting schedules from db", project=project, name=name, kind=kind) query = self._query(session, Schedule, project=project, kind=kind) if name is not None: @@ -1257,7 +1288,7 @@ def list_schedules( def get_schedule( self, session: Session, project: str, name: str - ) -> schemas.ScheduleRecord: + ) -> mlrun.common.schemas.ScheduleRecord: logger.debug("Getting schedule from db", project=project, name=name) schedule_record = self._get_schedule_record(session, project, name) schedule = self._transform_schedule_record_to_scheme(schedule_record) @@ -1265,7 +1296,7 @@ def get_schedule( def _get_schedule_record( self, session: Session, project: str, name: str - ) -> schemas.ScheduleRecord: + ) -> mlrun.common.schemas.ScheduleRecord: query = self._query(session, Schedule, project=project, name=name) schedule_record = query.one_or_none() if not schedule_record: @@ -1358,7 +1389,7 @@ def tag_objects_v2(self, session, objs, project: str, name: str): tags.append(tag) self._upsert(session, tags) - def create_project(self, session: Session, project: schemas.Project): + def create_project(self, session: Session, project: mlrun.common.schemas.Project): logger.debug("Creating project in DB", project=project) created = datetime.utcnow() project.metadata.created = created @@ -1377,8 +1408,17 @@ def create_project(self, session: Session, project: schemas.Project): self._upsert(session, [project_record]) @retry_on_conflict - def store_project(self, session: Session, name: str, project: schemas.Project): - logger.debug("Storing project in DB", name=name, project=project) + def store_project( + self, session: Session, name: str, project: mlrun.common.schemas.Project + ): + logger.debug( + "Storing project in DB", + name=name, + project_metadata=project.metadata, + project_owner=project.spec.owner, + project_desired_state=project.spec.desired_state, + project_status=project.status, + ) project_record = self._get_project_record( session, name, raise_on_not_found=False ) @@ -1392,11 +1432,9 @@ def patch_project( session: Session, name: str, project: dict, - patch_mode: schemas.PatchMode = schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ): - logger.debug( - "Patching project in DB", name=name, project=project, patch_mode=patch_mode - ) + logger.debug("Patching project in DB", name=name, patch_mode=patch_mode) project_record = self._get_project_record(session, name) self._patch_project_record_from_project( session, name, project_record, project, patch_mode @@ -1404,7 +1442,7 @@ def patch_project( def get_project( self, session: Session, name: str = None, project_id: int = None - ) -> schemas.Project: + ) -> mlrun.common.schemas.Project: project_record = self._get_project_record(session, name, project_id) return self._transform_project_record_to_schema(session, project_record) @@ -1413,7 +1451,7 @@ def delete_project( self, session: Session, name: str, - deletion_strategy: schemas.DeletionStrategy = schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), ): logger.debug( "Deleting project from DB", name=name, deletion_strategy=deletion_strategy @@ -1424,25 +1462,45 @@ def list_projects( self, session: Session, owner: str = None, - format_: mlrun.api.schemas.ProjectsFormat = mlrun.api.schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: List[str] = None, - state: mlrun.api.schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, names: typing.Optional[typing.List[str]] = None, - ) -> schemas.ProjectsOutput: + ) -> mlrun.common.schemas.ProjectsOutput: query = self._query(session, Project, owner=owner, state=state) + + # if format is name_only, we don't need to query the full project object, we can just query the name + # and return it as a list of strings + if format_ == mlrun.common.schemas.ProjectsFormat.name_only: + query = self._query(session, Project.name, owner=owner, state=state) + + # attach filters to the query if labels: query = self._add_labels_filter(session, query, Project, labels) if names is not None: query = query.filter(Project.name.in_(names)) + project_records = query.all() + + # format the projects according to the requested format projects = [] for project_record in project_records: - if format_ == mlrun.api.schemas.ProjectsFormat.name_only: - projects = [project_record.name for project_record in project_records] + if format_ == mlrun.common.schemas.ProjectsFormat.name_only: + projects.append(project_record.name) + + elif format_ == mlrun.common.schemas.ProjectsFormat.minimal: + projects.append( + mlrun.api.utils.helpers.minimize_project_schema( + self._transform_project_record_to_schema( + session, project_record + ) + ) + ) + # leader format is only for follower mode which will format the projects returned from here elif format_ in [ - mlrun.api.schemas.ProjectsFormat.full, - mlrun.api.schemas.ProjectsFormat.leader, + mlrun.common.schemas.ProjectsFormat.full, + mlrun.common.schemas.ProjectsFormat.leader, ]: projects.append( self._transform_project_record_to_schema(session, project_record) @@ -1451,7 +1509,7 @@ def list_projects( raise NotImplementedError( f"Provided format is not supported. format={format_}" ) - return schemas.ProjectsOutput(projects=projects) + return mlrun.common.schemas.ProjectsOutput(projects=projects) async def get_project_resources_counters( self, @@ -1560,7 +1618,10 @@ def _calculate_files_counters(self, session) -> Dict[str, int]: # We're using the "latest" which gives us only one version of each artifact key, which is what we want to # count (artifact count, not artifact versions count) file_artifacts = self._find_artifacts( - session, None, "latest", category=mlrun.api.schemas.ArtifactCategories.other + session, + None, + "latest", + category=mlrun.common.schemas.ArtifactCategories.other, ) project_to_files_count = collections.defaultdict(int) for file_artifact in file_artifacts: @@ -1604,7 +1665,7 @@ def _calculate_runs_counters( async def generate_projects_summaries( self, session: Session, projects: List[str] - ) -> List[mlrun.api.schemas.ProjectSummary]: + ) -> List[mlrun.common.schemas.ProjectSummary]: ( project_to_function_count, project_to_schedule_count, @@ -1616,7 +1677,7 @@ async def generate_projects_summaries( project_summaries = [] for project in projects: project_summaries.append( - mlrun.api.schemas.ProjectSummary( + mlrun.common.schemas.ProjectSummary( name=project, functions_count=project_to_function_count.get(project, 0), schedules_count=project_to_schedule_count.get(project, 0), @@ -1634,7 +1695,10 @@ async def generate_projects_summaries( return project_summaries def _update_project_record_from_project( - self, session: Session, project_record: Project, project: schemas.Project + self, + session: Session, + project_record: Project, + project: mlrun.common.schemas.Project, ): project.metadata.created = project_record.created project_dict = project.dict() @@ -1654,7 +1718,7 @@ def _patch_project_record_from_project( name: str, project_record: Project, project: dict, - patch_mode: schemas.PatchMode, + patch_mode: mlrun.common.schemas.PatchMode, ): project.setdefault("metadata", {})["created"] = project_record.created strategy = patch_mode.to_mergedeep_strategy() @@ -1662,7 +1726,7 @@ def _patch_project_record_from_project( mergedeep.merge(project_record_full_object, project, strategy=strategy) # If a bad kind value was passed, it will fail here (return 422 to caller) - project = schemas.Project(**project_record_full_object) + project = mlrun.common.schemas.Project(**project_record_full_object) self.store_project( session, name, @@ -1686,7 +1750,7 @@ def _get_project_record( name: str = None, project_id: int = None, raise_on_not_found: bool = True, - ) -> Project: + ) -> typing.Optional[Project]: if not any([project_id, name]): raise mlrun.errors.MLRunInvalidArgumentError( "One of 'name' or 'project_id' must be provided" @@ -1712,7 +1776,9 @@ def verify_project_has_no_related_resources(self, session: Session, name: str): self._verify_empty_list_of_project_related_resources(name, logs, "logs") runs = self._find_runs(session, None, name, []).all() self._verify_empty_list_of_project_related_resources(name, runs, "runs") - notifications = self._get_db_notifications(session, project=name) + notifications = [] + for cls in _with_notifications: + notifications.extend(self._get_db_notifications(session, cls, project=name)) self._verify_empty_list_of_project_related_resources( name, notifications, "notifications" ) @@ -1813,7 +1879,7 @@ def get_feature_set( name: str, tag: str = None, uid: str = None, - ) -> schemas.FeatureSet: + ) -> mlrun.common.schemas.FeatureSet: feature_set = self._get_feature_set(session, project, name, tag, uid) if not feature_set: feature_set_uri = generate_object_uri(project, name, tag) @@ -1861,10 +1927,10 @@ def _generate_records_with_tags_assigned( return results @staticmethod - def _generate_feature_set_digest(feature_set: schemas.FeatureSet): - return schemas.FeatureSetDigestOutput( + def _generate_feature_set_digest(feature_set: mlrun.common.schemas.FeatureSet): + return mlrun.common.schemas.FeatureSetDigestOutput( metadata=feature_set.metadata, - spec=schemas.FeatureSetDigestSpec( + spec=mlrun.common.schemas.FeatureSetDigestSpec( entities=feature_set.spec.entities, features=feature_set.spec.features, ), @@ -1906,7 +1972,7 @@ def list_features( tag: str = None, entities: List[str] = None, labels: List[str] = None, - ) -> schemas.FeaturesOutput: + ) -> mlrun.common.schemas.FeaturesOutput: # We don't filter by feature-set name here, as the name parameter refers to features feature_set_id_tags = self._get_records_to_tags_map( session, FeatureSet, project, tag, name=None @@ -1921,7 +1987,7 @@ def list_features( features_results = [] for row in query: - feature_record = schemas.FeatureRecord.from_orm(row.Feature) + feature_record = mlrun.common.schemas.FeatureRecord.from_orm(row.Feature) feature_name = feature_record.name feature_sets = self._generate_records_with_tags_assigned( @@ -1948,14 +2014,14 @@ def list_features( ) features_results.append( - schemas.FeatureListOutput( + mlrun.common.schemas.FeatureListOutput( feature=feature, feature_set_digest=self._generate_feature_set_digest( feature_set ), ) ) - return schemas.FeaturesOutput(features=features_results) + return mlrun.common.schemas.FeaturesOutput(features=features_results) def list_entities( self, @@ -1964,7 +2030,7 @@ def list_entities( name: str = None, tag: str = None, labels: List[str] = None, - ) -> schemas.EntitiesOutput: + ) -> mlrun.common.schemas.EntitiesOutput: feature_set_id_tags = self._get_records_to_tags_map( session, FeatureSet, project, tag, name=None ) @@ -1975,7 +2041,7 @@ def list_entities( entities_results = [] for row in query: - entity_record = schemas.FeatureRecord.from_orm(row.Entity) + entity_record = mlrun.common.schemas.FeatureRecord.from_orm(row.Entity) entity_name = entity_record.name feature_sets = self._generate_records_with_tags_assigned( @@ -2002,14 +2068,14 @@ def list_entities( ) entities_results.append( - schemas.EntityListOutput( + mlrun.common.schemas.EntityListOutput( entity=entity, feature_set_digest=self._generate_feature_set_digest( feature_set ), ) ) - return schemas.EntitiesOutput(entities=entities_results) + return mlrun.common.schemas.EntitiesOutput(entities=entities_results) @staticmethod def _assert_partition_by_parameters(partition_by_enum_cls, partition_by, sort): @@ -2032,11 +2098,12 @@ def _create_partitioned_query( query, cls, partition_by: typing.Union[ - schemas.FeatureStorePartitionByField, schemas.RunPartitionByField + mlrun.common.schemas.FeatureStorePartitionByField, + mlrun.common.schemas.RunPartitionByField, ], rows_per_partition: int, - partition_sort_by: schemas.SortField, - partition_order: schemas.OrderType, + partition_sort_by: mlrun.common.schemas.SortField, + partition_order: mlrun.common.schemas.OrderType, max_partitions: int = 0, ): @@ -2102,11 +2169,11 @@ def list_feature_sets( entities: List[str] = None, features: List[str] = None, labels: List[str] = None, - partition_by: schemas.FeatureStorePartitionByField = None, + partition_by: mlrun.common.schemas.FeatureStorePartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: schemas.SortField = None, - partition_order: schemas.OrderType = schemas.OrderType.desc, - ) -> schemas.FeatureSetsOutput: + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, + ) -> mlrun.common.schemas.FeatureSetsOutput: obj_id_tags = self._get_records_to_tags_map( session, FeatureSet, project, tag, name ) @@ -2129,7 +2196,9 @@ def list_feature_sets( if partition_by: self._assert_partition_by_parameters( - schemas.FeatureStorePartitionByField, partition_by, partition_sort_by + mlrun.common.schemas.FeatureStorePartitionByField, + partition_by, + partition_sort_by, ) query = self._create_partitioned_query( session, @@ -2151,7 +2220,7 @@ def list_feature_sets( tag, ) ) - return schemas.FeatureSetsOutput(feature_sets=feature_sets) + return mlrun.common.schemas.FeatureSetsOutput(feature_sets=feature_sets) def list_feature_sets_tags( self, @@ -2282,7 +2351,7 @@ def store_feature_set( session, project, name, - feature_set: schemas.FeatureSet, + feature_set: mlrun.common.schemas.FeatureSet, tag=None, uid=None, versioned=True, @@ -2381,7 +2450,7 @@ def create_feature_set( self, session, project, - feature_set: schemas.FeatureSet, + feature_set: mlrun.common.schemas.FeatureSet, versioned=True, ) -> str: (uid, tag, feature_set_dict,) = self._validate_and_enrich_record_for_creation( @@ -2406,7 +2475,7 @@ def patch_feature_set( feature_set_patch: dict, tag=None, uid=None, - patch_mode: schemas.PatchMode = schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ) -> str: feature_set_record = self._get_feature_set(session, project, name, tag, uid) if not feature_set_record: @@ -2423,7 +2492,7 @@ def patch_feature_set( versioned = feature_set_record.metadata.uid is not None # If a bad kind value was passed, it will fail here (return 422 to caller) - feature_set = schemas.FeatureSet(**feature_set_struct) + feature_set = mlrun.common.schemas.FeatureSet(**feature_set_struct) return self.store_feature_set( session, project, @@ -2474,7 +2543,7 @@ def create_feature_vector( self, session, project, - feature_vector: schemas.FeatureVector, + feature_vector: mlrun.common.schemas.FeatureVector, versioned=True, ) -> str: ( @@ -2525,7 +2594,7 @@ def _get_feature_vector( def get_feature_vector( self, session, project: str, name: str, tag: str = None, uid: str = None - ) -> schemas.FeatureVector: + ) -> mlrun.common.schemas.FeatureVector: feature_vector = self._get_feature_vector(session, project, name, tag, uid) if not feature_vector: feature_vector_uri = generate_object_uri(project, name, tag) @@ -2543,11 +2612,11 @@ def list_feature_vectors( tag: str = None, state: str = None, labels: List[str] = None, - partition_by: schemas.FeatureStorePartitionByField = None, + partition_by: mlrun.common.schemas.FeatureStorePartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: schemas.SortField = None, - partition_order: schemas.OrderType = schemas.OrderType.desc, - ) -> schemas.FeatureVectorsOutput: + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, + ) -> mlrun.common.schemas.FeatureVectorsOutput: obj_id_tags = self._get_records_to_tags_map( session, FeatureVector, project, tag, name ) @@ -2566,7 +2635,9 @@ def list_feature_vectors( if partition_by: self._assert_partition_by_parameters( - schemas.FeatureStorePartitionByField, partition_by, partition_sort_by + mlrun.common.schemas.FeatureStorePartitionByField, + partition_by, + partition_sort_by, ) query = self._create_partitioned_query( session, @@ -2588,7 +2659,9 @@ def list_feature_vectors( tag, ) ) - return schemas.FeatureVectorsOutput(feature_vectors=feature_vectors) + return mlrun.common.schemas.FeatureVectorsOutput( + feature_vectors=feature_vectors + ) def list_feature_vectors_tags( self, @@ -2609,7 +2682,7 @@ def store_feature_vector( session, project, name, - feature_vector: schemas.FeatureVector, + feature_vector: mlrun.common.schemas.FeatureVector, tag=None, uid=None, versioned=True, @@ -2672,7 +2745,7 @@ def patch_feature_vector( feature_vector_update: dict, tag=None, uid=None, - patch_mode: schemas.PatchMode = schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ) -> str: feature_vector_record = self._get_feature_vector( session, project, name, tag, uid @@ -2690,7 +2763,7 @@ def patch_feature_vector( versioned = feature_vector_record.metadata.uid is not None - feature_vector = schemas.FeatureVector(**feature_vector_struct) + feature_vector = mlrun.common.schemas.FeatureVector(**feature_vector_struct) return self.store_feature_vector( session, project, @@ -2741,17 +2814,6 @@ def _query(self, session, cls, **kw): kw = {k: v for k, v in kw.items() if v is not None} return session.query(cls).filter_by(**kw) - def _function_latest_uid(self, session, project, name): - # FIXME - query = ( - self._query(session, Function.uid) - .filter(Function.project == project, Function.name == name) - .order_by(Function.updated.desc()) - ).limit(1) - out = query.one_or_none() - if out: - return out[0] - def _find_or_create_users(self, session, user_names): users = list(self._query(session, User).filter(User.name.in_(user_names))) new = set(user_names) - {user.name for user in users} @@ -2861,10 +2923,10 @@ def _find_runs(self, session, uid, project, labels): return self._add_labels_filter(session, query, Run, labels) def _get_db_notifications( - self, session, name: str = None, run_id: int = None, project: str = None + self, session, cls, name: str = None, parent_id: str = None, project: str = None ): return self._query( - session, Notification, name=name, run=run_id, project=project + session, cls.Notification, name=name, parent_id=parent_id, project=project ).all() def _latest_uid_filter(self, session, query): @@ -2942,7 +3004,7 @@ def _find_artifacts( until=None, name=None, kind=None, - category: schemas.ArtifactCategories = None, + category: mlrun.common.schemas.ArtifactCategories = None, iter=None, use_tag_as_uid: bool = None, ): @@ -2993,7 +3055,7 @@ def _find_artifacts( return query.all() def _filter_artifacts_by_category( - self, artifacts, category: schemas.ArtifactCategories + self, artifacts, category: mlrun.common.schemas.ArtifactCategories ): kinds, exclude = category.to_kinds_filter() return self._filter_artifacts_by_kinds(artifacts, kinds, exclude) @@ -3140,8 +3202,8 @@ def _delete_class_labels( def _transform_schedule_record_to_scheme( self, schedule_record: Schedule, - ) -> schemas.ScheduleRecord: - schedule = schemas.ScheduleRecord.from_orm(schedule_record) + ) -> mlrun.common.schemas.ScheduleRecord: + schedule = mlrun.common.schemas.ScheduleRecord.from_orm(schedule_record) schedule.creation_time = self._add_utc_timezone(schedule.creation_time) schedule.next_run_time = self._add_utc_timezone(schedule.next_run_time) return schedule @@ -3161,9 +3223,9 @@ def _add_utc_timezone(time_value: typing.Optional[datetime]): def _transform_feature_set_model_to_schema( feature_set_record: FeatureSet, tag=None, - ) -> schemas.FeatureSet: + ) -> mlrun.common.schemas.FeatureSet: feature_set_full_dict = feature_set_record.full_object - feature_set_resp = schemas.FeatureSet(**feature_set_full_dict) + feature_set_resp = mlrun.common.schemas.FeatureSet(**feature_set_full_dict) feature_set_resp.metadata.tag = tag return feature_set_resp @@ -3172,9 +3234,11 @@ def _transform_feature_set_model_to_schema( def _transform_feature_vector_model_to_schema( feature_vector_record: FeatureVector, tag=None, - ) -> schemas.FeatureVector: + ) -> mlrun.common.schemas.FeatureVector: feature_vector_full_dict = feature_vector_record.full_object - feature_vector_resp = schemas.FeatureVector(**feature_vector_full_dict) + feature_vector_resp = mlrun.common.schemas.FeatureVector( + **feature_vector_full_dict + ) feature_vector_resp.metadata.tag = tag feature_vector_resp.metadata.created = feature_vector_record.created @@ -3182,30 +3246,30 @@ def _transform_feature_vector_model_to_schema( def _transform_project_record_to_schema( self, session: Session, project_record: Project - ) -> schemas.Project: + ) -> mlrun.common.schemas.Project: # in projects that was created before 0.6.0 the full object wasn't created properly - fix that, and return if not project_record.full_object: - project = schemas.Project( - metadata=schemas.ProjectMetadata( + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name=project_record.name, created=project_record.created, ), - spec=schemas.ProjectSpec( + spec=mlrun.common.schemas.ProjectSpec( description=project_record.description, source=project_record.source, ), - status=schemas.ObjectStatus( + status=mlrun.common.schemas.ObjectStatus( state=project_record.state, ), ) self.store_project(session, project_record.name, project) return project # TODO: handle transforming the functions/workflows/artifacts references to real objects - return schemas.Project(**project_record.full_object) + return mlrun.common.schemas.Project(**project_record.full_object) def _transform_notification_record_to_spec_and_status( self, - notification_record: Notification, + notification_record, ) -> typing.Tuple[dict, dict]: notification_spec = self._transform_notification_record_to_schema( notification_record @@ -3218,7 +3282,7 @@ def _transform_notification_record_to_spec_and_status( @staticmethod def _transform_notification_record_to_schema( - notification_record: Notification, + notification_record, ) -> mlrun.model.Notification: return mlrun.model.Notification( kind=notification_record.kind, @@ -3258,11 +3322,9 @@ def _move_and_reorder_table_items( else: start, end = move_to, move_from - 1 - query = session.query(MarketplaceSource).filter( - MarketplaceSource.index >= start - ) + query = session.query(HubSource).filter(HubSource.index >= start) if end: - query = query.filter(MarketplaceSource.index <= end) + query = query.filter(HubSource.index <= end) for source_record in query: source_record.index = source_record.index + modifier @@ -3277,54 +3339,54 @@ def _move_and_reorder_table_items( session.commit() @staticmethod - def _transform_marketplace_source_record_to_schema( - marketplace_source_record: MarketplaceSource, - ) -> schemas.IndexedMarketplaceSource: - source_full_dict = marketplace_source_record.full_object - marketplace_source = schemas.MarketplaceSource(**source_full_dict) - return schemas.IndexedMarketplaceSource( - index=marketplace_source_record.index, source=marketplace_source + def _transform_hub_source_record_to_schema( + hub_source_record: HubSource, + ) -> mlrun.common.schemas.IndexedHubSource: + source_full_dict = hub_source_record.full_object + hub_source = mlrun.common.schemas.HubSource(**source_full_dict) + return mlrun.common.schemas.IndexedHubSource( + index=hub_source_record.index, source=hub_source ) @staticmethod - def _transform_marketplace_source_schema_to_record( - marketplace_source_schema: schemas.IndexedMarketplaceSource, - current_object: MarketplaceSource = None, + def _transform_hub_source_schema_to_record( + hub_source_schema: mlrun.common.schemas.IndexedHubSource, + current_object: HubSource = None, ): now = datetime.now(timezone.utc) if current_object: - if current_object.name != marketplace_source_schema.source.metadata.name: + if current_object.name != hub_source_schema.source.metadata.name: raise mlrun.errors.MLRunInternalServerError( "Attempt to update object while replacing its name" ) created_timestamp = current_object.created else: - created_timestamp = marketplace_source_schema.source.metadata.created or now - updated_timestamp = marketplace_source_schema.source.metadata.updated or now + created_timestamp = hub_source_schema.source.metadata.created or now + updated_timestamp = hub_source_schema.source.metadata.updated or now - marketplace_source_record = MarketplaceSource( + hub_source_record = HubSource( id=current_object.id if current_object else None, - name=marketplace_source_schema.source.metadata.name, - index=marketplace_source_schema.index, + name=hub_source_schema.source.metadata.name, + index=hub_source_schema.index, created=created_timestamp, updated=updated_timestamp, ) - full_object = marketplace_source_schema.source.dict() + full_object = hub_source_schema.source.dict() full_object["metadata"]["created"] = str(created_timestamp) full_object["metadata"]["updated"] = str(updated_timestamp) - # Make sure we don't keep any credentials in the DB. These are handled in the marketplace crud object. + # Make sure we don't keep any credentials in the DB. These are handled in the hub crud object. full_object["spec"].pop("credentials", None) - marketplace_source_record.full_object = full_object - return marketplace_source_record + hub_source_record.full_object = full_object + return hub_source_record @staticmethod - def _validate_and_adjust_marketplace_order(session, order): - max_order = session.query(func.max(MarketplaceSource.index)).scalar() + def _validate_and_adjust_hub_order(session, order): + max_order = session.query(func.max(HubSource.index)).scalar() if not max_order or max_order < 0: max_order = 0 - if order == schemas.marketplace.last_source_index: + if order == mlrun.common.schemas.hub.last_source_index: order = max_order + 1 if order > max_order + 1: @@ -3334,62 +3396,54 @@ def _validate_and_adjust_marketplace_order(session, order): if order < 1: raise mlrun.errors.MLRunInvalidArgumentError( "Order of inserted source must be greater than 0 or " - + f"{schemas.marketplace.last_source_index} (for last). order = {order}" + + f"{mlrun.common.schemas.hub.last_source_index} (for last). order = {order}" ) return order - def create_marketplace_source( - self, session, ordered_source: schemas.IndexedMarketplaceSource + def create_hub_source( + self, session, ordered_source: mlrun.common.schemas.IndexedHubSource ): logger.debug( - "Creating marketplace source in DB", + "Creating hub source in DB", index=ordered_source.index, name=ordered_source.source.metadata.name, ) - order = self._validate_and_adjust_marketplace_order( - session, ordered_source.index - ) + order = self._validate_and_adjust_hub_order(session, ordered_source.index) name = ordered_source.source.metadata.name - source_record = self._query(session, MarketplaceSource, name=name).one_or_none() + source_record = self._query(session, HubSource, name=name).one_or_none() if source_record: raise mlrun.errors.MLRunConflictError( - f"Marketplace source name already exists. name = {name}" + f"Hub source name already exists. name = {name}" ) - source_record = self._transform_marketplace_source_schema_to_record( - ordered_source - ) + source_record = self._transform_hub_source_schema_to_record(ordered_source) self._move_and_reorder_table_items( session, source_record, move_to=order, move_from=None ) @retry_on_conflict - def store_marketplace_source( + def store_hub_source( self, session, name, - ordered_source: schemas.IndexedMarketplaceSource, + ordered_source: mlrun.common.schemas.IndexedHubSource, ): - logger.debug( - "Storing marketplace source in DB", index=ordered_source.index, name=name - ) + logger.debug("Storing hub source in DB", index=ordered_source.index, name=name) if name != ordered_source.source.metadata.name: raise mlrun.errors.MLRunInvalidArgumentError( "Conflict between resource name and metadata.name in the stored object" ) - order = self._validate_and_adjust_marketplace_order( - session, ordered_source.index - ) + order = self._validate_and_adjust_hub_order(session, ordered_source.index) - source_record = self._query(session, MarketplaceSource, name=name).one_or_none() + source_record = self._query(session, HubSource, name=name).one_or_none() current_order = source_record.index if source_record else None - if current_order == schemas.marketplace.last_source_index: + if current_order == mlrun.common.schemas.hub.last_source_index: raise mlrun.errors.MLRunInvalidArgumentError( - "Attempting to modify the global marketplace source." + "Attempting to modify the global hub source." ) - source_record = self._transform_marketplace_source_schema_to_record( + source_record = self._transform_hub_source_schema_to_record( ordered_source, source_record ) @@ -3397,47 +3451,46 @@ def store_marketplace_source( session, source_record, move_to=order, move_from=current_order ) - def list_marketplace_sources( - self, session - ) -> List[schemas.IndexedMarketplaceSource]: + def list_hub_sources(self, session) -> List[mlrun.common.schemas.IndexedHubSource]: results = [] - query = self._query(session, MarketplaceSource).order_by( - MarketplaceSource.index.desc() - ) + query = self._query(session, HubSource).order_by(HubSource.index.desc()) for record in query: - ordered_source = self._transform_marketplace_source_record_to_schema(record) + ordered_source = self._transform_hub_source_record_to_schema(record) # Need this to make the list return such that the default source is last in the response. - if ordered_source.index != schemas.last_source_index: + if ordered_source.index != mlrun.common.schemas.last_source_index: results.insert(0, ordered_source) else: results.append(ordered_source) return results - def delete_marketplace_source(self, session, name): - logger.debug("Deleting marketplace source from DB", name=name) + def _list_hub_sources_without_transform(self, session) -> List[HubSource]: + return self._query(session, HubSource).all() + + def delete_hub_source(self, session, name): + logger.debug("Deleting hub source from DB", name=name) - source_record = self._query(session, MarketplaceSource, name=name).one_or_none() + source_record = self._query(session, HubSource, name=name).one_or_none() if not source_record: return current_order = source_record.index - if current_order == schemas.marketplace.last_source_index: + if current_order == mlrun.common.schemas.hub.last_source_index: raise mlrun.errors.MLRunInvalidArgumentError( - "Attempting to delete the global marketplace source." + "Attempting to delete the global hub source." ) self._move_and_reorder_table_items( session, source_record, move_to=None, move_from=current_order ) - def get_marketplace_source(self, session, name) -> schemas.IndexedMarketplaceSource: - source_record = self._query(session, MarketplaceSource, name=name).one_or_none() + def get_hub_source(self, session, name) -> mlrun.common.schemas.IndexedHubSource: + source_record = self._query(session, HubSource, name=name).one_or_none() if not source_record: raise mlrun.errors.MLRunNotFoundError( - f"Marketplace source not found. name = {name}" + f"Hub source not found. name = {name}" ) - return self._transform_marketplace_source_record_to_schema(source_record) + return self._transform_hub_source_record_to_schema(source_record) def get_current_data_version( self, session, raise_on_not_found=True @@ -3475,7 +3528,7 @@ def store_background_task( session, name: str, project: str, - state: str = mlrun.api.schemas.BackgroundTaskState.running, + state: str = mlrun.common.schemas.BackgroundTaskState.running, timeout: int = None, ): background_task_record = self._query( @@ -3489,7 +3542,7 @@ def store_background_task( # we don't want to be able to change state after it reached terminal state if ( background_task_record.state - in mlrun.api.schemas.BackgroundTaskState.terminal_states() + in mlrun.common.schemas.BackgroundTaskState.terminal_states() and state != background_task_record.state ): raise mlrun.errors.MLRunRuntimeError( @@ -3516,7 +3569,7 @@ def store_background_task( def get_background_task( self, session, name: str, project: str - ) -> schemas.BackgroundTask: + ) -> mlrun.common.schemas.BackgroundTask: background_task_record = self._get_background_task_record( session, name, project ) @@ -3527,7 +3580,7 @@ def get_background_task( session, name, project, - mlrun.api.schemas.background_task.BackgroundTaskState.failed, + mlrun.common.schemas.background_task.BackgroundTaskState.failed, ) background_task_record = self._get_background_task_record( session, name, project @@ -3538,17 +3591,17 @@ def get_background_task( @staticmethod def _transform_background_task_record_to_schema( background_task_record: BackgroundTask, - ) -> schemas.BackgroundTask: - return schemas.BackgroundTask( - metadata=schemas.BackgroundTaskMetadata( + ) -> mlrun.common.schemas.BackgroundTask: + return mlrun.common.schemas.BackgroundTask( + metadata=mlrun.common.schemas.BackgroundTaskMetadata( name=background_task_record.name, project=background_task_record.project, created=background_task_record.created, updated=background_task_record.updated, timeout=background_task_record.timeout, ), - spec=schemas.BackgroundTaskSpec(), - status=schemas.BackgroundTaskStatus( + spec=mlrun.common.schemas.BackgroundTaskSpec(), + status=mlrun.common.schemas.BackgroundTaskStatus( state=background_task_record.state, ), ) @@ -3601,7 +3654,7 @@ def _is_background_task_timeout_exceeded(background_task_record) -> bool: if ( timeout and background_task_record.state - not in mlrun.api.schemas.BackgroundTaskState.terminal_states() + not in mlrun.common.schemas.BackgroundTaskState.terminal_states() and datetime.utcnow() > timedelta(seconds=int(timeout)) + background_task_record.updated ): @@ -3622,18 +3675,36 @@ def store_run_notifications( f"Run not found: uid={run_uid}, project={project}" ) - run_notifications = { + self._store_notifications(session, Run, notification_objects, run.id, project) + + def _store_notifications( + self, + session, + cls, + notification_objects: typing.List[mlrun.model.Notification], + parent_id: str, + project: str, + ): + db_notifications = { notification.name: notification - for notification in self._get_db_notifications(session, run_id=run.id) + for notification in self._get_db_notifications( + session, cls, parent_id=parent_id + ) } notifications = [] + logger.debug( + "Storing notifications", + notifications_length=len(notification_objects), + parent_id=parent_id, + project=project, + ) for notification_model in notification_objects: new_notification = False - notification = run_notifications.get(notification_model.name, None) + notification = db_notifications.get(notification_model.name, None) if not notification: new_notification = True - notification = Notification( - name=notification_model.name, run=run.id, project=project + notification = cls.Notification( + name=notification_model.name, parent_id=parent_id, project=project ) notification.kind = notification_model.kind @@ -3644,14 +3715,15 @@ def store_run_notifications( notification.params = notification_model.params notification.status = ( notification_model.status - or mlrun.api.schemas.NotificationStatus.PENDING + or mlrun.common.schemas.NotificationStatus.PENDING ) notification.sent_time = notification_model.sent_time logger.debug( f"Storing {'new' if new_notification else 'existing'} notification", notification_name=notification.name, - run_uid=run_uid, + notification_status=notification.status, + parent_id=parent_id, project=project, ) notifications.append(notification) @@ -3672,7 +3744,9 @@ def list_run_notifications( return [ self._transform_notification_record_to_schema(notification) - for notification in self._query(session, Notification, run=run.id).all() + for notification in self._query( + session, Run.Notification, parent_id=run.id + ).all() ] def delete_run_notifications( @@ -3698,9 +3772,51 @@ def delete_run_notifications( if project == "*": project = None - query = self._get_db_notifications(session, name, run_id, project) + query = self._get_db_notifications(session, Run, name, run_id, project) for notification in query: session.delete(notification) if commit: session.commit() + + def set_run_notifications( + self, + session: Session, + project: str, + notifications: typing.List[mlrun.model.Notification], + identifier: mlrun.common.schemas.RunIdentifier, + **kwargs, + ): + """ + Set notifications for a run. This will replace any existing notifications. + :param session: SQLAlchemy session + :param project: Project name + :param notifications: List of notifications to set + :param identifier: Run identifier + :param kwargs: Ignored additional arguments (for interfacing purposes) + """ + run = self._get_run(session, identifier.uid, project, None) + if not run: + raise mlrun.errors.MLRunNotFoundError( + f"Run not found: project={project}, uid={identifier.uid}" + ) + + run.struct.setdefault("spec", {})["notifications"] = [ + notification.to_dict() for notification in notifications + ] + + # update run, delete and store notifications all in one transaction. + # using session.add instead of upsert, so we don't commit the run. + # the commit will happen at the end (in store_run_notifications, or manually at the end). + session.add(run) + self.delete_run_notifications( + session, run_uid=run.uid, project=project, commit=False + ) + if notifications: + self.store_run_notifications( + session, + notification_objects=notifications, + run_uid=run.uid, + project=project, + ) + self._commit(session, [run], ignore=True) diff --git a/mlrun/api/db/sqldb/models/__init__.py b/mlrun/api/db/sqldb/models/__init__.py index bfb036e18368..a633e09f0640 100644 --- a/mlrun/api/db/sqldb/models/__init__.py +++ b/mlrun/api/db/sqldb/models/__init__.py @@ -20,10 +20,10 @@ from .models_mysql import * # noqa # importing private variables as well - from .models_mysql import _classes, _labeled, _table2cls, _tagged # noqa # isort:skip + from .models_mysql import _classes, _labeled, _table2cls, _tagged, _with_notifications # noqa # isort:skip else: from .models_sqlite import * # noqa # importing private variables as well - from .models_sqlite import _classes, _labeled, _table2cls, _tagged # noqa # isort:skip + from .models_sqlite import _classes, _labeled, _table2cls, _tagged, _with_notifications # noqa # isort:skip # fmt: on diff --git a/mlrun/api/db/sqldb/models/models_mysql.py b/mlrun/api/db/sqldb/models/models_mysql.py index f7fad72eea6b..e8c41895d6fd 100644 --- a/mlrun/api/db/sqldb/models/models_mysql.py +++ b/mlrun/api/db/sqldb/models/models_mysql.py @@ -30,9 +30,10 @@ UniqueConstraint, ) from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import class_mapper, relationship +from sqlalchemy.orm import relationship -from mlrun.api import schemas +import mlrun.common.schemas +import mlrun.utils.db from mlrun.api.utils.db.sql_collation import SQLCollationUtil Base = declarative_base() @@ -40,42 +41,8 @@ run_time_fmt = "%Y-%m-%dT%H:%M:%S.%fZ" -class BaseModel: - def to_dict(self, exclude=None): - """ - NOTE - this function (currently) does not handle serializing relationships - """ - exclude = exclude or [] - mapper = class_mapper(self.__class__) - columns = [column.key for column in mapper.columns if column.key not in exclude] - get_key_value = ( - lambda c: (c, getattr(self, c).isoformat()) - if isinstance(getattr(self, c), datetime) - else (c, getattr(self, c)) - ) - return dict(map(get_key_value, columns)) - - -class HasStruct(BaseModel): - @property - def struct(self): - return pickle.loads(self.body) - - @struct.setter - def struct(self, value): - self.body = pickle.dumps(value) - - def to_dict(self, exclude=None): - """ - NOTE - this function (currently) does not handle serializing relationships - """ - exclude = exclude or [] - exclude.append("body") - return super().to_dict(exclude) - - def make_label(table): - class Label(Base, BaseModel): + class Label(Base, mlrun.utils.db.BaseModel): __tablename__ = f"{table}_labels" __table_args__ = ( UniqueConstraint("name", "parent", name=f"_{table}_labels_uc"), @@ -90,7 +57,7 @@ class Label(Base, BaseModel): def make_tag(table): - class Tag(Base, BaseModel): + class Tag(Base, mlrun.utils.db.BaseModel): __tablename__ = f"{table}_tags" __table_args__ = ( UniqueConstraint("project", "name", "obj_id", name=f"_{table}_tags_uc"), @@ -107,7 +74,7 @@ class Tag(Base, BaseModel): # TODO: don't want to refactor everything in one PR so splitting this function to 2 versions - eventually only this one # should be used def make_tag_v2(table): - class Tag(Base, BaseModel): + class Tag(Base, mlrun.utils.db.BaseModel): __tablename__ = f"{table}_tags" __table_args__ = ( UniqueConstraint("project", "name", "obj_name", name=f"_{table}_tags_uc"), @@ -122,11 +89,57 @@ class Tag(Base, BaseModel): return Tag +def make_notification(table): + class Notification(Base, mlrun.utils.db.BaseModel): + __tablename__ = f"{table}_notifications" + __table_args__ = ( + UniqueConstraint("name", "parent_id", name=f"_{table}_notifications_uc"), + ) + + id = Column(Integer, primary_key=True) + project = Column(String(255, collation=SQLCollationUtil.collation())) + name = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + kind = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + message = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + severity = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + when = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + condition = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + params = Column("params", JSON) + parent_id = Column(Integer, ForeignKey(f"{table}.id")) + + # TODO: Separate table for notification state. + # Currently, we are only supporting one notification being sent per DB row (either on completion or on error). + # In the future, we might want to support multiple notifications per DB row, and we might want to support on + # start, therefore we need to separate the state from the notification itself (e.g. this table can be table + # with notification_id, state, when, last_sent, etc.). This will require some refactoring in the code. + sent_time = Column( + sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3), + nullable=True, + ) + status = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + + return Notification + + # quell SQLAlchemy warnings on duplicate class name (Label) with warnings.catch_warnings(): warnings.simplefilter("ignore") - class Artifact(Base, HasStruct): + class Artifact(Base, mlrun.utils.db.HasStruct): __tablename__ = "artifacts" __table_args__ = ( UniqueConstraint("uid", "project", "key", name="_artifacts_uc"), @@ -140,7 +153,7 @@ class Artifact(Base, HasStruct): project = Column(String(255, collation=SQLCollationUtil.collation())) uid = Column(String(255, collation=SQLCollationUtil.collation())) updated = Column(sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3)) - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning body = Column(sqlalchemy.dialects.mysql.MEDIUMBLOB) labels = relationship(Label, cascade="all, delete-orphan") @@ -149,7 +162,7 @@ class Artifact(Base, HasStruct): def get_identifier_string(self) -> str: return f"{self.project}/{self.key}/{self.uid}" - class Function(Base, HasStruct): + class Function(Base, mlrun.utils.db.HasStruct): __tablename__ = "functions" __table_args__ = ( UniqueConstraint("name", "project", "uid", name="_functions_uc"), @@ -162,7 +175,7 @@ class Function(Base, HasStruct): name = Column(String(255, collation=SQLCollationUtil.collation())) project = Column(String(255, collation=SQLCollationUtil.collation())) uid = Column(String(255, collation=SQLCollationUtil.collation())) - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning body = Column(sqlalchemy.dialects.mysql.MEDIUMBLOB) updated = Column(sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3)) @@ -172,59 +185,19 @@ class Function(Base, HasStruct): def get_identifier_string(self) -> str: return f"{self.project}/{self.name}/{self.uid}" - class Log(Base, BaseModel): + class Log(Base, mlrun.utils.db.BaseModel): __tablename__ = "logs" id = Column(Integer, primary_key=True) uid = Column(String(255, collation=SQLCollationUtil.collation())) project = Column(String(255, collation=SQLCollationUtil.collation())) - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning body = Column(sqlalchemy.dialects.mysql.MEDIUMBLOB) def get_identifier_string(self) -> str: return f"{self.project}/{self.uid}" - class Notification(Base, BaseModel): - __tablename__ = "notifications" - __table_args__ = (UniqueConstraint("name", "run", name="_notifications_uc"),) - - id = Column(Integer, primary_key=True) - project = Column(String(255, collation=SQLCollationUtil.collation())) - name = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - kind = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - message = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - severity = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - when = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - condition = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - params = Column("params", JSON) - run = Column(Integer, ForeignKey("runs.id")) - - # TODO: Separate table for notification state. - # Currently, we are only supporting one notification being sent per DB row (either on completion or on error). - # In the future, we might want to support multiple notifications per DB row, and we might want to support on - # start, therefore we need to separate the state from the notification itself (e.g. this table can be table - # with notification_id, state, when, last_sent, etc.). This will require some refactoring in the code. - sent_time = Column( - sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3), - nullable=True, - ) - status = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - - class Run(Base, HasStruct): + class Run(Base, mlrun.utils.db.HasStruct): __tablename__ = "runs" __table_args__ = ( UniqueConstraint("uid", "project", "iteration", name="_runs_uc"), @@ -232,6 +205,7 @@ class Run(Base, HasStruct): Label = make_label(__tablename__) Tag = make_tag(__tablename__) + Notification = make_notification(__tablename__) id = Column(Integer, primary_key=True) uid = Column(String(255, collation=SQLCollationUtil.collation())) @@ -241,7 +215,7 @@ class Run(Base, HasStruct): ) iteration = Column(Integer) state = Column(String(255, collation=SQLCollationUtil.collation())) - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning body = Column(sqlalchemy.dialects.mysql.MEDIUMBLOB) start_time = Column(sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3)) updated = Column( @@ -260,7 +234,7 @@ class Run(Base, HasStruct): def get_identifier_string(self) -> str: return f"{self.project}/{self.uid}/{self.iteration}" - class BackgroundTask(Base, BaseModel): + class BackgroundTask(Base, mlrun.utils.db.BaseModel): __tablename__ = "background_tasks" __table_args__ = ( UniqueConstraint("name", "project", name="_background_tasks_uc"), @@ -284,7 +258,7 @@ class BackgroundTask(Base, BaseModel): state = Column(String(255, collation=SQLCollationUtil.collation())) timeout = Column(Integer) - class Schedule(Base, BaseModel): + class Schedule(Base, mlrun.utils.db.BaseModel): __tablename__ = "schedules_v2" __table_args__ = (UniqueConstraint("project", "name", name="_schedules_v2_uc"),) @@ -303,7 +277,7 @@ class Schedule(Base, BaseModel): creation_time = Column(sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3)) cron_trigger_str = Column(String(255, collation=SQLCollationUtil.collation())) last_run_uri = Column(String(255, collation=SQLCollationUtil.collation())) - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning struct = Column(sqlalchemy.dialects.mysql.MEDIUMBLOB) labels = relationship(Label, cascade="all, delete-orphan") concurrency_limit = Column(Integer, nullable=False) @@ -321,11 +295,11 @@ def scheduled_object(self, value): self.struct = pickle.dumps(value) @property - def cron_trigger(self) -> schemas.ScheduleCronTrigger: + def cron_trigger(self) -> mlrun.common.schemas.ScheduleCronTrigger: return orjson.loads(self.cron_trigger_str) @cron_trigger.setter - def cron_trigger(self, trigger: schemas.ScheduleCronTrigger): + def cron_trigger(self, trigger: mlrun.common.schemas.ScheduleCronTrigger): self.cron_trigger_str = orjson.dumps(trigger.dict(exclude_unset=True)) # Define "many to many" users/projects @@ -336,14 +310,14 @@ def cron_trigger(self, trigger: schemas.ScheduleCronTrigger): Column("user_id", Integer, ForeignKey("users.id")), ) - class User(Base, BaseModel): + class User(Base, mlrun.utils.db.BaseModel): __tablename__ = "users" __table_args__ = (UniqueConstraint("name", name="_users_uc"),) id = Column(Integer, primary_key=True) name = Column(String(255, collation=SQLCollationUtil.collation())) - class Project(Base, BaseModel): + class Project(Base, mlrun.utils.db.BaseModel): __tablename__ = "projects" # For now since we use project name a lot __table_args__ = (UniqueConstraint("name", name="_projects_uc"),) @@ -355,7 +329,7 @@ class Project(Base, BaseModel): source = Column(String(255, collation=SQLCollationUtil.collation())) # the attribute name used to be _spec which is just a wrong naming, the attribute was renamed to _full_object # leaving the column as is to prevent redundant migration - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning _full_object = Column("spec", sqlalchemy.dialects.mysql.MEDIUMBLOB) created = Column( sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3), default=datetime.utcnow @@ -379,7 +353,7 @@ def full_object(self): def full_object(self, value): self._full_object = pickle.dumps(value) - class Feature(Base, BaseModel): + class Feature(Base, mlrun.utils.db.BaseModel): __tablename__ = "features" id = Column(Integer, primary_key=True) feature_set_id = Column(Integer, ForeignKey("feature_sets.id")) @@ -393,7 +367,7 @@ class Feature(Base, BaseModel): def get_identifier_string(self) -> str: return f"{self.project}/{self.name}" - class Entity(Base, BaseModel): + class Entity(Base, mlrun.utils.db.BaseModel): __tablename__ = "entities" id = Column(Integer, primary_key=True) feature_set_id = Column(Integer, ForeignKey("feature_sets.id")) @@ -407,7 +381,7 @@ class Entity(Base, BaseModel): def get_identifier_string(self) -> str: return f"{self.project}/{self.name}" - class FeatureSet(Base, BaseModel): + class FeatureSet(Base, mlrun.utils.db.BaseModel): __tablename__ = "feature_sets" __table_args__ = ( UniqueConstraint("name", "project", "uid", name="_feature_set_uc"), @@ -451,7 +425,7 @@ def full_object(self, value): # TODO - convert to pickle, to avoid issues with non-json serializable fields such as datetime self._full_object = json.dumps(value, default=str) - class FeatureVector(Base, BaseModel): + class FeatureVector(Base, mlrun.utils.db.BaseModel): __tablename__ = "feature_vectors" __table_args__ = ( UniqueConstraint("name", "project", "uid", name="_feature_vectors_uc"), @@ -492,9 +466,9 @@ def full_object(self, value): # TODO - convert to pickle, to avoid issues with non-json serializable fields such as datetime self._full_object = json.dumps(value, default=str) - class MarketplaceSource(Base, BaseModel): - __tablename__ = "marketplace_sources" - __table_args__ = (UniqueConstraint("name", name="_marketplace_sources_uc"),) + class HubSource(Base, mlrun.utils.db.BaseModel): + __tablename__ = "hub_sources" + __table_args__ = (UniqueConstraint("name", name="_hub_sources_uc"),) id = Column(Integer, primary_key=True) name = Column(String(255, collation=SQLCollationUtil.collation())) @@ -523,7 +497,7 @@ def full_object(self, value): # TODO - convert to pickle, to avoid issues with non-json serializable fields such as datetime self._full_object = json.dumps(value, default=str) - class DataVersion(Base, BaseModel): + class DataVersion(Base, mlrun.utils.db.BaseModel): __tablename__ = "data_versions" __table_args__ = (UniqueConstraint("version", name="_versions_uc"),) @@ -538,5 +512,8 @@ class DataVersion(Base, BaseModel): # Must be after all table definitions _tagged = [cls for cls in Base.__subclasses__() if hasattr(cls, "Tag")] _labeled = [cls for cls in Base.__subclasses__() if hasattr(cls, "Label")] +_with_notifications = [ + cls for cls in Base.__subclasses__() if hasattr(cls, "Notification") +] _classes = [cls for cls in Base.__subclasses__()] _table2cls = {cls.__table__.name: cls for cls in Base.__subclasses__()} diff --git a/mlrun/api/db/sqldb/models/models_sqlite.py b/mlrun/api/db/sqldb/models/models_sqlite.py index 597983fbe997..4ce29ddac439 100644 --- a/mlrun/api/db/sqldb/models/models_sqlite.py +++ b/mlrun/api/db/sqldb/models/models_sqlite.py @@ -31,9 +31,10 @@ UniqueConstraint, ) from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import class_mapper, relationship +from sqlalchemy.orm import relationship -from mlrun.api import schemas +import mlrun.common.schemas +import mlrun.utils.db from mlrun.api.utils.db.sql_collation import SQLCollationUtil Base = declarative_base() @@ -41,42 +42,8 @@ run_time_fmt = "%Y-%m-%dT%H:%M:%S.%fZ" -class BaseModel: - def to_dict(self, exclude=None): - """ - NOTE - this function (currently) does not handle serializing relationships - """ - exclude = exclude or [] - mapper = class_mapper(self.__class__) - columns = [column.key for column in mapper.columns if column.key not in exclude] - get_key_value = ( - lambda c: (c, getattr(self, c).isoformat()) - if isinstance(getattr(self, c), datetime) - else (c, getattr(self, c)) - ) - return dict(map(get_key_value, columns)) - - -class HasStruct(BaseModel): - @property - def struct(self): - return pickle.loads(self.body) - - @struct.setter - def struct(self, value): - self.body = pickle.dumps(value) - - def to_dict(self, exclude=None): - """ - NOTE - this function (currently) does not handle serializing relationships - """ - exclude = exclude or [] - exclude.append("body") - return super().to_dict(exclude) - - def make_label(table): - class Label(Base, BaseModel): + class Label(Base, mlrun.utils.db.BaseModel): __tablename__ = f"{table}_labels" __table_args__ = ( UniqueConstraint("name", "parent", name=f"_{table}_labels_uc"), @@ -91,7 +58,7 @@ class Label(Base, BaseModel): def make_tag(table): - class Tag(Base, BaseModel): + class Tag(Base, mlrun.utils.db.BaseModel): __tablename__ = f"{table}_tags" __table_args__ = ( UniqueConstraint("project", "name", "obj_id", name=f"_{table}_tags_uc"), @@ -108,7 +75,7 @@ class Tag(Base, BaseModel): # TODO: don't want to refactor everything in one PR so splitting this function to 2 versions - eventually only this one # should be used def make_tag_v2(table): - class Tag(Base, BaseModel): + class Tag(Base, mlrun.utils.db.BaseModel): __tablename__ = f"{table}_tags" __table_args__ = ( UniqueConstraint("project", "name", "obj_name", name=f"_{table}_tags_uc"), @@ -126,11 +93,51 @@ class Tag(Base, BaseModel): return Tag +def make_notification(table): + class Notification(Base, mlrun.utils.db.BaseModel): + __tablename__ = f"{table}_notifications" + __table_args__ = ( + UniqueConstraint("name", "parent_id", name=f"_{table}_notifications_uc"), + ) + + id = Column(Integer, primary_key=True) + project = Column(String(255, collation=SQLCollationUtil.collation())) + name = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + kind = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + message = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + severity = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + when = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + condition = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + params = Column("params", JSON) + parent_id = Column(Integer, ForeignKey(f"{table}.id")) + sent_time = Column( + TIMESTAMP(), + nullable=True, + ) + status = Column( + String(255, collation=SQLCollationUtil.collation()), nullable=False + ) + + return Notification + + # quell SQLAlchemy warnings on duplicate class name (Label) with warnings.catch_warnings(): warnings.simplefilter("ignore") - class Artifact(Base, HasStruct): + class Artifact(Base, mlrun.utils.db.HasStruct): __tablename__ = "artifacts" __table_args__ = ( UniqueConstraint("uid", "project", "key", name="_artifacts_uc"), @@ -144,14 +151,14 @@ class Artifact(Base, HasStruct): project = Column(String(255, collation=SQLCollationUtil.collation())) uid = Column(String(255, collation=SQLCollationUtil.collation())) updated = Column(TIMESTAMP) - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning body = Column(BLOB) labels = relationship(Label) def get_identifier_string(self) -> str: return f"{self.project}/{self.key}/{self.uid}" - class Function(Base, HasStruct): + class Function(Base, mlrun.utils.db.HasStruct): __tablename__ = "functions" __table_args__ = ( UniqueConstraint("name", "project", "uid", name="_functions_uc"), @@ -164,7 +171,7 @@ class Function(Base, HasStruct): name = Column(String(255, collation=SQLCollationUtil.collation())) project = Column(String(255, collation=SQLCollationUtil.collation())) uid = Column(String(255, collation=SQLCollationUtil.collation())) - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning body = Column(BLOB) updated = Column(TIMESTAMP) labels = relationship(Label) @@ -172,53 +179,19 @@ class Function(Base, HasStruct): def get_identifier_string(self) -> str: return f"{self.project}/{self.name}/{self.uid}" - class Log(Base, BaseModel): + class Log(Base, mlrun.utils.db.BaseModel): __tablename__ = "logs" id = Column(Integer, primary_key=True) uid = Column(String(255, collation=SQLCollationUtil.collation())) project = Column(String(255, collation=SQLCollationUtil.collation())) - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning body = Column(BLOB) def get_identifier_string(self) -> str: return f"{self.project}/{self.uid}" - class Notification(Base, BaseModel): - __tablename__ = "notifications" - __table_args__ = (UniqueConstraint("name", "run", name="_notifications_uc"),) - - id = Column(Integer, primary_key=True) - project = Column(String(255, collation=SQLCollationUtil.collation())) - name = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - kind = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - message = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - severity = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - when = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - condition = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - params = Column("params", JSON) - run = Column(Integer, ForeignKey("runs.id")) - sent_time = Column( - TIMESTAMP(), - nullable=True, - ) - status = Column( - String(255, collation=SQLCollationUtil.collation()), nullable=False - ) - - class Run(Base, HasStruct): + class Run(Base, mlrun.utils.db.HasStruct): __tablename__ = "runs" __table_args__ = ( UniqueConstraint("uid", "project", "iteration", name="_runs_uc"), @@ -226,6 +199,7 @@ class Run(Base, HasStruct): Label = make_label(__tablename__) Tag = make_tag(__tablename__) + Notification = make_notification(__tablename__) id = Column(Integer, primary_key=True) uid = Column(String(255, collation=SQLCollationUtil.collation())) @@ -235,7 +209,7 @@ class Run(Base, HasStruct): ) iteration = Column(Integer) state = Column(String(255, collation=SQLCollationUtil.collation())) - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning body = Column(BLOB) start_time = Column(TIMESTAMP) # requested logs column indicates whether logs were requested for this run @@ -250,7 +224,7 @@ class Run(Base, HasStruct): def get_identifier_string(self) -> str: return f"{self.project}/{self.uid}/{self.iteration}" - class BackgroundTask(Base, BaseModel): + class BackgroundTask(Base, mlrun.utils.db.BaseModel): __tablename__ = "background_tasks" __table_args__ = ( UniqueConstraint("name", "project", name="_background_tasks_uc"), @@ -268,7 +242,7 @@ class BackgroundTask(Base, BaseModel): state = Column(String(255, collation=SQLCollationUtil.collation())) timeout = Column(Integer) - class Schedule(Base, BaseModel): + class Schedule(Base, mlrun.utils.db.BaseModel): __tablename__ = "schedules_v2" __table_args__ = (UniqueConstraint("project", "name", name="_schedules_v2_uc"),) @@ -287,7 +261,7 @@ class Schedule(Base, BaseModel): creation_time = Column(TIMESTAMP) cron_trigger_str = Column(String(255, collation=SQLCollationUtil.collation())) last_run_uri = Column(String(255, collation=SQLCollationUtil.collation())) - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning struct = Column(BLOB) labels = relationship(Label, cascade="all, delete-orphan") concurrency_limit = Column(Integer, nullable=False) @@ -305,11 +279,11 @@ def scheduled_object(self, value): self.struct = pickle.dumps(value) @property - def cron_trigger(self) -> schemas.ScheduleCronTrigger: + def cron_trigger(self) -> mlrun.common.schemas.ScheduleCronTrigger: return orjson.loads(self.cron_trigger_str) @cron_trigger.setter - def cron_trigger(self, trigger: schemas.ScheduleCronTrigger): + def cron_trigger(self, trigger: mlrun.common.schemas.ScheduleCronTrigger): self.cron_trigger_str = orjson.dumps(trigger.dict(exclude_unset=True)) # Define "many to many" users/projects @@ -320,14 +294,14 @@ def cron_trigger(self, trigger: schemas.ScheduleCronTrigger): Column("user_id", Integer, ForeignKey("users.id")), ) - class User(Base, BaseModel): + class User(Base, mlrun.utils.db.BaseModel): __tablename__ = "users" __table_args__ = (UniqueConstraint("name", name="_users_uc"),) id = Column(Integer, primary_key=True) name = Column(String(255, collation=SQLCollationUtil.collation())) - class Project(Base, BaseModel): + class Project(Base, mlrun.utils.db.BaseModel): __tablename__ = "projects" # For now since we use project name a lot __table_args__ = (UniqueConstraint("name", name="_projects_uc"),) @@ -339,7 +313,7 @@ class Project(Base, BaseModel): source = Column(String(255, collation=SQLCollationUtil.collation())) # the attribute name used to be _spec which is just a wrong naming, the attribute was renamed to _full_object # leaving the column as is to prevent redundant migration - # TODO: change to JSON, see mlrun/api/schemas/function.py::FunctionState for reasoning + # TODO: change to JSON, see mlrun/common/schemas/function.py::FunctionState for reasoning _full_object = Column("spec", BLOB) created = Column(TIMESTAMP, default=datetime.utcnow) state = Column(String(255, collation=SQLCollationUtil.collation())) @@ -361,7 +335,7 @@ def full_object(self): def full_object(self, value): self._full_object = pickle.dumps(value) - class Feature(Base, BaseModel): + class Feature(Base, mlrun.utils.db.BaseModel): __tablename__ = "features" id = Column(Integer, primary_key=True) feature_set_id = Column(Integer, ForeignKey("feature_sets.id")) @@ -375,7 +349,7 @@ class Feature(Base, BaseModel): def get_identifier_string(self) -> str: return f"{self.project}/{self.name}" - class Entity(Base, BaseModel): + class Entity(Base, mlrun.utils.db.BaseModel): __tablename__ = "entities" id = Column(Integer, primary_key=True) feature_set_id = Column(Integer, ForeignKey("feature_sets.id")) @@ -389,7 +363,7 @@ class Entity(Base, BaseModel): def get_identifier_string(self) -> str: return f"{self.project}/{self.name}" - class FeatureSet(Base, BaseModel): + class FeatureSet(Base, mlrun.utils.db.BaseModel): __tablename__ = "feature_sets" __table_args__ = ( UniqueConstraint("name", "project", "uid", name="_feature_set_uc"), @@ -425,7 +399,7 @@ def full_object(self): def full_object(self, value): self._full_object = json.dumps(value, default=str) - class FeatureVector(Base, BaseModel): + class FeatureVector(Base, mlrun.utils.db.BaseModel): __tablename__ = "feature_vectors" __table_args__ = ( UniqueConstraint("name", "project", "uid", name="_feature_vectors_uc"), @@ -458,9 +432,9 @@ def full_object(self): def full_object(self, value): self._full_object = json.dumps(value, default=str) - class MarketplaceSource(Base, BaseModel): - __tablename__ = "marketplace_sources" - __table_args__ = (UniqueConstraint("name", name="_marketplace_sources_uc"),) + class HubSource(Base, mlrun.utils.db.BaseModel): + __tablename__ = "hub_sources" + __table_args__ = (UniqueConstraint("name", name="_hub_sources_uc"),) id = Column(Integer, primary_key=True) name = Column(String(255, collation=SQLCollationUtil.collation())) @@ -482,7 +456,7 @@ def full_object(self): def full_object(self, value): self._full_object = json.dumps(value, default=str) - class DataVersion(Base, BaseModel): + class DataVersion(Base, mlrun.utils.db.BaseModel): __tablename__ = "data_versions" __table_args__ = (UniqueConstraint("version", name="_versions_uc"),) @@ -494,5 +468,8 @@ class DataVersion(Base, BaseModel): # Must be after all table definitions _tagged = [cls for cls in Base.__subclasses__() if hasattr(cls, "Tag")] _labeled = [cls for cls in Base.__subclasses__() if hasattr(cls, "Label")] +_with_notifications = [ + cls for cls in Base.__subclasses__() if hasattr(cls, "Notification") +] _classes = [cls for cls in Base.__subclasses__()] _table2cls = {cls.__table__.name: cls for cls in Base.__subclasses__()} diff --git a/mlrun/api/db/sqldb/session.py b/mlrun/api/db/sqldb/session.py index 197d2af30429..34b3475e3e61 100644 --- a/mlrun/api/db/sqldb/session.py +++ b/mlrun/api/db/sqldb/session.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +import typing + from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import Session @@ -19,35 +22,38 @@ from mlrun.config import config -engine: Engine = None -_session_maker: SessionMaker = None +# TODO: wrap the following functions in a singleton class +_engines: typing.Dict[str, Engine] = {} +_session_makers: typing.Dict[str, SessionMaker] = {} # doing lazy load to allow tests to initialize the engine -def get_engine() -> Engine: - global engine - if engine is None: - _init_engine() - return engine +def get_engine(dsn=None) -> Engine: + global _engines + dsn = dsn or config.httpdb.dsn + if dsn not in _engines: + _init_engine(dsn=dsn) + return _engines[dsn] -def create_session() -> Session: - session_maker = _get_session_maker() +def create_session(dsn=None) -> Session: + session_maker = _get_session_maker(dsn=dsn) return session_maker() # doing lazy load to allow tests to initialize the engine -def _get_session_maker() -> SessionMaker: - global _session_maker - if _session_maker is None: - _init_session_maker() - return _session_maker +def _get_session_maker(dsn) -> SessionMaker: + global _session_makers + dsn = dsn or config.httpdb.dsn + if dsn not in _session_makers: + _init_session_maker(dsn=dsn) + return _session_makers[dsn] # TODO: we accept the dsn here to enable tests to override it, the "right" thing will be that config will be easily # overridable by tests (today when you import the config it is already being initialized.. should be lazy load) def _init_engine(dsn=None): - global engine + global _engines dsn = dsn or config.httpdb.dsn kwargs = {} if "mysql" in dsn: @@ -62,9 +68,10 @@ def _init_engine(dsn=None): "max_overflow": max_overflow, } engine = create_engine(dsn, **kwargs) - _init_session_maker() + _engines[dsn] = engine + _init_session_maker(dsn=dsn) -def _init_session_maker(): - global _session_maker - _session_maker = SessionMaker(bind=get_engine()) +def _init_session_maker(dsn): + global _session_makers + _session_makers[dsn] = SessionMaker(bind=get_engine(dsn=dsn)) diff --git a/mlrun/api/initial_data.py b/mlrun/api/initial_data.py index 1282e6efc434..cc02b4d77c63 100644 --- a/mlrun/api/initial_data.py +++ b/mlrun/api/initial_data.py @@ -19,6 +19,7 @@ import typing import dateutil.parser +import pydantic.error_wrappers import pymysql.err import sqlalchemy.exc import sqlalchemy.orm @@ -26,12 +27,12 @@ import mlrun.api.db.sqldb.db import mlrun.api.db.sqldb.helpers import mlrun.api.db.sqldb.models -import mlrun.api.schemas import mlrun.api.utils.db.alembic import mlrun.api.utils.db.backup import mlrun.api.utils.db.mysql import mlrun.api.utils.db.sqlite_migration import mlrun.artifacts +import mlrun.common.schemas from mlrun.api.db.init_db import init_db from mlrun.api.db.session import close_session, create_session from mlrun.config import config @@ -43,7 +44,20 @@ def init_data( from_scratch: bool = False, perform_migrations_if_needed: bool = False ) -> None: logger.info("Initializing DB data") - mlrun.api.utils.db.mysql.MySQLUtil.wait_for_db_liveness(logger) + + # create mysql util, and if mlrun is configured to use mysql, wait for it to be live and set its db modes + mysql_util = mlrun.api.utils.db.mysql.MySQLUtil(logger) + if mysql_util.get_mysql_dsn_data(): + mysql_util.wait_for_db_liveness() + mysql_util.set_modes(mlrun.mlconf.httpdb.db.mysql.modes) + else: + dsn = mysql_util.get_dsn() + if "sqlite" in dsn: + logger.debug("SQLite DB is used, liveness check not needed") + else: + logger.warn( + f"Invalid mysql dsn: {dsn}, assuming live and skipping liveness verification" + ) sqlite_migration_util = None if not from_scratch and config.httpdb.db.database_migration_mode == "enabled": @@ -62,7 +76,7 @@ def init_data( and not perform_migrations_if_needed and is_migration_needed ): - state = mlrun.api.schemas.APIStates.waiting_for_migrations + state = mlrun.common.schemas.APIStates.waiting_for_migrations logger.info("Migration is needed, changing API state", state=state) config.httpdb.state = state return @@ -73,7 +87,7 @@ def init_data( db_backup.backup_database() logger.info("Creating initial data") - config.httpdb.state = mlrun.api.schemas.APIStates.migrations_in_progress + config.httpdb.state = mlrun.common.schemas.APIStates.migrations_in_progress if is_migration_from_scratch or is_migration_needed: try: @@ -81,15 +95,15 @@ def init_data( _perform_database_migration(sqlite_migration_util) + init_db() db_session = create_session() try: - init_db(db_session) _add_initial_data(db_session) _perform_data_migrations(db_session) finally: close_session(db_session) except Exception: - state = mlrun.api.schemas.APIStates.migrations_failed + state = mlrun.common.schemas.APIStates.migrations_failed logger.warning("Migrations failed, changing API state", state=state) config.httpdb.state = state raise @@ -97,17 +111,19 @@ def init_data( # should happen - we can't do it here because it requires an asyncio loop which can't be accessible here # therefore moving to migration_completed state, and other component will take care of moving to online if not is_migration_from_scratch and is_migration_needed: - config.httpdb.state = mlrun.api.schemas.APIStates.migrations_completed + config.httpdb.state = mlrun.common.schemas.APIStates.migrations_completed else: - config.httpdb.state = mlrun.api.schemas.APIStates.online + config.httpdb.state = mlrun.common.schemas.APIStates.online logger.info("Initial data created") # If the data_table version doesn't exist, we can assume the data version is 1. -# This is because data version 1 points to to a data migration which was added back in 0.6.0, and +# This is because data version 1 points to a data migration which was added back in 0.6.0, and # upgrading from a version earlier than 0.6.0 to v>=0.8.0 is not supported. data_version_prior_to_table_addition = 1 -latest_data_version = 2 + +# NOTE: Bump this number when adding a new data migration +latest_data_version = 3 def _resolve_needed_operations( @@ -212,13 +228,16 @@ def _perform_data_migrations(db_session: sqlalchemy.orm.Session): _perform_version_1_data_migrations(db, db_session) if current_data_version < 2: _perform_version_2_data_migrations(db, db_session) + if current_data_version < 3: + _perform_version_3_data_migrations(db, db_session) + db.create_data_version(db_session, str(latest_data_version)) def _add_initial_data(db_session: sqlalchemy.orm.Session): # FileDB is not really a thing anymore, so using SQLDB directly db = mlrun.api.db.sqldb.db.SQLDB("") - _add_default_marketplace_source_if_needed(db, db_session) + _add_default_hub_source_if_needed(db, db_session) _add_data_version(db, db_session) @@ -465,6 +484,30 @@ def _align_runs_table( db._upsert(db_session, [run], ignore=True) +def _perform_version_3_data_migrations( + db: mlrun.api.db.sqldb.db.SQLDB, db_session: sqlalchemy.orm.Session +): + _rename_marketplace_kind_to_hub(db, db_session) + + +def _rename_marketplace_kind_to_hub( + db: mlrun.api.db.sqldb.db.SQLDB, db_session: sqlalchemy.orm.Session +): + logger.info("Renaming 'Marketplace' kinds to 'Hub'") + + hubs = db._list_hub_sources_without_transform(db_session) + for hub in hubs: + hub_dict = hub.full_object + + # rename kind from "MarketplaceSource" to "HubSource" + if "Marketplace" in hub_dict.get("kind", ""): + hub_dict["kind"] = hub_dict["kind"].replace("Marketplace", "Hub") + + # save the object back to the db + hub.full_object = hub_dict + db._upsert(db_session, [hub], ignore=True) + + def _perform_version_1_data_migrations( db: mlrun.api.db.sqldb.db.SQLDB, db_session: sqlalchemy.orm.Session ): @@ -482,7 +525,7 @@ def _enrich_project_state( changed = False if not project.spec.desired_state: changed = True - project.spec.desired_state = mlrun.api.schemas.ProjectState.online + project.spec.desired_state = mlrun.common.schemas.ProjectState.online if not project.status.state: changed = True project.status.state = project.spec.desired_state @@ -494,32 +537,47 @@ def _enrich_project_state( db.store_project(db_session, project.metadata.name, project) -def _add_default_marketplace_source_if_needed( +def _add_default_hub_source_if_needed( db: mlrun.api.db.sqldb.db.SQLDB, db_session: sqlalchemy.orm.Session ): try: - hub_marketplace_source = db.get_marketplace_source( - db_session, config.marketplace.default_source.name + hub_marketplace_source = db.get_hub_source( + db_session, config.hub.default_source.name ) except mlrun.errors.MLRunNotFoundError: hub_marketplace_source = None + except pydantic.error_wrappers.ValidationError as exc: + + # following the renaming of 'marketplace' to 'hub', validation errors can occur on the old 'marketplace'. + # this will be handled later in the data migrations, but for now - if a validation error occurs, we assume + # that a default hub source exists + if all( + [ + "validation error for HubSource" in str(exc), + "value is not a valid enumeration member" in str(exc), + ] + ): + logger.info("Found existing default hub source, data migration needed") + hub_marketplace_source = True + else: + raise exc if not hub_marketplace_source: - hub_source = mlrun.api.schemas.MarketplaceSource.generate_default_source() - # hub_source will be None if the configuration has marketplace.default_source.create=False + hub_source = mlrun.common.schemas.HubSource.generate_default_source() + # hub_source will be None if the configuration has hub.default_source.create=False if hub_source: - logger.info("Adding default marketplace source") - # Not using db.store_marketplace_source() since it doesn't allow changing the default marketplace source. - hub_record = db._transform_marketplace_source_schema_to_record( - mlrun.api.schemas.IndexedMarketplaceSource( - index=mlrun.api.schemas.marketplace.last_source_index, + logger.info("Adding default hub source") + # Not using db.store_marketplace_source() since it doesn't allow changing the default hub source. + hub_record = db._transform_hub_source_schema_to_record( + mlrun.common.schemas.IndexedHubSource( + index=mlrun.common.schemas.hub.last_source_index, source=hub_source, ) ) db_session.add(hub_record) db_session.commit() else: - logger.info("Not adding default marketplace source, per configuration") + logger.info("Not adding default hub source, per configuration") return diff --git a/mlrun/api/launcher.py b/mlrun/api/launcher.py new file mode 100644 index 000000000000..cd316a0abbd1 --- /dev/null +++ b/mlrun/api/launcher.py @@ -0,0 +1,196 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Dict, List, Optional, Union + +import mlrun.api.crud +import mlrun.api.db.sqldb.session +import mlrun.common.schemas.schedule +import mlrun.config +import mlrun.execution +import mlrun.launcher.base +import mlrun.runtimes +import mlrun.runtimes.generators +import mlrun.runtimes.utils +import mlrun.utils +import mlrun.utils.regex + + +class ServerSideLauncher(mlrun.launcher.base.BaseLauncher): + def launch( + self, + runtime: mlrun.runtimes.BaseRuntime, + task: Optional[ + Union["mlrun.run.RunTemplate", "mlrun.run.RunObject", dict] + ] = None, + handler: Optional[str] = None, + name: Optional[str] = "", + project: Optional[str] = "", + params: Optional[dict] = None, + inputs: Optional[Dict[str, str]] = None, + out_path: Optional[str] = "", + workdir: Optional[str] = "", + artifact_path: Optional[str] = "", + watch: Optional[bool] = True, + schedule: Optional[ + Union[str, mlrun.common.schemas.schedule.ScheduleCronTrigger] + ] = None, + hyperparams: Dict[str, list] = None, + hyper_param_options: Optional[mlrun.model.HyperParamOptions] = None, + verbose: Optional[bool] = None, + scrape_metrics: Optional[bool] = None, + local_code_path: Optional[str] = None, + auto_build: Optional[bool] = None, + param_file_secrets: Optional[Dict[str, str]] = None, + notifications: Optional[List[mlrun.model.Notification]] = None, + returns: Optional[List[Union[str, Dict[str, str]]]] = None, + ) -> mlrun.run.RunObject: + self.enrich_runtime(runtime, project) + + run = self._create_run_object(task) + + run = self._enrich_run( + runtime, + run=run, + handler=handler, + project_name=project, + name=name, + params=params, + inputs=inputs, + returns=returns, + hyperparams=hyperparams, + hyper_param_options=hyper_param_options, + verbose=verbose, + scrape_metrics=scrape_metrics, + out_path=out_path, + artifact_path=artifact_path, + workdir=workdir, + notifications=notifications, + ) + self._validate_runtime(runtime, run) + + if runtime.verbose: + mlrun.utils.logger.info(f"Run:\n{run.to_yaml()}") + + if not runtime.is_child: + mlrun.utils.logger.info( + "Storing function", + name=run.metadata.name, + uid=run.metadata.uid, + ) + self._store_function(runtime, run) + + execution = mlrun.execution.MLClientCtx.from_dict( + run.to_dict(), + runtime._get_db(), + autocommit=False, + is_api=True, + store_run=False, + ) + + # create task generator (for child runs) from spec + task_generator = mlrun.runtimes.generators.get_generator( + run.spec, execution, param_file_secrets=param_file_secrets + ) + if task_generator: + # verify valid task parameters + tasks = task_generator.generate(run) + for task in tasks: + self._validate_run_params(task.spec.parameters) + + # post verifications, store execution in db and run pre run hooks + execution.store_run() + runtime._pre_run(run, execution) # hook for runtime specific prep + + resp = None + last_err = None + # If the runtime is nested, it means the hyper-run will run within a single instance of the run. + # So while in the API, we consider the hyper-run as a single run, and then in the runtime itself when the + # runtime is now a local runtime and therefore `self._is_nested == False`, we run each task as a separate run by + # using the task generator + if task_generator and not runtime._is_nested: + # multiple runs (based on hyper params or params file) + runner = runtime._run_many + if hasattr(runtime, "_parallel_run_many") and task_generator.use_parallel(): + runner = runtime._parallel_run_many + results = runner(task_generator, execution, run) + mlrun.runtimes.utils.results_to_iter(results, run, execution) + result = execution.to_dict() + result = runtime._update_run_state(result, task=run) + + else: + # single run + try: + resp = runtime._run(run, execution) + + except mlrun.runtimes.utils.RunError as err: + last_err = err + + finally: + result = runtime._update_run_state(resp=resp, task=run, err=last_err) + + self._save_notifications(run) + + runtime._post_run(result, execution) # hook for runtime specific cleanup + + return self._wrap_run_result(runtime, result, run, err=last_err) + + @staticmethod + def enrich_runtime( + runtime: "mlrun.runtimes.base.BaseRuntime", project_name: Optional[str] = "" + ): + """ + Enrich the runtime object with the project spec and metadata. + This is done only on the server side, since it's the source of truth for the project, and we want to keep the + client side enrichment as minimal as possible. + """ + # ensure the runtime has a project before we enrich it with the project's spec + runtime.metadata.project = ( + project_name + or runtime.metadata.project + or mlrun.config.config.default_project + ) + project = runtime._get_db().get_project(runtime.metadata.project) + # this is mainly for tests with nop db + # in normal use cases if no project is found we will get an error + if project: + project = mlrun.projects.project.MlrunProject.from_dict(project.dict()) + mlrun.projects.pipelines.enrich_function_object( + project, runtime, copy_function=False + ) + + def _save_notifications(self, runobj): + if not self._run_has_valid_notifications(runobj): + return + + # If in the api server, we can assume that watch=False, so we save notification + # configs to the DB, for the run monitor to later pick up and push. + session = mlrun.api.db.sqldb.session.create_session() + mlrun.api.crud.Notifications().store_run_notifications( + session, + runobj.spec.notifications, + runobj.metadata.uid, + runobj.metadata.project, + ) + + def _store_function( + self, runtime: mlrun.runtimes.base.BaseRuntime, run: mlrun.run.RunObject + ): + run.metadata.labels["kind"] = runtime.kind + db = runtime._get_db() + if db and runtime.kind != "handler": + struct = runtime.to_dict() + hash_key = db.store_function( + struct, runtime.metadata.name, runtime.metadata.project, versioned=True + ) + run.spec.function = runtime._function_uri(hash_key=hash_key) diff --git a/mlrun/api/main.py b/mlrun/api/main.py index d9332cd2e492..75840a51777c 100644 --- a/mlrun/api/main.py +++ b/mlrun/api/main.py @@ -25,9 +25,9 @@ from fastapi.exception_handlers import http_exception_handler import mlrun.api.db.base -import mlrun.api.schemas import mlrun.api.utils.clients.chief import mlrun.api.utils.clients.log_collector +import mlrun.common.schemas import mlrun.errors import mlrun.lists import mlrun.utils @@ -43,6 +43,7 @@ run_function_periodically, ) from mlrun.api.utils.singletons.db import get_db, initialize_db +from mlrun.api.utils.singletons.k8s import get_k8s_helper from mlrun.api.utils.singletons.logs_dir import initialize_logs_dir from mlrun.api.utils.singletons.project_member import ( get_project_member, @@ -51,7 +52,6 @@ from mlrun.api.utils.singletons.scheduler import get_scheduler, initialize_scheduler from mlrun.config import config from mlrun.errors import err_to_str -from mlrun.k8s_utils import get_k8s_helper from mlrun.runtimes import RuntimeClassMode, RuntimeKinds, get_runtime_handler from mlrun.utils import logger @@ -80,8 +80,9 @@ ) app.include_router(api_router, prefix=BASE_VERSIONED_API_PREFIX) # This is for backward compatibility, that is why we still leave it here but not include it in the schema -# so new users won't use the old un-versioned api -# TODO: remove in 1.4.0 +# so new users won't use the old un-versioned api. +# /api points to /api/v1 since it is used externally, and we don't want to break it. +# TODO: make sure UI and all relevant Iguazio versions uses /api/v1 and deprecate this app.include_router(api_router, prefix=API_PREFIX, include_in_schema=False) init_middlewares(app) @@ -103,6 +104,11 @@ async def generic_error_handler(request: fastapi.Request, exc: Exception): async def http_status_error_handler( request: fastapi.Request, exc: mlrun.errors.MLRunHTTPStatusError ): + request_id = None + + # request might not have request id when the error is raised before the request id is set on middleware + if hasattr(request.state, "request_id"): + request_id = request.state.request_id status_code = exc.response.status_code error_message = repr(exc) logger.warning( @@ -110,6 +116,7 @@ async def http_status_error_handler( error_message=error_message, status_code=status_code, traceback=traceback.format_exc(), + request_id=request_id, ) return await http_exception_handler( request, @@ -120,8 +127,8 @@ async def http_status_error_handler( @app.on_event("startup") async def startup_event(): logger.info( - "configuration dump", - dumped_config=config.dump_yaml(), + "On startup event handler called", + config=config.dump_yaml(), version=mlrun.utils.version.Version().get(), ) loop = asyncio.get_running_loop() @@ -136,13 +143,13 @@ async def startup_event(): if ( config.httpdb.clusterization.worker.sync_with_chief.mode - == mlrun.api.schemas.WaitForChiefToReachOnlineStateFeatureFlag.enabled + == mlrun.common.schemas.WaitForChiefToReachOnlineStateFeatureFlag.enabled and config.httpdb.clusterization.role - == mlrun.api.schemas.ClusterizationRole.worker + == mlrun.common.schemas.ClusterizationRole.worker ): _start_chief_clusterization_spec_sync_loop() - if config.httpdb.state == mlrun.api.schemas.APIStates.online: + if config.httpdb.state == mlrun.common.schemas.APIStates.online: await move_api_to_online() @@ -165,7 +172,10 @@ async def move_api_to_online(): initialize_project_member() # maintenance periodic functions should only run on the chief instance - if config.httpdb.clusterization.role == mlrun.api.schemas.ClusterizationRole.chief: + if ( + config.httpdb.clusterization.role + == mlrun.common.schemas.ClusterizationRole.chief + ): # runs cleanup/monitoring is not needed if we're not inside kubernetes cluster if get_k8s_helper(silent=True).is_running_inside_kubernetes_cluster(): _start_periodic_cleanup() @@ -175,7 +185,7 @@ async def move_api_to_online(): async def _start_logs_collection(): - if config.log_collector.mode == mlrun.api.schemas.LogsCollectorMode.legacy: + if config.log_collector.mode == mlrun.common.schemas.LogsCollectorMode.legacy: logger.info( "Using legacy logs collection method, skipping logs collection periodic function", mode=config.log_collector.mode, @@ -268,12 +278,17 @@ async def _initiate_logs_collection(start_logs_limit: asyncio.Semaphore): """ db_session = await fastapi.concurrency.run_in_threadpool(create_session) try: + # we don't want initiate logs collection for aborted runs + run_states = mlrun.runtimes.constants.RunStates.all() + run_states.remove(mlrun.runtimes.constants.RunStates.aborted) + # list all the runs in the system which we didn't request logs collection for yet runs = await fastapi.concurrency.run_in_threadpool( get_db().list_distinct_runs_uids, db_session, requested_logs_modes=[False], only_uids=False, + states=run_states, ) if runs: logger.debug( @@ -410,7 +425,7 @@ def _start_periodic_runs_monitoring(): async def _start_periodic_stop_logs(): - if config.log_collector.mode == mlrun.api.schemas.LogsCollectorMode.legacy: + if config.log_collector.mode == mlrun.common.schemas.LogsCollectorMode.legacy: logger.info( "Using legacy logs collection method, skipping stop logs periodic function", mode=config.log_collector.mode, @@ -441,7 +456,12 @@ async def _verify_log_collection_stopped_on_startup(): db_session, requested_logs_modes=[True], only_uids=False, - states=mlrun.runtimes.constants.RunStates.terminal_states(), + states=mlrun.runtimes.constants.RunStates.terminal_states() + + [ + # add unknown state as well, as it's possible that the run reached such state + # usually it happens when run pods get preempted + mlrun.runtimes.constants.RunStates.unknown, + ], ) if len(runs) > 0: @@ -469,7 +489,7 @@ def _start_chief_clusterization_spec_sync_loop(): async def _synchronize_with_chief_clusterization_spec(): # sanity # if we are still in the periodic function and the worker has reached the terminal state, then cancel it - if config.httpdb.state in mlrun.api.schemas.APIStates.terminal_states(): + if config.httpdb.state in mlrun.common.schemas.APIStates.terminal_states(): cancel_periodic_function(_synchronize_with_chief_clusterization_spec.__name__) try: @@ -488,14 +508,14 @@ async def _synchronize_with_chief_clusterization_spec(): async def _align_worker_state_with_chief_state( - clusterization_spec: mlrun.api.schemas.ClusterizationSpec, + clusterization_spec: mlrun.common.schemas.ClusterizationSpec, ): chief_state = clusterization_spec.chief_api_state if not chief_state: logger.warning("Chief did not return any state") return - if chief_state not in mlrun.api.schemas.APIStates.terminal_states(): + if chief_state not in mlrun.common.schemas.APIStates.terminal_states(): logger.debug( "Chief did not reach online state yet, will retry after sync interval", interval=config.httpdb.clusterization.worker.sync_with_chief.interval, @@ -505,7 +525,7 @@ async def _align_worker_state_with_chief_state( config.httpdb.state = chief_state return - if chief_state == mlrun.api.schemas.APIStates.online: + if chief_state == mlrun.common.schemas.APIStates.online: logger.info("Chief reached online state! Switching worker state to online") await move_api_to_online() logger.info("Worker state reached online") @@ -562,7 +582,6 @@ def _push_terminal_run_notifications(db: mlrun.api.db.base.DBInterface, db_sessi Get all runs with notification configs which became terminal since the last call to the function and push their notifications if they haven't been pushed yet. """ - # Import here to avoid circular import import mlrun.api.api.utils @@ -572,6 +591,8 @@ def _push_terminal_run_notifications(db: mlrun.api.db.base.DBInterface, db_sessi # and their notifications haven't been sent yet. global _last_notification_push_time + now = datetime.datetime.now(datetime.timezone.utc) + runs = db.list_runs( db_session, project="*", @@ -580,6 +601,9 @@ def _push_terminal_run_notifications(db: mlrun.api.db.base.DBInterface, db_sessi with_notifications=True, ) + if not len(runs): + return + # Unmasking the run parameters from secrets before handing them over to the notification handler # as importing the `Secrets` crud in the notification handler will cause a circular import unmasked_runs = [ @@ -592,7 +616,7 @@ def _push_terminal_run_notifications(db: mlrun.api.db.base.DBInterface, db_sessi ) mlrun.utils.notifications.NotificationPusher(unmasked_runs).push(db) - _last_notification_push_time = datetime.datetime.now(datetime.timezone.utc) + _last_notification_push_time = now async def _stop_logs(): @@ -647,16 +671,19 @@ async def _stop_logs_for_runs(runs: list): def main(): - if config.httpdb.clusterization.role == mlrun.api.schemas.ClusterizationRole.chief: + if ( + config.httpdb.clusterization.role + == mlrun.common.schemas.ClusterizationRole.chief + ): init_data() elif ( config.httpdb.clusterization.worker.sync_with_chief.mode - == mlrun.api.schemas.WaitForChiefToReachOnlineStateFeatureFlag.enabled + == mlrun.common.schemas.WaitForChiefToReachOnlineStateFeatureFlag.enabled and config.httpdb.clusterization.role - == mlrun.api.schemas.ClusterizationRole.worker + == mlrun.common.schemas.ClusterizationRole.worker ): # we set this state to mark the phase between the startup of the instance until we able to pull the chief state - config.httpdb.state = mlrun.api.schemas.APIStates.waiting_for_chief + config.httpdb.state = mlrun.common.schemas.APIStates.waiting_for_chief logger.info( "Starting API server", diff --git a/mlrun/api/middlewares.py b/mlrun/api/middlewares.py index 2f6a49a1b3db..cf842cca40cc 100644 --- a/mlrun/api/middlewares.py +++ b/mlrun/api/middlewares.py @@ -21,7 +21,7 @@ import uvicorn.protocols.utils from starlette.middleware.base import BaseHTTPMiddleware -import mlrun.api.schemas.constants +import mlrun.common.schemas.constants from mlrun.config import config from mlrun.utils import logger @@ -46,6 +46,7 @@ async def log_request_response(request: fastapi.Request, call_next): path_with_query_string = uvicorn.protocols.utils.get_path_with_query_string( request.scope ) + request.state.request_id = request_id start_time = time.perf_counter_ns() if not any( silent_logging_path in path_with_query_string @@ -100,7 +101,7 @@ async def ui_clear_cache(request: fastapi.Request, call_next): This middleware tells ui when to clear its cache based on backend version changes. """ ui_version = request.headers.get( - mlrun.api.schemas.constants.HeaderNames.ui_version, "" + mlrun.common.schemas.constants.HeaderNames.ui_version, "" ) response: fastapi.Response = await call_next(request) development_version = config.version.startswith("0.0.0") @@ -117,7 +118,7 @@ async def ui_clear_cache(request: fastapi.Request, call_next): # tell ui to reload response.headers[ - mlrun.api.schemas.constants.HeaderNames.ui_clear_cache + mlrun.common.schemas.constants.HeaderNames.ui_clear_cache ] = "true" return response @@ -128,7 +129,7 @@ async def ensure_be_version(request: fastapi.Request, call_next): """ response: fastapi.Response = await call_next(request) response.headers[ - mlrun.api.schemas.constants.HeaderNames.backend_version + mlrun.common.schemas.constants.HeaderNames.backend_version ] = config.version return response diff --git a/mlrun/api/migrations_mysql/versions/28383af526f3_market_place_to_hub.py b/mlrun/api/migrations_mysql/versions/28383af526f3_market_place_to_hub.py new file mode 100644 index 000000000000..8edf1db6f7aa --- /dev/null +++ b/mlrun/api/migrations_mysql/versions/28383af526f3_market_place_to_hub.py @@ -0,0 +1,40 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""market_place_to_hub + +Revision ID: 28383af526f3 +Revises: c905d15bd91d +Create Date: 2023-04-24 11:06:36.177314 + +""" +from alembic import op + +# revision identifiers, used by Alembic. +revision = "28383af526f3" +down_revision = "c905d15bd91d" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table("marketplace_sources", "hub_sources") + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.rename_table("hub_sources", "marketplace_sources") + # ### end Alembic commands ### diff --git a/mlrun/api/migrations_mysql/versions/c905d15bd91d_notifications.py b/mlrun/api/migrations_mysql/versions/c905d15bd91d_notifications.py index f4798756aaf9..4da397aabf65 100644 --- a/mlrun/api/migrations_mysql/versions/c905d15bd91d_notifications.py +++ b/mlrun/api/migrations_mysql/versions/c905d15bd91d_notifications.py @@ -33,7 +33,7 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table( - "notifications", + "runs_notifications", sa.Column("id", sa.Integer(), nullable=False), sa.Column("project", sa.String(length=255, collation="utf8_bin")), sa.Column("name", sa.String(length=255, collation="utf8_bin"), nullable=False), @@ -49,17 +49,19 @@ def upgrade(): "condition", sa.String(length=255, collation="utf8_bin"), nullable=False ), sa.Column("params", sa.JSON(), nullable=True), - sa.Column("run", sa.Integer(), nullable=True), + # A generic parent_id rather than run_id since notification table is standard across objects, see the + # make_notification function for its definition and usage. + sa.Column("parent_id", sa.Integer(), nullable=True), sa.Column("sent_time", mysql.TIMESTAMP(fsp=3), nullable=True), sa.Column( "status", sa.String(length=255, collation="utf8_bin"), nullable=False ), sa.ForeignKeyConstraint( - ["run"], + ["parent_id"], ["runs.id"], ), sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("name", "run", name="_notifications_uc"), + sa.UniqueConstraint("name", "parent_id", name="_runs_notifications_uc"), ) # ### end Alembic commands ### diff --git a/mlrun/api/migrations_sqlite/versions/4acd9430b093_market_place_to_hub.py b/mlrun/api/migrations_sqlite/versions/4acd9430b093_market_place_to_hub.py new file mode 100644 index 000000000000..553dbad93b9a --- /dev/null +++ b/mlrun/api/migrations_sqlite/versions/4acd9430b093_market_place_to_hub.py @@ -0,0 +1,77 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""market_place_to_hub + +Revision ID: 4acd9430b093 +Revises: 959ae00528ad +Create Date: 2023-04-26 22:41:59.726305 + +""" +import sqlalchemy as sa +from alembic import op + +from mlrun.api.utils.db.sql_collation import SQLCollationUtil + +# revision identifiers, used by Alembic. +revision = "4acd9430b093" +down_revision = "959ae00528ad" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + rename_hub_marketplace_table("marketplace_sources", "hub_sources") + # ### end Alembic commands ### + + +def downgrade(): + pass + # ### commands auto generated by Alembic - please adjust! ### + rename_hub_marketplace_table("hub_sources", "marketplace_sources") + # ### end Alembic commands ### + + +def rename_hub_marketplace_table(current_name, new_name): + op.create_table( + new_name, + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "name", + sa.String(255, collation=SQLCollationUtil.collation()), + nullable=True, + ), + sa.Column("index", sa.Integer(), nullable=True), + sa.Column("created", sa.TIMESTAMP(), nullable=True), + sa.Column("updated", sa.TIMESTAMP(), nullable=True), + sa.Column("object", sa.JSON(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("name", name=f"_{new_name}_uc"), + ) + hub_sources = sa.sql.table( + new_name, + sa.Column("name", sa.String(255), nullable=True), + sa.Column("object", sa.JSON, nullable=True), + sa.Column("index", sa.Integer, nullable=True), + sa.Column("created", sa.TIMESTAMP, nullable=True), + sa.Column("updated", sa.TIMESTAMP, nullable=True), + ) + connection = op.get_bind() + select_previous_table_data_query = connection.execute( + f"SELECT * FROM {current_name}" + ) + previous_table_data = select_previous_table_data_query.fetchall() + op.bulk_insert(hub_sources, previous_table_data) + op.drop_table(current_name) diff --git a/mlrun/api/migrations_sqlite/versions/959ae00528ad_notifications.py b/mlrun/api/migrations_sqlite/versions/959ae00528ad_notifications.py index 9dc46d6a2480..a66f76b07fd9 100644 --- a/mlrun/api/migrations_sqlite/versions/959ae00528ad_notifications.py +++ b/mlrun/api/migrations_sqlite/versions/959ae00528ad_notifications.py @@ -32,7 +32,7 @@ def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table( - "notifications", + "runs_notifications", sa.Column("id", sa.Integer(), nullable=False), sa.Column("project", sa.String(length=255)), sa.Column("name", sa.String(length=255), nullable=False), @@ -42,15 +42,17 @@ def upgrade(): sa.Column("when", sa.String(length=255), nullable=False), sa.Column("condition", sa.String(length=255), nullable=False), sa.Column("params", sa.JSON(), nullable=True), - sa.Column("run", sa.Integer(), nullable=True), + # A generic parent_id rather than run_id since notification table is standard across objects, see the + # make_notification function for its definition and usage. + sa.Column("parent_id", sa.Integer(), nullable=True), sa.Column("sent_time", sa.TIMESTAMP(), nullable=True), sa.Column("status", sa.String(length=255), nullable=False), sa.ForeignKeyConstraint( - ["run"], + ["parent_id"], ["runs.id"], ), sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("name", "run", name="_notifications_uc"), + sa.UniqueConstraint("name", "parent_id", name="_runs_notifications_uc"), ) # ### end Alembic commands ### diff --git a/mlrun/api/schemas/__init__.py b/mlrun/api/schemas/__init__.py index a2f91cd26645..13d30e387208 100644 --- a/mlrun/api/schemas/__init__.py +++ b/mlrun/api/schemas/__init__.py @@ -14,141 +14,218 @@ # # flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx -from .artifact import ArtifactCategories, ArtifactIdentifier, ArtifactsFormat -from .auth import ( - AuthInfo, - AuthorizationAction, - AuthorizationResourceTypes, - AuthorizationVerificationInput, - Credentials, - ProjectsRole, -) -from .background_task import ( - BackgroundTask, - BackgroundTaskMetadata, - BackgroundTaskSpec, - BackgroundTaskState, - BackgroundTaskStatus, -) -from .client_spec import ClientSpec -from .clusterization_spec import ( - ClusterizationSpec, - WaitForChiefToReachOnlineStateFeatureFlag, -) -from .constants import ( - APIStates, - ClusterizationRole, - DeletionStrategy, - FeatureStorePartitionByField, - HeaderNames, - LogsCollectorMode, - OrderType, - PatchMode, - RunPartitionByField, - SortField, -) -from .feature_store import ( - EntitiesOutput, - Entity, - EntityListOutput, - EntityRecord, - Feature, - FeatureListOutput, - FeatureRecord, - FeatureSet, - FeatureSetDigestOutput, - FeatureSetDigestSpec, - FeatureSetIngestInput, - FeatureSetIngestOutput, - FeatureSetRecord, - FeatureSetsOutput, - FeatureSetSpec, - FeatureSetsTagsOutput, - FeaturesOutput, - FeatureVector, - FeatureVectorRecord, - FeatureVectorsOutput, - FeatureVectorsTagsOutput, -) -from .frontend_spec import ( - AuthenticationFeatureFlag, - FeatureFlags, - FrontendSpec, - NuclioStreamsFeatureFlag, - PreemptionNodesFeatureFlag, - ProjectMembershipFeatureFlag, -) -from .function import FunctionState, PreemptionModes, SecurityContextEnrichmentModes -from .http import HTTPSessionRetryMode -from .k8s import NodeSelectorOperator, Resources, ResourceSpec -from .marketplace import ( - IndexedMarketplaceSource, - MarketplaceCatalog, - MarketplaceItem, - MarketplaceObjectMetadata, - MarketplaceSource, - MarketplaceSourceSpec, - last_source_index, -) -from .memory_reports import MostCommonObjectTypesReport, ObjectTypeReport -from .model_endpoints import ( - Features, - FeatureValues, - GrafanaColumn, - GrafanaDataPoint, - GrafanaNumberColumn, - GrafanaStringColumn, - GrafanaTable, - GrafanaTimeSeriesTarget, - Metric, - ModelEndpoint, - ModelEndpointList, - ModelEndpointMetadata, - ModelEndpointSpec, - ModelEndpointStatus, - ModelMonitoringMode, - ModelMonitoringStoreKinds, -) -from .notification import NotificationSeverity, NotificationStatus -from .object import ObjectKind, ObjectMetadata, ObjectSpec, ObjectStatus -from .pipeline import PipelinesFormat, PipelinesOutput, PipelinesPagination -from .project import ( - IguazioProject, - Project, - ProjectDesiredState, - ProjectMetadata, - ProjectOwner, - ProjectsFormat, - ProjectsOutput, - ProjectSpec, - ProjectState, - ProjectStatus, - ProjectSummariesOutput, - ProjectSummary, -) -from .runtime_resource import ( - GroupedByJobRuntimeResourcesOutput, - GroupedByProjectRuntimeResourcesOutput, - KindRuntimeResources, - ListRuntimeResourcesGroupByField, - RuntimeResource, - RuntimeResources, - RuntimeResourcesOutput, -) -from .schedule import ( - ScheduleCronTrigger, - ScheduleInput, - ScheduleKinds, - ScheduleOutput, - ScheduleRecord, - SchedulesOutput, - ScheduleUpdate, -) -from .secret import ( - AuthSecretData, - SecretKeysData, - SecretProviderName, - SecretsData, - UserSecretCreationRequest, -) -from .tag import Tag, TagObjects +""" +Schemas were moved to mlrun.common.schemas. +For backwards compatibility with mlrun.api.schemas, we use this file to convert the old imports to the new ones. +The DeprecationHelper class is used to print a deprecation warning when the old import is used, and return the new +schema. +""" + +import sys + +import mlrun.common.schemas +import mlrun.common.schemas.artifact as old_artifact +import mlrun.common.schemas.auth as old_auth +import mlrun.common.schemas.background_task as old_background_task +import mlrun.common.schemas.client_spec as old_client_spec +import mlrun.common.schemas.clusterization_spec as old_clusterization_spec +import mlrun.common.schemas.constants as old_constants +import mlrun.common.schemas.feature_store as old_feature_store +import mlrun.common.schemas.frontend_spec as old_frontend_spec +import mlrun.common.schemas.function as old_function +import mlrun.common.schemas.http as old_http +import mlrun.common.schemas.k8s as old_k8s +import mlrun.common.schemas.memory_reports as old_memory_reports +import mlrun.common.schemas.object as old_object +import mlrun.common.schemas.pipeline as old_pipeline +import mlrun.common.schemas.project as old_project +import mlrun.common.schemas.runtime_resource as old_runtime_resource +import mlrun.common.schemas.schedule as old_schedule +import mlrun.common.schemas.secret as old_secret +import mlrun.common.schemas.tag as old_tag +from mlrun.utils.helpers import DeprecationHelper + +# for backwards compatibility, we need to inject the old import path to `sys.modules` +sys.modules["mlrun.api.schemas.artifact"] = old_artifact +sys.modules["mlrun.api.schemas.auth"] = old_auth +sys.modules["mlrun.api.schemas.background_task"] = old_background_task +sys.modules["mlrun.api.schemas.client_spec"] = old_client_spec +sys.modules["mlrun.api.schemas.clusterization_spec"] = old_clusterization_spec +sys.modules["mlrun.api.schemas.constants"] = old_constants +sys.modules["mlrun.api.schemas.feature_store"] = old_feature_store +sys.modules["mlrun.api.schemas.frontend_spec"] = old_frontend_spec +sys.modules["mlrun.api.schemas.function"] = old_function +sys.modules["mlrun.api.schemas.http"] = old_http +sys.modules["mlrun.api.schemas.k8s"] = old_k8s +sys.modules["mlrun.api.schemas.memory_reports"] = old_memory_reports +sys.modules["mlrun.api.schemas.object"] = old_object +sys.modules["mlrun.api.schemas.pipeline"] = old_pipeline +sys.modules["mlrun.api.schemas.project"] = old_project +sys.modules["mlrun.api.schemas.runtime_resource"] = old_runtime_resource +sys.modules["mlrun.api.schemas.schedule"] = old_schedule +sys.modules["mlrun.api.schemas.secret"] = old_secret +sys.modules["mlrun.api.schemas.tag"] = old_tag + +# The DeprecationHelper class is used to print a deprecation warning when the old import is used, +# and return the new schema. This is done for backwards compatibility with mlrun.api.schemas. +ArtifactCategories = DeprecationHelper(mlrun.common.schemas.ArtifactCategories) +ArtifactIdentifier = DeprecationHelper(mlrun.common.schemas.ArtifactIdentifier) +ArtifactsFormat = DeprecationHelper(mlrun.common.schemas.ArtifactsFormat) +AuthInfo = DeprecationHelper(mlrun.common.schemas.AuthInfo) +AuthorizationAction = DeprecationHelper(mlrun.common.schemas.AuthorizationAction) +AuthorizationResourceTypes = DeprecationHelper( + mlrun.common.schemas.AuthorizationResourceTypes +) +AuthorizationVerificationInput = DeprecationHelper( + mlrun.common.schemas.AuthorizationVerificationInput +) +Credentials = DeprecationHelper(mlrun.common.schemas.Credentials) +ProjectsRole = DeprecationHelper(mlrun.common.schemas.ProjectsRole) + +BackgroundTask = DeprecationHelper(mlrun.common.schemas.BackgroundTask) +BackgroundTaskMetadata = DeprecationHelper(mlrun.common.schemas.BackgroundTaskMetadata) +BackgroundTaskSpec = DeprecationHelper(mlrun.common.schemas.BackgroundTaskSpec) +BackgroundTaskState = DeprecationHelper(mlrun.common.schemas.BackgroundTaskState) +BackgroundTaskStatus = DeprecationHelper(mlrun.common.schemas.BackgroundTaskStatus) +ClientSpe = DeprecationHelper(mlrun.common.schemas.ClientSpec) +ClusterizationSpec = DeprecationHelper(mlrun.common.schemas.ClusterizationSpec) +WaitForChiefToReachOnlineStateFeatureFlag = DeprecationHelper( + mlrun.common.schemas.WaitForChiefToReachOnlineStateFeatureFlag +) +APIStates = DeprecationHelper(mlrun.common.schemas.APIStates) +ClusterizationRole = DeprecationHelper(mlrun.common.schemas.ClusterizationRole) +DeletionStrategy = DeprecationHelper(mlrun.common.schemas.DeletionStrategy) +FeatureStorePartitionByField = DeprecationHelper( + mlrun.common.schemas.FeatureStorePartitionByField +) +HeaderNames = DeprecationHelper(mlrun.common.schemas.HeaderNames) +LogsCollectorMode = DeprecationHelper(mlrun.common.schemas.LogsCollectorMode) +OrderType = DeprecationHelper(mlrun.common.schemas.OrderType) +PatchMode = DeprecationHelper(mlrun.common.schemas.PatchMode) +RunPartitionByField = DeprecationHelper(mlrun.common.schemas.RunPartitionByField) +SortField = DeprecationHelper(mlrun.common.schemas.SortField) +EntitiesOutput = DeprecationHelper(mlrun.common.schemas.EntitiesOutput) +Entity = DeprecationHelper(mlrun.common.schemas.Entity) +EntityListOutput = DeprecationHelper(mlrun.common.schemas.EntityListOutput) +EntityRecord = DeprecationHelper(mlrun.common.schemas.EntityRecord) +Feature = DeprecationHelper(mlrun.common.schemas.Feature) +FeatureListOutput = DeprecationHelper(mlrun.common.schemas.FeatureListOutput) +FeatureRecord = DeprecationHelper(mlrun.common.schemas.FeatureRecord) +FeatureSet = DeprecationHelper(mlrun.common.schemas.FeatureSet) +FeatureSetDigestOutput = DeprecationHelper(mlrun.common.schemas.FeatureSetDigestOutput) +FeatureSetDigestSpec = DeprecationHelper(mlrun.common.schemas.FeatureSetDigestSpec) +FeatureSetIngestInput = DeprecationHelper(mlrun.common.schemas.FeatureSetIngestInput) +FeatureSetIngestOutput = DeprecationHelper(mlrun.common.schemas.FeatureSetIngestOutput) +FeatureSetRecord = DeprecationHelper(mlrun.common.schemas.FeatureSetRecord) +FeatureSetsOutput = DeprecationHelper(mlrun.common.schemas.FeatureSetsOutput) +FeatureSetSpec = DeprecationHelper(mlrun.common.schemas.FeatureSetSpec) +FeatureSetsTagsOutput = DeprecationHelper(mlrun.common.schemas.FeatureSetsTagsOutput) +FeaturesOutput = DeprecationHelper(mlrun.common.schemas.FeaturesOutput) +FeatureVector = DeprecationHelper(mlrun.common.schemas.FeatureVector) +FeatureVectorRecord = DeprecationHelper(mlrun.common.schemas.FeatureVectorRecord) +FeatureVectorsOutput = DeprecationHelper(mlrun.common.schemas.FeatureVectorsOutput) +FeatureVectorsTagsOutput = DeprecationHelper( + mlrun.common.schemas.FeatureVectorsTagsOutput +) +AuthenticationFeatureFlag = DeprecationHelper( + mlrun.common.schemas.AuthenticationFeatureFlag +) +FeatureFlags = DeprecationHelper(mlrun.common.schemas.FeatureFlags) +FrontendSpec = DeprecationHelper(mlrun.common.schemas.FrontendSpec) +NuclioStreamsFeatureFlag = DeprecationHelper( + mlrun.common.schemas.NuclioStreamsFeatureFlag +) +PreemptionNodesFeatureFlag = DeprecationHelper( + mlrun.common.schemas.PreemptionNodesFeatureFlag +) +ProjectMembershipFeatureFlag = DeprecationHelper( + mlrun.common.schemas.ProjectMembershipFeatureFlag +) +FunctionState = DeprecationHelper(mlrun.common.schemas.FunctionState) +PreemptionModes = DeprecationHelper(mlrun.common.schemas.PreemptionModes) +SecurityContextEnrichmentModes = DeprecationHelper( + mlrun.common.schemas.SecurityContextEnrichmentModes +) +HTTPSessionRetryMode = DeprecationHelper(mlrun.common.schemas.HTTPSessionRetryMode) +NodeSelectorOperator = DeprecationHelper(mlrun.common.schemas.NodeSelectorOperator) +Resources = DeprecationHelper(mlrun.common.schemas.Resources) +ResourceSpec = DeprecationHelper(mlrun.common.schemas.ResourceSpec) +IndexedHubSource = DeprecationHelper(mlrun.common.schemas.IndexedHubSource) +HubCatalog = DeprecationHelper(mlrun.common.schemas.HubCatalog) +HubItem = DeprecationHelper(mlrun.common.schemas.HubItem) +HubObjectMetadata = DeprecationHelper(mlrun.common.schemas.HubObjectMetadata) +HubSource = DeprecationHelper(mlrun.common.schemas.HubSource) +HubSourceSpec = DeprecationHelper(mlrun.common.schemas.HubSourceSpec) +last_source_index = DeprecationHelper(mlrun.common.schemas.last_source_index) +MostCommonObjectTypesReport = DeprecationHelper( + mlrun.common.schemas.MostCommonObjectTypesReport +) +ObjectTypeReport = DeprecationHelper(mlrun.common.schemas.ObjectTypeReport) +Features = DeprecationHelper(mlrun.common.schemas.Features) +FeatureValues = DeprecationHelper(mlrun.common.schemas.FeatureValues) +GrafanaColumn = DeprecationHelper(mlrun.common.schemas.GrafanaColumn) +GrafanaDataPoint = DeprecationHelper(mlrun.common.schemas.GrafanaDataPoint) +GrafanaNumberColumn = DeprecationHelper(mlrun.common.schemas.GrafanaNumberColumn) +GrafanaStringColumn = DeprecationHelper(mlrun.common.schemas.GrafanaStringColumn) +GrafanaTable = DeprecationHelper(mlrun.common.schemas.GrafanaTable) +GrafanaTimeSeriesTarget = DeprecationHelper( + mlrun.common.schemas.GrafanaTimeSeriesTarget +) +ModelEndpoint = DeprecationHelper(mlrun.common.schemas.ModelEndpoint) +ModelEndpointList = DeprecationHelper(mlrun.common.schemas.ModelEndpointList) +ModelEndpointMetadata = DeprecationHelper(mlrun.common.schemas.ModelEndpointMetadata) +ModelEndpointSpec = DeprecationHelper(mlrun.common.schemas.ModelEndpointSpec) +ModelEndpointStatus = DeprecationHelper(mlrun.common.schemas.ModelEndpointStatus) +ModelMonitoringStoreKinds = DeprecationHelper( + mlrun.common.schemas.ModelMonitoringStoreKinds +) +NotificationSeverity = DeprecationHelper(mlrun.common.schemas.NotificationSeverity) +NotificationStatus = DeprecationHelper(mlrun.common.schemas.NotificationStatus) +ObjectKind = DeprecationHelper(mlrun.common.schemas.ObjectKind) +ObjectMetadata = DeprecationHelper(mlrun.common.schemas.ObjectMetadata) +ObjectSpec = DeprecationHelper(mlrun.common.schemas.ObjectSpec) +ObjectStatus = DeprecationHelper(mlrun.common.schemas.ObjectStatus) +PipelinesFormat = DeprecationHelper(mlrun.common.schemas.PipelinesFormat) +PipelinesOutput = DeprecationHelper(mlrun.common.schemas.PipelinesOutput) +PipelinesPagination = DeprecationHelper(mlrun.common.schemas.PipelinesPagination) +IguazioProject = DeprecationHelper(mlrun.common.schemas.IguazioProject) +Project = DeprecationHelper(mlrun.common.schemas.Project) +ProjectDesiredState = DeprecationHelper(mlrun.common.schemas.ProjectDesiredState) +ProjectMetadata = DeprecationHelper(mlrun.common.schemas.ProjectMetadata) +ProjectOwner = DeprecationHelper(mlrun.common.schemas.ProjectOwner) +ProjectsFormat = DeprecationHelper(mlrun.common.schemas.ProjectsFormat) +ProjectsOutput = DeprecationHelper(mlrun.common.schemas.ProjectsOutput) +ProjectSpec = DeprecationHelper(mlrun.common.schemas.ProjectSpec) +ProjectState = DeprecationHelper(mlrun.common.schemas.ProjectState) +ProjectStatus = DeprecationHelper(mlrun.common.schemas.ProjectStatus) +ProjectSummariesOutput = DeprecationHelper(mlrun.common.schemas.ProjectSummariesOutput) +ProjectSummary = DeprecationHelper(mlrun.common.schemas.ProjectSummary) +GroupedByJobRuntimeResourcesOutput = DeprecationHelper( + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput +) +GroupedByProjectRuntimeResourcesOutput = DeprecationHelper( + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput +) +KindRuntimeResources = DeprecationHelper(mlrun.common.schemas.KindRuntimeResources) +ListRuntimeResourcesGroupByField = DeprecationHelper( + mlrun.common.schemas.ListRuntimeResourcesGroupByField +) +RuntimeResource = DeprecationHelper(mlrun.common.schemas.RuntimeResource) +RuntimeResources = DeprecationHelper(mlrun.common.schemas.RuntimeResources) +RuntimeResourcesOutput = DeprecationHelper(mlrun.common.schemas.RuntimeResourcesOutput) +ScheduleCronTrigger = DeprecationHelper(mlrun.common.schemas.ScheduleCronTrigger) +ScheduleInput = DeprecationHelper(mlrun.common.schemas.ScheduleInput) +ScheduleKinds = DeprecationHelper(mlrun.common.schemas.ScheduleKinds) +ScheduleOutput = DeprecationHelper(mlrun.common.schemas.ScheduleOutput) +ScheduleRecord = DeprecationHelper(mlrun.common.schemas.ScheduleRecord) +SchedulesOutput = DeprecationHelper(mlrun.common.schemas.SchedulesOutput) +ScheduleUpdate = DeprecationHelper(mlrun.common.schemas.ScheduleUpdate) +AuthSecretData = DeprecationHelper(mlrun.common.schemas.AuthSecretData) +SecretKeysData = DeprecationHelper(mlrun.common.schemas.SecretKeysData) +SecretProviderName = DeprecationHelper(mlrun.common.schemas.SecretProviderName) +SecretsData = DeprecationHelper(mlrun.common.schemas.SecretsData) +UserSecretCreationRequest = DeprecationHelper( + mlrun.common.schemas.UserSecretCreationRequest +) +Tag = DeprecationHelper(mlrun.common.schemas.Tag) +TagObjects = DeprecationHelper(mlrun.common.schemas.TagObjects) diff --git a/mlrun/api/schemas/model_endpoints.py b/mlrun/api/schemas/model_endpoints.py deleted file mode 100644 index 0ae5aee397f0..000000000000 --- a/mlrun/api/schemas/model_endpoints.py +++ /dev/null @@ -1,185 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -from typing import Any, Dict, List, Optional, Tuple, Union - -from pydantic import BaseModel, Field -from pydantic.main import Extra - -import mlrun.api.utils.helpers -from mlrun.api.schemas.object import ObjectKind, ObjectSpec, ObjectStatus -from mlrun.utils.model_monitoring import EndpointType, create_model_endpoint_id - - -class ModelMonitoringStoreKinds: - ENDPOINTS = "endpoints" - EVENTS = "events" - - -class ModelEndpointMetadata(BaseModel): - project: Optional[str] - labels: Optional[dict] = {} - uid: Optional[str] - - class Config: - extra = Extra.allow - - -class ModelMonitoringMode(mlrun.api.utils.helpers.StrEnum): - enabled = "enabled" - disabled = "disabled" - - -class ModelEndpointSpec(ObjectSpec): - function_uri: Optional[str] # /: - model: Optional[str] # : - model_class: Optional[str] - model_uri: Optional[str] - feature_names: Optional[List[str]] - label_names: Optional[List[str]] - stream_path: Optional[str] - algorithm: Optional[str] - monitor_configuration: Optional[dict] = {} - active: Optional[bool] - monitoring_mode: Optional[str] = ModelMonitoringMode.disabled - - -class Metric(BaseModel): - name: str - values: List[Tuple[str, float]] - - -class Histogram(BaseModel): - buckets: List[float] - counts: List[int] - - -class FeatureValues(BaseModel): - min: float - mean: float - max: float - histogram: Histogram - - @classmethod - def from_dict(cls, stats: Optional[dict]): - if stats: - return FeatureValues( - min=stats["min"], - mean=stats["mean"], - max=stats["max"], - histogram=Histogram(buckets=stats["hist"][1], counts=stats["hist"][0]), - ) - else: - return None - - -class Features(BaseModel): - name: str - weight: float - expected: Optional[FeatureValues] - actual: Optional[FeatureValues] - - @classmethod - def new( - cls, - feature_name: str, - feature_stats: Optional[dict], - current_stats: Optional[dict], - ): - return cls( - name=feature_name, - weight=-1.0, - expected=FeatureValues.from_dict(feature_stats), - actual=FeatureValues.from_dict(current_stats), - ) - - -class ModelEndpointStatus(ObjectStatus): - feature_stats: Optional[dict] = {} - current_stats: Optional[dict] = {} - first_request: Optional[str] - last_request: Optional[str] - accuracy: Optional[float] - error_count: Optional[int] - drift_status: Optional[str] - drift_measures: Optional[dict] = {} - metrics: Optional[Dict[str, Metric]] - features: Optional[List[Features]] - children: Optional[List[str]] - children_uids: Optional[List[str]] - endpoint_type: Optional[EndpointType] - monitoring_feature_set_uri: Optional[str] - - class Config: - extra = Extra.allow - - -class ModelEndpoint(BaseModel): - kind: ObjectKind = Field(ObjectKind.model_endpoint, const=True) - metadata: ModelEndpointMetadata - spec: ModelEndpointSpec - status: ModelEndpointStatus - - class Config: - extra = Extra.allow - - def __init__(self, **data: Any): - super().__init__(**data) - if self.metadata.uid is None: - uid = create_model_endpoint_id( - function_uri=self.spec.function_uri, - versioned_model=self.spec.model, - ) - self.metadata.uid = str(uid) - - -class ModelEndpointList(BaseModel): - endpoints: List[ModelEndpoint] - - -class GrafanaColumn(BaseModel): - text: str - type: str - - -class GrafanaNumberColumn(GrafanaColumn): - text: str - type: str = "number" - - -class GrafanaStringColumn(GrafanaColumn): - text: str - type: str = "string" - - -class GrafanaTable(BaseModel): - columns: List[GrafanaColumn] - rows: List[List[Optional[Union[float, int, str]]]] = [] - type: str = "table" - - def add_row(self, *args): - self.rows.append(list(args)) - - -class GrafanaDataPoint(BaseModel): - value: float - timestamp: int # Unix timestamp in milliseconds - - -class GrafanaTimeSeriesTarget(BaseModel): - target: str - datapoints: List[Tuple[float, int]] = [] - - def add_data_point(self, data_point: GrafanaDataPoint): - self.datapoints.append((data_point.value, data_point.timestamp)) diff --git a/mlrun/api/utils/auth/providers/base.py b/mlrun/api/utils/auth/providers/base.py index c5c2258139bf..e00cc77a0975 100644 --- a/mlrun/api/utils/auth/providers/base.py +++ b/mlrun/api/utils/auth/providers/base.py @@ -15,7 +15,7 @@ import abc import typing -import mlrun.api.schemas +import mlrun.common.schemas class Provider(abc.ABC): @@ -23,8 +23,8 @@ class Provider(abc.ABC): async def query_permissions( self, resource: str, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, raise_on_forbidden: bool = True, ) -> bool: pass @@ -34,13 +34,13 @@ async def filter_by_permissions( self, resources: typing.List, opa_resource_extractor: typing.Callable, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, ) -> typing.List: pass @abc.abstractmethod def add_allowed_project_for_owner( - self, project_name: str, auth_info: mlrun.api.schemas.AuthInfo + self, project_name: str, auth_info: mlrun.common.schemas.AuthInfo ): pass diff --git a/mlrun/api/utils/auth/providers/nop.py b/mlrun/api/utils/auth/providers/nop.py index 4316585be2b1..987087081363 100644 --- a/mlrun/api/utils/auth/providers/nop.py +++ b/mlrun/api/utils/auth/providers/nop.py @@ -14,7 +14,6 @@ # import typing -import mlrun.api.schemas import mlrun.api.utils.auth.providers.base import mlrun.utils.singleton @@ -26,8 +25,8 @@ class Provider( async def query_permissions( self, resource: str, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, raise_on_forbidden: bool = True, ) -> bool: return True @@ -36,12 +35,12 @@ async def filter_by_permissions( self, resources: typing.List, opa_resource_extractor: typing.Callable, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, ) -> typing.List: return resources def add_allowed_project_for_owner( - self, project_name: str, auth_info: mlrun.api.schemas.AuthInfo + self, project_name: str, auth_info: mlrun.common.schemas.AuthInfo ): pass diff --git a/mlrun/api/utils/auth/providers/opa.py b/mlrun/api/utils/auth/providers/opa.py index d717c736e66e..d8da4b527239 100644 --- a/mlrun/api/utils/auth/providers/opa.py +++ b/mlrun/api/utils/auth/providers/opa.py @@ -21,9 +21,9 @@ import humanfriendly -import mlrun.api.schemas import mlrun.api.utils.auth.providers.base import mlrun.api.utils.projects.remotes.leader +import mlrun.common.schemas import mlrun.errors import mlrun.utils.helpers import mlrun.utils.singleton @@ -66,23 +66,23 @@ def __init__(self) -> None: async def query_permissions( self, resource: str, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, raise_on_forbidden: bool = True, ) -> bool: # store is not really a verb in our OPA manifest, we map it to 2 query permissions requests (create & update) - if action == mlrun.api.schemas.AuthorizationAction.store: + if action == mlrun.common.schemas.AuthorizationAction.store: results = await asyncio.gather( self.query_permissions( resource, - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, raise_on_forbidden, ), self.query_permissions( resource, - mlrun.api.schemas.AuthorizationAction.update, + mlrun.common.schemas.AuthorizationAction.update, auth_info, raise_on_forbidden, ), @@ -113,11 +113,11 @@ async def filter_by_permissions( self, resources: typing.List, opa_resource_extractor: typing.Callable, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, ) -> typing.List: # store is not really a verb in our OPA manifest, we map it to 2 query permissions requests (create & update) - if action == mlrun.api.schemas.AuthorizationAction.store: + if action == mlrun.common.schemas.AuthorizationAction.store: raise NotImplementedError("Store action is not supported in filtering") if self._is_request_from_leader(auth_info.projects_role): return resources @@ -149,7 +149,7 @@ async def filter_by_permissions( return allowed_resources def add_allowed_project_for_owner( - self, project_name: str, auth_info: mlrun.api.schemas.AuthInfo + self, project_name: str, auth_info: mlrun.common.schemas.AuthInfo ): if ( not auth_info.user_id @@ -168,7 +168,7 @@ def add_allowed_project_for_owner( self._allowed_project_owners_cache[auth_info.user_id] = allowed_projects def _check_allowed_project_owners_cache( - self, resource: str, auth_info: mlrun.api.schemas.AuthInfo + self, resource: str, auth_info: mlrun.common.schemas.AuthInfo ): # Cache shouldn't be big, simply clean it on get instead of scheduling it self._clean_expired_records_from_cache() @@ -199,7 +199,7 @@ def _clean_expired_records_from_cache(self): del self._allowed_project_owners_cache[user_id] def _is_request_from_leader( - self, projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] + self, projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] ): if projects_role and projects_role.value == self._leader_name: return True @@ -241,8 +241,8 @@ async def _on_request_api_failure(self, method, path, response, **kwargs): @staticmethod def _generate_permission_request_body( resource: str, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, ) -> dict: body = { "input": { @@ -256,8 +256,8 @@ def _generate_permission_request_body( @staticmethod def _generate_filter_request_body( resources: typing.List[str], - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, ) -> dict: body = { "input": { diff --git a/mlrun/api/utils/auth/verifier.py b/mlrun/api/utils/auth/verifier.py index c70bc10a5041..4ec5479545cc 100644 --- a/mlrun/api/utils/auth/verifier.py +++ b/mlrun/api/utils/auth/verifier.py @@ -19,10 +19,10 @@ import fastapi import mlrun -import mlrun.api.schemas import mlrun.api.utils.auth.providers.nop import mlrun.api.utils.auth.providers.opa import mlrun.api.utils.clients.iguazio +import mlrun.common.schemas import mlrun.utils.singleton @@ -41,11 +41,11 @@ def __init__(self) -> None: async def filter_project_resources_by_permissions( self, - resource_type: mlrun.api.schemas.AuthorizationResourceTypes, + resource_type: mlrun.common.schemas.AuthorizationResourceTypes, resources: typing.List, project_and_resource_name_extractor: typing.Callable, - auth_info: mlrun.api.schemas.AuthInfo, - action: mlrun.api.schemas.AuthorizationAction = mlrun.api.schemas.AuthorizationAction.read, + auth_info: mlrun.common.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction = mlrun.common.schemas.AuthorizationAction.read, ) -> typing.List: def _generate_opa_resource(resource): project_name, resource_name = project_and_resource_name_extractor(resource) @@ -60,8 +60,8 @@ def _generate_opa_resource(resource): async def filter_projects_by_permissions( self, project_names: typing.List[str], - auth_info: mlrun.api.schemas.AuthInfo, - action: mlrun.api.schemas.AuthorizationAction = mlrun.api.schemas.AuthorizationAction.read, + auth_info: mlrun.common.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction = mlrun.common.schemas.AuthorizationAction.read, ) -> typing.List: return await self.filter_by_permissions( project_names, @@ -72,11 +72,11 @@ async def filter_projects_by_permissions( async def query_project_resources_permissions( self, - resource_type: mlrun.api.schemas.AuthorizationResourceTypes, + resource_type: mlrun.common.schemas.AuthorizationResourceTypes, resources: typing.List, project_and_resource_name_extractor: typing.Callable, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, raise_on_forbidden: bool = True, ) -> bool: project_resources = [ @@ -102,11 +102,11 @@ async def query_project_resources_permissions( async def query_project_resource_permissions( self, - resource_type: mlrun.api.schemas.AuthorizationResourceTypes, + resource_type: mlrun.common.schemas.AuthorizationResourceTypes, project_name: str, resource_name: str, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, raise_on_forbidden: bool = True, ) -> bool: return await self.query_permissions( @@ -121,8 +121,8 @@ async def query_project_resource_permissions( async def query_project_permissions( self, project_name: str, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, raise_on_forbidden: bool = True, ) -> bool: return await self.query_permissions( @@ -134,9 +134,9 @@ async def query_project_permissions( async def query_global_resource_permissions( self, - resource_type: mlrun.api.schemas.AuthorizationResourceTypes, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + resource_type: mlrun.common.schemas.AuthorizationResourceTypes, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, raise_on_forbidden: bool = True, ) -> bool: return await self.query_resource_permissions( @@ -149,10 +149,10 @@ async def query_global_resource_permissions( async def query_resource_permissions( self, - resource_type: mlrun.api.schemas.AuthorizationResourceTypes, + resource_type: mlrun.common.schemas.AuthorizationResourceTypes, resource_name: str, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, raise_on_forbidden: bool = True, ) -> bool: return await self.query_permissions( @@ -165,8 +165,8 @@ async def query_resource_permissions( async def query_permissions( self, resource: str, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, raise_on_forbidden: bool = True, ) -> bool: return await self._auth_provider.query_permissions( @@ -177,8 +177,8 @@ async def filter_by_permissions( self, resources: typing.List, opa_resource_extractor: typing.Callable, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, ) -> typing.List: return await self._auth_provider.filter_by_permissions( resources, @@ -188,14 +188,14 @@ async def filter_by_permissions( ) def add_allowed_project_for_owner( - self, project_name: str, auth_info: mlrun.api.schemas.AuthInfo + self, project_name: str, auth_info: mlrun.common.schemas.AuthInfo ): self._auth_provider.add_allowed_project_for_owner(project_name, auth_info) async def authenticate_request( self, request: fastapi.Request - ) -> mlrun.api.schemas.AuthInfo: - auth_info = mlrun.api.schemas.AuthInfo() + ) -> mlrun.common.schemas.AuthInfo: + auth_info = mlrun.common.schemas.AuthInfo() header = request.headers.get("Authorization", "") if self._basic_auth_configured(): if not header.startswith(self._basic_prefix): @@ -228,10 +228,10 @@ async def authenticate_request( auth_info.username = request.headers["x-remote-user"] projects_role_header = request.headers.get( - mlrun.api.schemas.HeaderNames.projects_role + mlrun.common.schemas.HeaderNames.projects_role ) auth_info.projects_role = ( - mlrun.api.schemas.ProjectsRole(projects_role_header) + mlrun.common.schemas.ProjectsRole(projects_role_header) if projects_role_header else None ) @@ -248,7 +248,7 @@ async def authenticate_request( async def generate_auth_info_from_session( self, session: str - ) -> mlrun.api.schemas.AuthInfo: + ) -> mlrun.common.schemas.AuthInfo: if not self._iguazio_auth_configured(): raise NotImplementedError( "Session is currently supported only for iguazio authentication mode" @@ -273,13 +273,15 @@ def is_jobs_auth_required(self): @staticmethod def _generate_resource_string_from_project_name(project_name: str): - return mlrun.api.schemas.AuthorizationResourceTypes.project.to_resource_string( - project_name, "" + return ( + mlrun.common.schemas.AuthorizationResourceTypes.project.to_resource_string( + project_name, "" + ) ) @staticmethod def _generate_resource_string_from_project_resource( - resource_type: mlrun.api.schemas.AuthorizationResourceTypes, + resource_type: mlrun.common.schemas.AuthorizationResourceTypes, project_name: str, resource_name: str, ): diff --git a/mlrun/api/utils/background_tasks.py b/mlrun/api/utils/background_tasks.py index 6372c136267d..29d40279dc87 100644 --- a/mlrun/api/utils/background_tasks.py +++ b/mlrun/api/utils/background_tasks.py @@ -22,10 +22,10 @@ import fastapi.concurrency import sqlalchemy.orm -import mlrun.api.schemas import mlrun.api.utils.helpers import mlrun.api.utils.singletons.db import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.errors import mlrun.utils.singleton from mlrun.utils import logger @@ -41,13 +41,13 @@ def create_background_task( timeout: int = None, # in seconds *args, **kwargs, - ) -> mlrun.api.schemas.BackgroundTask: + ) -> mlrun.common.schemas.BackgroundTask: name = str(uuid.uuid4()) mlrun.api.utils.singletons.db.get_db().store_background_task( db_session, name, project, - mlrun.api.schemas.BackgroundTaskState.running, + mlrun.common.schemas.BackgroundTaskState.running, timeout, ) background_tasks.add_task( @@ -66,7 +66,7 @@ def get_background_task( db_session: sqlalchemy.orm.Session, name: str, project: str, - ) -> mlrun.api.schemas.BackgroundTask: + ) -> mlrun.common.schemas.BackgroundTask: return mlrun.api.utils.singletons.db.get_db().get_background_task( db_session, name, project ) @@ -93,21 +93,21 @@ async def background_task_wrapper( db_session, name, project=project, - state=mlrun.api.schemas.BackgroundTaskState.failed, + state=mlrun.common.schemas.BackgroundTaskState.failed, ) else: mlrun.api.utils.singletons.db.get_db().store_background_task( db_session, name, project=project, - state=mlrun.api.schemas.BackgroundTaskState.succeeded, + state=mlrun.common.schemas.BackgroundTaskState.succeeded, ) class InternalBackgroundTasksHandler(metaclass=mlrun.utils.singleton.Singleton): def __init__(self): self._internal_background_tasks: typing.Dict[ - str, mlrun.api.schemas.BackgroundTask + str, mlrun.common.schemas.BackgroundTask ] = {} @mlrun.api.utils.helpers.ensure_running_on_chief @@ -117,7 +117,7 @@ def create_background_task( function, *args, **kwargs, - ) -> mlrun.api.schemas.BackgroundTask: + ) -> mlrun.common.schemas.BackgroundTask: name = str(uuid.uuid4()) # sanity if name in self._internal_background_tasks: @@ -138,7 +138,7 @@ def create_background_task( def get_background_task( self, name: str, - ) -> mlrun.api.schemas.BackgroundTask: + ) -> mlrun.common.schemas.BackgroundTask: """ :return: returns the background task object and bool whether exists """ @@ -160,17 +160,17 @@ async def background_task_wrapper(self, name: str, function, *args, **kwargs): f"Failed during background task execution: {function.__name__}, exc: {traceback.format_exc()}" ) self._update_background_task( - name, mlrun.api.schemas.BackgroundTaskState.failed + name, mlrun.common.schemas.BackgroundTaskState.failed ) else: self._update_background_task( - name, mlrun.api.schemas.BackgroundTaskState.succeeded + name, mlrun.common.schemas.BackgroundTaskState.succeeded ) def _update_background_task( self, name: str, - state: mlrun.api.schemas.BackgroundTaskState, + state: mlrun.common.schemas.BackgroundTaskState, ): background_task = self._internal_background_tasks[name] background_task.status.state = state @@ -183,31 +183,31 @@ def _generate_background_task_not_found_response( # in order to keep things simple we don't persist the internal background tasks to the DB # If for some reason get is called and the background task doesn't exist, it means that probably we got # restarted, therefore we want to return a failed background task so the client will retry (if needed) - return mlrun.api.schemas.BackgroundTask( - metadata=mlrun.api.schemas.BackgroundTaskMetadata( + return mlrun.common.schemas.BackgroundTask( + metadata=mlrun.common.schemas.BackgroundTaskMetadata( name=name, project=project ), - spec=mlrun.api.schemas.BackgroundTaskSpec(), - status=mlrun.api.schemas.BackgroundTaskStatus( - state=mlrun.api.schemas.BackgroundTaskState.failed + spec=mlrun.common.schemas.BackgroundTaskSpec(), + status=mlrun.common.schemas.BackgroundTaskStatus( + state=mlrun.common.schemas.BackgroundTaskState.failed ), ) @staticmethod def _generate_background_task( name: str, project: typing.Optional[str] = None - ) -> mlrun.api.schemas.BackgroundTask: + ) -> mlrun.common.schemas.BackgroundTask: now = datetime.datetime.utcnow() - metadata = mlrun.api.schemas.BackgroundTaskMetadata( + metadata = mlrun.common.schemas.BackgroundTaskMetadata( name=name, project=project, created=now, updated=now, ) - spec = mlrun.api.schemas.BackgroundTaskSpec() - status = mlrun.api.schemas.BackgroundTaskStatus( - state=mlrun.api.schemas.BackgroundTaskState.running + spec = mlrun.common.schemas.BackgroundTaskSpec() + status = mlrun.common.schemas.BackgroundTaskStatus( + state=mlrun.common.schemas.BackgroundTaskState.running ) - return mlrun.api.schemas.BackgroundTask( + return mlrun.common.schemas.BackgroundTask( metadata=metadata, spec=spec, status=status ) diff --git a/mlrun/builder.py b/mlrun/api/utils/builder.py similarity index 72% rename from mlrun/builder.py rename to mlrun/api/utils/builder.py index c11e41947f8f..c6cbac973a18 100644 --- a/mlrun/builder.py +++ b/mlrun/api/utils/builder.py @@ -14,35 +14,33 @@ import os.path import pathlib import re -import tarfile import tempfile +import typing from base64 import b64decode, b64encode from os import path from urllib.parse import urlparse from kubernetes import client -import mlrun.api.schemas +import mlrun.api.utils.singletons.k8s +import mlrun.common.constants +import mlrun.common.schemas import mlrun.errors import mlrun.runtimes.utils - -from .config import config -from .datastore import store_manager -from .k8s_utils import BasePod, get_k8s_helper -from .utils import enrich_image_url, get_parsed_docker_registry, logger, normalize_name - -IMAGE_NAME_ENRICH_REGISTRY_PREFIX = "." +import mlrun.utils +from mlrun.config import config +from mlrun.utils.helpers import remove_image_protocol_prefix def make_dockerfile( - base_image, - commands=None, - source=None, - requirements=None, - workdir="/mlrun", - extra="", - user_unix_id=None, - enriched_group_id=None, + base_image: str, + commands: list = None, + source: str = None, + requirements_path: str = None, + workdir: str = "/mlrun", + extra: str = "", + user_unix_id: int = None, + enriched_group_id: int = None, ): dock = f"FROM {base_image}\n" @@ -55,7 +53,6 @@ def make_dockerfile( dock += f"ARG {build_arg_key}={build_arg_value}\n" if source: - dock += f"RUN mkdir -p {workdir}\n" dock += f"WORKDIR {workdir}\n" # 'ADD' command does not extract zip files - add extraction stage to the dockerfile if source.endswith(".zip"): @@ -77,13 +74,16 @@ def make_dockerfile( dock += f"RUN chown -R {user_unix_id}:{enriched_group_id} {workdir}\n" dock += f"ENV PYTHONPATH {workdir}\n" - if requirements: - dock += f"RUN python -m pip install -r {requirements}\n" if commands: dock += "".join([f"RUN {command}\n" for command in commands]) + if requirements_path: + dock += ( + f"RUN echo 'Installing {requirements_path}...'; cat {requirements_path}\n" + ) + dock += f"RUN python -m pip install -r {requirements_path}\n" if extra: dock += extra - logger.debug("Resolved dockerfile", dockfile_contents=dock) + mlrun.utils.logger.debug("Resolved dockerfile", dockfile_contents=dock) return dock @@ -96,6 +96,7 @@ def make_kaniko_pod( inline_code=None, inline_path=None, requirements=None, + requirements_path=None, secret_name=None, name="", verbose=False, @@ -135,7 +136,16 @@ def make_kaniko_pod( if dockertext: dockerfile = "/empty/Dockerfile" - args = ["--dockerfile", dockerfile, "--context", context, "--destination", dest] + args = [ + "--dockerfile", + dockerfile, + "--context", + context, + "--destination", + dest, + "--image-fs-extract-retry", + config.httpdb.builder.kaniko_image_fs_extraction_retries, + ] for value, flag in [ (config.httpdb.builder.insecure_pull_registry_mode, "--insecure-pull"), (config.httpdb.builder.insecure_push_registry_mode, "--insecure"), @@ -159,7 +169,7 @@ def make_kaniko_pod( mem=default_requests.get("memory"), cpu=default_requests.get("cpu") ) } - kpod = BasePod( + kpod = mlrun.api.utils.singletons.k8s.BasePod( name or "mlrun-build", config.httpdb.builder.kaniko_image, args=args, @@ -194,19 +204,23 @@ def make_kaniko_pod( commands = [] env = {} if dockertext: - commands.append("echo ${DOCKERFILE} | base64 -d > /empty/Dockerfile") + # set and encode docker content to the DOCKERFILE environment variable in the kaniko pod env["DOCKERFILE"] = b64encode(dockertext.encode("utf-8")).decode("utf-8") + # dump dockerfile content and decode to Dockerfile destination + commands.append("echo ${DOCKERFILE} | base64 -d > /empty/Dockerfile") if inline_code: name = inline_path or "main.py" - commands.append("echo ${CODE} | base64 -d > /empty/" + name) env["CODE"] = b64encode(inline_code.encode("utf-8")).decode("utf-8") + commands.append("echo ${CODE} | base64 -d > /empty/" + name) if requirements: - commands.append( - "echo ${REQUIREMENTS} | base64 -d > /empty/requirements.txt" - ) + # set and encode requirements to the REQUIREMENTS environment variable in the kaniko pod env["REQUIREMENTS"] = b64encode( "\n".join(requirements).encode("utf-8") ).decode("utf-8") + # dump requirement content and decode to the requirement.txt destination + commands.append( + "echo ${REQUIREMENTS}" + " | " + f"base64 -d > {requirements_path}" + ) kpod.append_init_container( config.httpdb.builder.kaniko_init_container_image, @@ -222,7 +236,21 @@ def make_kaniko_pod( if end == -1: end = len(dest) repo = dest[dest.find("/") + 1 : end] - configure_kaniko_ecr_init_container(kpod, registry, repo) + + # if no secret is given, assume ec2 instance has attached role which provides read/write access to ECR + assume_instance_role = not config.httpdb.builder.docker_registry_secret + configure_kaniko_ecr_init_container(kpod, registry, repo, assume_instance_role) + + # project secret might conflict with the attached instance role + # ensure "AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY" have no values or else kaniko will fail + # due to credentials conflict / lack of permission on given credentials + if assume_instance_role: + kpod.pod.spec.containers[0].env.extend( + [ + client.V1EnvVar(name="AWS_ACCESS_KEY_ID", value=""), + client.V1EnvVar(name="AWS_SECRET_ACCESS_KEY", value=""), + ] + ) # mount regular docker config secret elif secret_name: @@ -232,7 +260,9 @@ def make_kaniko_pod( return kpod -def configure_kaniko_ecr_init_container(kpod, registry, repo): +def configure_kaniko_ecr_init_container( + kpod, registry, repo, assume_instance_role=True +): region = registry.split(".")[3] # fail silently in order to ignore "repository already exists" errors @@ -243,12 +273,13 @@ def configure_kaniko_ecr_init_container(kpod, registry, repo): ) init_container_env = {} - if not config.httpdb.builder.docker_registry_secret: + if assume_instance_role: # assume instance role has permissions to register and store a container image # https://github.com/GoogleContainerTools/kaniko#pushing-to-amazon-ecr # we only need this in the kaniko container kpod.env.append(client.V1EnvVar(name="AWS_SDK_LOAD_CONFIG", value="true")) + else: aws_credentials_file_env_key = "AWS_SHARED_CREDENTIALS_FILE" aws_credentials_file_env_value = "/tmp/credentials" @@ -279,24 +310,12 @@ def configure_kaniko_ecr_init_container(kpod, registry, repo): ) -def upload_tarball(source_dir, target, secrets=None): - - # will delete the temp file - with tempfile.NamedTemporaryFile(suffix=".tar.gz") as temp_fh: - with tarfile.open(mode="w:gz", fileobj=temp_fh) as tar: - tar.add(source_dir, arcname="") - stores = store_manager.set(secrets) - datastore, subpath = stores.get_or_create_store(target) - datastore.upload(subpath, temp_fh.name) - - def build_image( - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, project: str, image_target, commands=None, source="", - mounter="v3io", base_image=None, requirements=None, inline_code=None, @@ -316,43 +335,24 @@ def build_image( ): runtime_spec = runtime.spec if runtime else None builder_env = builder_env or {} - image_target, secret_name = _resolve_image_target_and_registry_secret( + image_target, secret_name = resolve_image_target_and_registry_secret( image_target, registry, secret_name ) - if isinstance(requirements, list): - requirements_list = requirements - requirements_path = "requirements.txt" - if source: - raise ValueError("requirements list only works with inline code") - else: - requirements_list = None - requirements_path = requirements - - commands = commands or [] - if with_mlrun: - # mlrun prerequisite - upgrade pip - upgrade_pip_command = resolve_upgrade_pip_command(commands) - if upgrade_pip_command: - commands.append(upgrade_pip_command) - - mlrun_command = resolve_mlrun_install_command( - mlrun_version_specifier, client_version, commands - ) - if mlrun_command: - commands.append(mlrun_command) + commands, requirements_list, requirements_path = _resolve_build_requirements( + requirements, commands, with_mlrun, mlrun_version_specifier, client_version + ) - if not inline_code and not source and not commands: - logger.info("skipping build, nothing to add") + if not inline_code and not source and not commands and not requirements: + mlrun.utils.logger.info("skipping build, nothing to add") return "skipped" context = "/context" to_mount = False - v3io = ( - source.startswith("v3io://") or source.startswith("v3ios://") - if source - else None - ) + is_v3io_source = False + if source: + is_v3io_source = source.startswith("v3io://") or source.startswith("v3ios://") + access_key = builder_env.get( "V3IO_ACCESS_KEY", auth_info.data_session or auth_info.access_key ) @@ -366,7 +366,8 @@ def build_image( if inline_code or runtime_spec.build.load_source_on_run or not source: context = "/empty" - elif source and "://" in source and not v3io: + # source is remote + elif source and "://" in source and not is_v3io_source: if source.startswith("git://"): # if the user provided branch (w/o refs/..) we add the "refs/.." fragment = parsed_url.fragment or "" @@ -377,19 +378,34 @@ def build_image( context = source source_to_copy = "." + # source is local / v3io else: - if v3io: + if is_v3io_source: source = parsed_url.path to_mount = True source_dir_to_mount, source_to_copy = path.split(source) - else: + + # source is a path without a scheme, we allow to copy absolute paths assuming they are valid paths + # in the image, however, it is recommended to use `workdir` instead in such cases + # which is set during runtime (mlrun.runtimes.local.LocalRuntime._pre_run). + # relative paths are not supported at build time + # "." and "./" are considered as 'project context' + # TODO: enrich with project context if pulling on build time + elif path.isabs(source): source_to_copy = source + else: + raise mlrun.errors.MLRunInvalidArgumentError( + f"Load of relative source ({source}) is not supported at build time " + "see 'mlrun.runtimes.kubejob.KubejobRuntime.with_source_archive' or " + "'mlrun.projects.project.MlrunProject.set_source' for more details" + ) + user_unix_id = None enriched_group_id = None if ( mlrun.mlconf.function.spec.security_context.enrichment_mode - != mlrun.api.schemas.SecurityContextEnrichmentModes.disabled.value + != mlrun.common.schemas.SecurityContextEnrichmentModes.disabled.value ): from mlrun.api.api.utils import ensure_function_security_context @@ -398,24 +414,28 @@ def build_image( enriched_group_id = runtime.spec.security_context.run_as_group if source_to_copy and ( - not runtime.spec.workdir or not path.isabs(runtime.spec.workdir) + not runtime.spec.clone_target_dir + or not os.path.isabs(runtime.spec.clone_target_dir) ): - # the user may give a relative workdir to the source where the code is located - # add the relative workdir to the target source copy path + # use a temp dir for permissions and set it as the workdir tmpdir = tempfile.mkdtemp() - relative_workdir = runtime.spec.workdir or "" - _, _, relative_workdir = relative_workdir.partition("./") - runtime.spec.workdir = path.join(tmpdir, "mlrun", relative_workdir) + relative_workdir = runtime.spec.clone_target_dir or "" + if relative_workdir.startswith("./"): + # TODO: use 'removeprefix' when we drop python 3.7 support + # relative_workdir.removeprefix("./") + relative_workdir = relative_workdir[2:] + + runtime.spec.clone_target_dir = path.join(tmpdir, "mlrun", relative_workdir) dock = make_dockerfile( base_image, commands, source=source_to_copy, - requirements=requirements_path, + requirements_path=requirements_path, extra=extra, user_unix_id=user_unix_id, enriched_group_id=enriched_group_id, - workdir=runtime.spec.workdir, + workdir=runtime.spec.clone_target_dir, ) kpod = make_kaniko_pod( @@ -426,6 +446,7 @@ def build_image( inline_code=inline_code, inline_path=inline_path, requirements=requirements_list, + requirements_path=requirements_path, secret_name=secret_name, name=name, verbose=verbose, @@ -442,14 +463,16 @@ def build_image( user=username, ) - k8s = get_k8s_helper() + k8s = mlrun.api.utils.singletons.k8s.get_k8s_helper(silent=False) kpod.namespace = k8s.resolve_namespace(namespace) if interactive: return k8s.run_job(kpod) else: pod, ns = k8s.create_pod(kpod) - logger.info(f'started build, to watch build logs use "mlrun watch {pod} {ns}"') + mlrun.utils.logger.info( + "Build started", pod=pod, namespace=ns, project=project, image=image_target + ) return f"build:{pod}" @@ -465,7 +488,7 @@ def get_kaniko_spec_attributes_from_runtime(): ] -def resolve_mlrun_install_command( +def resolve_mlrun_install_command_version( mlrun_version_specifier=None, client_version=None, commands=None ): commands = commands or [] @@ -495,7 +518,7 @@ def resolve_mlrun_install_command( mlrun_version_specifier = ( f"{config.package_path}[complete]=={config.version}" ) - return f'python -m pip install "{mlrun_version_specifier}"' + return mlrun_version_specifier def resolve_upgrade_pip_command(commands=None): @@ -509,7 +532,7 @@ def resolve_upgrade_pip_command(commands=None): def build_runtime( - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, runtime, with_mlrun=True, mlrun_version_specifier=None, @@ -523,7 +546,7 @@ def build_runtime( namespace = runtime.metadata.namespace project = runtime.metadata.project if skip_deployed and runtime.is_deployed(): - runtime.status.state = mlrun.api.schemas.FunctionState.ready + runtime.status.state = mlrun.common.schemas.FunctionState.ready return True if build.base_image: mlrun_images = [ @@ -535,7 +558,13 @@ def build_runtime( # if the base is one of mlrun images - no need to install mlrun if any([image in build.base_image for image in mlrun_images]): with_mlrun = False - if not build.source and not build.commands and not build.extra and not with_mlrun: + if ( + not build.source + and not build.commands + and not build.requirements + and not build.extra + and not with_mlrun + ): if not runtime.spec.image: if build.base_image: runtime.spec.image = build.base_image @@ -548,7 +577,7 @@ def build_runtime( "The deployment was not successful because no image was specified or there are missing build parameters" " (commands/source)" ) - runtime.status.state = mlrun.api.schemas.FunctionState.ready + runtime.status.state = mlrun.common.schemas.FunctionState.ready return True build.image = mlrun.runtimes.utils.resolve_function_image_name(runtime, build.image) @@ -561,17 +590,23 @@ def build_runtime( raise mlrun.errors.MLRunInvalidArgumentError( "build spec must have a target image, set build.image = " ) - logger.info(f"building image ({build.image})") + name = mlrun.utils.normalize_name(f"mlrun-build-{runtime.metadata.name}") - name = normalize_name(f"mlrun-build-{runtime.metadata.name}") base_image: str = ( build.base_image or runtime.spec.image or config.default_base_image ) - enriched_base_image = enrich_image_url( + enriched_base_image = mlrun.utils.enrich_image_url( base_image, client_version, client_python_version, ) + mlrun.utils.logger.info( + "Building runtime image", + base_image=enriched_base_image, + image=build.image, + project=project, + name=name, + ) status = build_image( auth_info, @@ -579,8 +614,8 @@ def build_runtime( image_target=build.image, base_image=enriched_base_image, commands=build.commands, + requirements=build.requirements, namespace=namespace, - # inline_code=inline, source=build.source, secret_name=build.secret, interactive=interactive, @@ -598,11 +633,11 @@ def build_runtime( # using enriched base image for the runtime spec image, because this will be the image that the function will # run with runtime.spec.image = enriched_base_image - runtime.status.state = mlrun.api.schemas.FunctionState.ready + runtime.status.state = mlrun.common.schemas.FunctionState.ready return True if status.startswith("build:"): - runtime.status.state = mlrun.api.schemas.FunctionState.deploying + runtime.status.state = mlrun.common.schemas.FunctionState.deploying runtime.status.build_pod = status[6:] # using the base_image, and not the enriched one so we won't have the client version in the image, useful for # exports and other cases where we don't want to have the client version in the image, but rather enriched on @@ -610,48 +645,34 @@ def build_runtime( runtime.spec.build.base_image = base_image return False - logger.info(f"build completed with {status}") + mlrun.utils.logger.info(f"build completed with {status}") if status in ["failed", "error"]: - runtime.status.state = mlrun.api.schemas.FunctionState.error + runtime.status.state = mlrun.common.schemas.FunctionState.error return False local = "" if build.secret or build.image.startswith(".") else "." runtime.spec.image = local + build.image - runtime.status.state = mlrun.api.schemas.FunctionState.ready + runtime.status.state = mlrun.common.schemas.FunctionState.ready return True -def _generate_builder_env(project, builder_env): - k8s = get_k8s_helper() - secret_name = k8s.get_project_secret_name(project) - existing_secret_keys = k8s.get_project_secret_keys(project, filter_internal=True) - - # generate env list from builder env and project secrets - env = [] - for key in existing_secret_keys: - if key not in builder_env: - value_from = client.V1EnvVarSource( - secret_key_ref=client.V1SecretKeySelector(name=secret_name, key=key) - ) - env.append(client.V1EnvVar(name=key, value_from=value_from)) - for key, value in builder_env.items(): - env.append(client.V1EnvVar(name=key, value=value)) - return env - - -def _resolve_image_target_and_registry_secret( +def resolve_image_target_and_registry_secret( image_target: str, registry: str = None, secret_name: str = None ) -> (str, str): if registry: return "/".join([registry, image_target]), secret_name # if dest starts with a dot, we add the configured registry to the start of the dest - if image_target.startswith(IMAGE_NAME_ENRICH_REGISTRY_PREFIX): + if image_target.startswith( + mlrun.common.constants.IMAGE_NAME_ENRICH_REGISTRY_PREFIX + ): # remove prefix from image name - image_target = image_target[len(IMAGE_NAME_ENRICH_REGISTRY_PREFIX) :] + image_target = image_target[ + len(mlrun.common.constants.IMAGE_NAME_ENRICH_REGISTRY_PREFIX) : + ] - registry, repository = get_parsed_docker_registry() + registry, repository = mlrun.utils.get_parsed_docker_registry() secret_name = secret_name or config.httpdb.builder.docker_registry_secret if not registry: raise ValueError( @@ -664,4 +685,64 @@ def _resolve_image_target_and_registry_secret( return "/".join(image_target_components), secret_name + image_target = remove_image_protocol_prefix(image_target) + return image_target, secret_name + + +def _generate_builder_env(project, builder_env): + k8s = mlrun.api.utils.singletons.k8s.get_k8s_helper(silent=False) + secret_name = k8s.get_project_secret_name(project) + existing_secret_keys = k8s.get_project_secret_keys(project, filter_internal=True) + + # generate env list from builder env and project secrets + env = [] + for key in existing_secret_keys: + if key not in builder_env: + value_from = client.V1EnvVarSource( + secret_key_ref=client.V1SecretKeySelector(name=secret_name, key=key) + ) + env.append(client.V1EnvVar(name=key, value_from=value_from)) + for key, value in builder_env.items(): + env.append(client.V1EnvVar(name=key, value=value)) + return env + + +def _resolve_build_requirements( + requirements: typing.Union[typing.List, str], + commands: typing.List, + with_mlrun: bool, + mlrun_version_specifier: typing.Optional[str], + client_version: typing.Optional[str], +): + """ + Resolve build requirements list, requirements path and commands. + If mlrun requirement is needed, we add a pip upgrade command to the commands list (prerequisite). + """ + requirements_path = "/empty/requirements.txt" + if requirements and isinstance(requirements, list): + requirements_list = requirements + else: + requirements_list = [] + requirements_path = requirements or requirements_path + commands = commands or [] + + if with_mlrun: + # mlrun prerequisite - upgrade pip + upgrade_pip_command = resolve_upgrade_pip_command(commands) + if upgrade_pip_command: + commands.append(upgrade_pip_command) + + mlrun_version = resolve_mlrun_install_command_version( + mlrun_version_specifier, client_version, commands + ) + + # mlrun must be installed with other python requirements in the same pip command to avoid version conflicts + if mlrun_version: + requirements_list.insert(0, mlrun_version) + + if not requirements_list: + # no requirements, we don't need a requirements file + requirements_path = "" + + return commands, requirements_list, requirements_path diff --git a/mlrun/api/utils/clients/chief.py b/mlrun/api/utils/clients/chief.py index ece862c6bfe2..87e77a75728c 100644 --- a/mlrun/api/utils/clients/chief.py +++ b/mlrun/api/utils/clients/chief.py @@ -21,8 +21,8 @@ import aiohttp import fastapi -import mlrun.api.schemas import mlrun.api.utils.projects.remotes.follower +import mlrun.common.schemas import mlrun.errors import mlrun.utils.singleton from mlrun.utils import logger @@ -157,7 +157,7 @@ async def delete_project(self, name, request: fastapi.Request) -> fastapi.Respon async def get_clusterization_spec( self, return_fastapi_response: bool = True, raise_on_failure: bool = False - ) -> typing.Union[fastapi.Response, mlrun.api.schemas.ClusterizationSpec]: + ) -> typing.Union[fastapi.Response, mlrun.common.schemas.ClusterizationSpec]: """ This method is used both for proxying requests from worker to chief and for aligning the worker state with the clusterization spec brought from the chief @@ -172,7 +172,22 @@ async def get_clusterization_spec( chief_response ) - return mlrun.api.schemas.ClusterizationSpec(**(await chief_response.json())) + return mlrun.common.schemas.ClusterizationSpec( + **(await chief_response.json()) + ) + + async def set_schedule_notifications( + self, project: str, schedule_name: str, request: fastapi.Request, json: dict + ) -> fastapi.Response: + """ + Schedules are running only on chief + """ + return await self._proxy_request_to_chief( + "PUT", + f"projects/{project}/schedules/{schedule_name}/notifications", + request, + json, + ) async def _proxy_request_to_chief( self, diff --git a/mlrun/api/utils/clients/iguazio.py b/mlrun/api/utils/clients/iguazio.py index 754d530bffa9..5bd931de7ff7 100644 --- a/mlrun/api/utils/clients/iguazio.py +++ b/mlrun/api/utils/clients/iguazio.py @@ -24,11 +24,12 @@ import aiohttp import fastapi +import igz_mgmt.schemas.manual_events import requests.adapters from fastapi.concurrency import run_in_threadpool -import mlrun.api.schemas import mlrun.api.utils.projects.remotes.leader +import mlrun.common.schemas import mlrun.errors import mlrun.utils.helpers import mlrun.utils.singleton @@ -79,7 +80,7 @@ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._session = mlrun.utils.HTTPSessionWithRetry( retry_on_exception=mlrun.mlconf.httpdb.projects.retry_leader_request_on_exception - == mlrun.api.schemas.HTTPSessionRetryMode.enabled.value, + == mlrun.common.schemas.HTTPSessionRetryMode.enabled.value, verbose=True, ) self._api_url = mlrun.mlconf.iguazio_api_url @@ -89,13 +90,15 @@ def __init__(self, *args, **kwargs) -> None: [[1, 10], [5, None]] ) self._wait_for_project_terminal_state_retry_interval = 5 + self._logger = logger.get_child("iguazio-client") + self._igz_clients = {} def try_get_grafana_service_url(self, session: str) -> typing.Optional[str]: """ Try to find a ready grafana app service, and return its URL If nothing found, returns None """ - logger.debug("Getting grafana service url from Iguazio") + self._logger.debug("Getting grafana service url from Iguazio") response = self._send_request_to_api( "GET", "app_services_manifests", @@ -123,7 +126,7 @@ def try_get_grafana_service_url(self, session: str) -> typing.Optional[str]: def verify_request_session( self, request: fastapi.Request - ) -> mlrun.api.schemas.AuthInfo: + ) -> mlrun.common.schemas.AuthInfo: """ Proxy the request to one of the session verification endpoints (which will verify the session of the request) """ @@ -140,7 +143,7 @@ def verify_request_session( response.headers, response.json() ) - def verify_session(self, session: str) -> mlrun.api.schemas.AuthInfo: + def verify_session(self, session: str) -> mlrun.common.schemas.AuthInfo: response = self._send_request_to_api( "POST", mlrun.mlconf.httpdb.authentication.iguazio.session_verification_endpoint, @@ -183,16 +186,16 @@ def get_or_create_access_key( json=body, ) if response.status_code == http.HTTPStatus.CREATED.value: - logger.debug("Created access key in Iguazio", planes=planes) + self._logger.debug("Created access key in Iguazio", planes=planes) return response.json()["data"]["id"] def create_project( self, session: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, wait_for_completion: bool = True, ) -> bool: - logger.debug("Creating project in Iguazio", project=project) + self._logger.debug("Creating project in Iguazio", project=project) body = self._transform_mlrun_project_to_iguazio_project(project) return self._create_project_in_iguazio( session, project.metadata.name, body, wait_for_completion @@ -202,9 +205,9 @@ def update_project( self, session: str, name: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): - logger.debug("Updating project in Iguazio", name=name, project=project) + self._logger.debug("Updating project in Iguazio", name=name) body = self._transform_mlrun_project_to_iguazio_project(project) self._put_project_to_iguazio(session, name, body) @@ -212,17 +215,17 @@ def delete_project( self, session: str, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), wait_for_completion: bool = True, ) -> bool: - logger.debug( + self._logger.debug( "Deleting project in Iguazio", name=name, deletion_strategy=deletion_strategy, ) body = self._transform_mlrun_project_to_iguazio_project( - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=name) + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=name) ) ) headers = { @@ -240,7 +243,7 @@ def delete_project( except requests.HTTPError as exc: if exc.response.status_code != http.HTTPStatus.NOT_FOUND.value: raise - logger.debug( + self._logger.debug( "Project not found in Iguazio. Considering deletion as successful", name=name, deletion_strategy=deletion_strategy, @@ -249,7 +252,7 @@ def delete_project( else: if wait_for_completion: job_id = response.json()["data"]["id"] - logger.debug( + self._logger.debug( "Waiting for project deletion job in Iguazio", name=name, job_id=job_id, @@ -266,48 +269,25 @@ def list_projects( updated_after: typing.Optional[datetime.datetime] = None, page_size: typing.Optional[int] = None, ) -> typing.Tuple[ - typing.List[mlrun.api.schemas.Project], typing.Optional[datetime.datetime] + typing.List[mlrun.common.schemas.Project], typing.Optional[datetime.datetime] ]: - params = {} - if updated_after is not None: - time_string = updated_after.isoformat().split("+")[0] - params = {"filter[updated_at]": f"[$gt]{time_string}Z"} - if page_size is None: - page_size = ( - mlrun.mlconf.httpdb.projects.iguazio_list_projects_default_page_size - ) - if page_size is not None: - params["page[size]"] = int(page_size) - - params["include"] = "owner" - response = self._send_request_to_api( - "GET", - "projects", - "Failed listing projects from Iguazio", - session, - params=params, + project_names, latest_updated_at = self._list_project_names( + session, updated_after, page_size ) - response_body = response.json() - projects = [] - for iguazio_project in response_body["data"]: - projects.append( - self._transform_iguazio_project_to_mlrun_project(iguazio_project) - ) - latest_updated_at = self._find_latest_updated_at(response_body) - return projects, latest_updated_at + return self._list_projects_data(session, project_names), latest_updated_at def get_project( self, session: str, name: str, - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: return self._get_project_from_iguazio(session, name) def get_project_owner( self, session: str, name: str, - ) -> mlrun.api.schemas.ProjectOwner: + ) -> mlrun.common.schemas.ProjectOwner: response = self._get_project_from_iguazio_without_parsing( session, name, enrich_owner_access_key=True ) @@ -323,15 +303,15 @@ def get_project_owner( f"Unable to enrich project owner for project {name}," f" because project has no owner configured" ) - return mlrun.api.schemas.ProjectOwner( + return mlrun.common.schemas.ProjectOwner( username=owner_username, access_key=owner_access_key, ) def format_as_leader_project( - self, project: mlrun.api.schemas.Project - ) -> mlrun.api.schemas.IguazioProject: - return mlrun.api.schemas.IguazioProject( + self, project: mlrun.common.schemas.Project + ) -> mlrun.common.schemas.IguazioProject: + return mlrun.common.schemas.IguazioProject( data=self._transform_mlrun_project_to_iguazio_project(project)["data"] ) @@ -342,6 +322,65 @@ def is_sync(self): """ return True + def emit_manual_event( + self, access_key: str, event: igz_mgmt.schemas.manual_events.ManualEventSchema + ): + """ + Emit a manual event to Iguazio + """ + client = self._get_igz_client(access_key) + igz_mgmt.ManualEvents.emit( + http_client=client, event=event, audit_tenant_id=client.tenant_id + ) + + def _get_igz_client(self, access_key: str) -> igz_mgmt.Client: + if not self._igz_clients.get(access_key): + self._igz_clients[access_key] = igz_mgmt.Client( + endpoint=self._api_url, + access_key=access_key, + ) + return self._igz_clients[access_key] + + def _list_project_names( + self, + session: str, + updated_after: typing.Optional[datetime.datetime] = None, + page_size: typing.Optional[int] = None, + ) -> typing.Tuple[typing.List[str], typing.Optional[datetime.datetime]]: + params = {} + if updated_after is not None: + time_string = updated_after.isoformat().split("+")[0] + params = {"filter[updated_at]": f"[$gt]{time_string}Z"} + if page_size is None: + page_size = ( + mlrun.mlconf.httpdb.projects.iguazio_list_projects_default_page_size + ) + if page_size is not None: + params["page[size]"] = int(page_size) + + response = self._send_request_to_api( + "GET", + "projects", + "Failed listing projects from Iguazio", + session, + params=params, + ) + response_body = response.json() + project_names = [ + iguazio_project["attributes"]["name"] + for iguazio_project in response_body["data"] + ] + latest_updated_at = self._find_latest_updated_at(response_body) + return project_names, latest_updated_at + + def _list_projects_data( + self, session: str, project_names: typing.List[str] + ) -> typing.List[mlrun.common.schemas.Project]: + return [ + self._get_project_from_iguazio(session, project_name) + for project_name in project_names + ] + def _find_latest_updated_at( self, response_body: dict ) -> typing.Optional[datetime.datetime]: @@ -359,7 +398,7 @@ def _create_project_in_iguazio( ) -> bool: _, job_id = self._post_project_to_iguazio(session, body) if wait_for_completion: - logger.debug( + self._logger.debug( "Waiting for project creation job in Iguazio", name=name, job_id=job_id, @@ -367,12 +406,17 @@ def _create_project_in_iguazio( self._wait_for_job_completion( session, job_id, "Project creation job failed" ) + self._logger.debug( + "Successfully created project in Iguazio", + name=name, + job_id=job_id, + ) return False return True def _post_project_to_iguazio( self, session: str, body: dict - ) -> typing.Tuple[mlrun.api.schemas.Project, str]: + ) -> typing.Tuple[mlrun.common.schemas.Project, str]: response = self._send_request_to_api( "POST", "projects", "Failed creating project in Iguazio", session, json=body ) @@ -384,7 +428,7 @@ def _post_project_to_iguazio( def _put_project_to_iguazio( self, session: str, name: str, body: dict - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: response = self._send_request_to_api( "PUT", f"projects/__name__/{name}", @@ -410,7 +454,7 @@ def _get_project_from_iguazio_without_parsing( def _get_project_from_iguazio( self, session: str, name: str, include_owner_session: bool = False - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: response = self._get_project_from_iguazio_without_parsing(session, name) return self._transform_iguazio_project_to_mlrun_project(response.json()["data"]) @@ -428,7 +472,7 @@ def _verify_job_in_terminal_state(): job_state, job_result = mlrun.utils.helpers.retry_until_successful( self._wait_for_job_completion_retry_interval, 360, - logger, + self._logger, False, _verify_job_in_terminal_state, ) @@ -445,6 +489,7 @@ def _verify_job_in_terminal_state(): if not status_code: raise mlrun.errors.MLRunRuntimeError(error_message) raise mlrun.errors.raise_for_status_code(status_code, error_message) + self._logger.debug("Job completed successfully", job_id=job_id) def _send_request_to_api( self, method, path, error_message: str, session=None, **kwargs @@ -466,7 +511,7 @@ def _generate_auth_info_from_session_verification_response( self, response_headers: typing.Mapping[str, typing.Any], response_body: typing.Mapping[typing.Any, typing.Any], - ) -> mlrun.api.schemas.AuthInfo: + ) -> mlrun.common.schemas.AuthInfo: ( username, @@ -487,7 +532,7 @@ def _generate_auth_info_from_session_verification_response( user_id = user_id_from_body or user_id group_ids = group_ids_from_body or group_ids - auth_info = mlrun.api.schemas.AuthInfo( + auth_info = mlrun.common.schemas.AuthInfo( username=username, session=session, user_id=user_id, @@ -546,7 +591,7 @@ def _resolve_params_from_response_body( @staticmethod def _transform_mlrun_project_to_iguazio_project( - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ) -> dict: body = { "data": { @@ -583,7 +628,7 @@ def _transform_mlrun_project_to_iguazio_project( @staticmethod def _transform_mlrun_project_to_iguazio_mlrun_project_attribute( - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): project_dict = project.dict( exclude_unset=True, @@ -617,7 +662,7 @@ def _transform_iguazio_labels_to_mlrun_labels( @staticmethod def _transform_iguazio_project_to_mlrun_project( iguazio_project, - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: mlrun_project_without_common_fields = json.loads( iguazio_project["attributes"].get("mlrun_project", "{}") ) @@ -625,14 +670,16 @@ def _transform_iguazio_project_to_mlrun_project( mlrun_project_without_common_fields.setdefault("metadata", {})[ "name" ] = iguazio_project["attributes"]["name"] - mlrun_project = mlrun.api.schemas.Project(**mlrun_project_without_common_fields) + mlrun_project = mlrun.common.schemas.Project( + **mlrun_project_without_common_fields + ) mlrun_project.metadata.created = datetime.datetime.fromisoformat( iguazio_project["attributes"]["created_at"] ) - mlrun_project.spec.desired_state = mlrun.api.schemas.ProjectDesiredState( + mlrun_project.spec.desired_state = mlrun.common.schemas.ProjectDesiredState( iguazio_project["attributes"]["admin_status"] ) - mlrun_project.status.state = mlrun.api.schemas.ProjectState( + mlrun_project.status.state = mlrun.common.schemas.ProjectState( iguazio_project["attributes"]["operational_status"] ) if iguazio_project["attributes"].get("description"): @@ -677,11 +724,11 @@ def _prepare_request_kwargs(self, session, path, *, kwargs): if kwargs.get("timeout") is None: kwargs["timeout"] = 20 if "projects" in path: - if mlrun.api.schemas.HeaderNames.projects_role not in kwargs.get( + if mlrun.common.schemas.HeaderNames.projects_role not in kwargs.get( "headers", {} ): kwargs.setdefault("headers", {})[ - mlrun.api.schemas.HeaderNames.projects_role + mlrun.common.schemas.HeaderNames.projects_role ] = "mlrun" # requests no longer supports header values to be enum (https://github.com/psf/requests/pull/6154) @@ -708,7 +755,7 @@ def _handle_error_response( if errors or ctx: log_kwargs.update({"ctx": ctx, "errors": errors}) - logger.warning("Request to iguazio failed", **log_kwargs) + self._logger.warning("Request to iguazio failed", **log_kwargs) mlrun.errors.raise_for_status(response, error_message) @@ -755,7 +802,7 @@ def wrapper(*args, **kwargs): async def verify_request_session( self, request: fastapi.Request - ) -> mlrun.api.schemas.AuthInfo: + ) -> mlrun.common.schemas.AuthInfo: """ Proxy the request to one of the session verification endpoints (which will verify the session of the request) """ @@ -772,7 +819,7 @@ async def verify_request_session( response.headers, await response.json() ) - async def verify_session(self, session: str) -> mlrun.api.schemas.AuthInfo: + async def verify_session(self, session: str) -> mlrun.common.schemas.AuthInfo: async with self._send_request_to_api_async( "POST", mlrun.mlconf.httpdb.authentication.iguazio.session_verification_endpoint, @@ -812,6 +859,6 @@ async def _ensure_async_session(self): if not self._async_session: self._async_session = mlrun.utils.AsyncClientWithRetry( retry_on_exception=mlrun.mlconf.httpdb.projects.retry_leader_request_on_exception - == mlrun.api.schemas.HTTPSessionRetryMode.enabled.value, + == mlrun.common.schemas.HTTPSessionRetryMode.enabled.value, logger=logger, ) diff --git a/mlrun/api/utils/clients/log_collector.py b/mlrun/api/utils/clients/log_collector.py index 764376def26a..8c6576c19f9d 100644 --- a/mlrun/api/utils/clients/log_collector.py +++ b/mlrun/api/utils/clients/log_collector.py @@ -14,6 +14,7 @@ import asyncio import enum import http +import re import typing import mlrun.api.utils.clients.protocols.grpc @@ -52,6 +53,20 @@ def map_error_code_to_mlrun_error( return mlrun_error_class(message) +class LogCollectorErrorRegex: + # when multiple routines in the log collector service try to search the same directory, + # one of them can fail with this error + readdirent_resource_temporarily_unavailable = ( + "readdirent.*resource temporarily unavailable" + ) + + @classmethod + def has_logs_retryable_errors(cls): + return [ + cls.readdirent_resource_temporarily_unavailable, + ] + + class LogCollectorClient( mlrun.api.utils.clients.protocols.grpc.BaseGRPCClient, metaclass=mlrun.utils.singleton.Singleton, @@ -137,6 +152,11 @@ async def get_logs( try: has_logs = await self.has_logs(run_uid, project, verbose, raise_on_error) if not has_logs: + logger.debug( + "Run has no logs to collect", + run_uid=run_uid, + project=project, + ) # run has no logs - return empty logs and exit so caller won't wait for logs or retry yield b"" @@ -209,6 +229,18 @@ async def has_logs( response = await self._call("HasLogs", request) if not response.success: + if self._retryable_error( + response.errorMessage, + LogCollectorErrorRegex.has_logs_retryable_errors(), + ): + if verbose: + logger.warning( + "Failed to check if run has logs to collect, retrying", + run_uid=run_uid, + error=response.errorMessage, + ) + return False + msg = f"Failed to check if run has logs to collect for {run_uid}" if verbose: logger.warning(msg, error=response.errorMessage) @@ -233,7 +265,6 @@ async def stop_logs( :param raise_on_error: Whether to raise an exception on error :return: None """ - request = self._log_collector_pb2.StopLogsRequest( project=project, runUIDs=run_uids ) @@ -277,3 +308,14 @@ async def delete_logs( ) if verbose: logger.warning(msg, error=response.errorMessage) + + def _retryable_error(self, error_message, retryable_error_patterns) -> bool: + """ + Check if the error is retryable + :param error_message: The error message + :param retryable_error_patterns: The retryable error regex patterns + :return: Whether the error is retryable + """ + if any(re.match(regex, error_message) for regex in retryable_error_patterns): + return True + return False diff --git a/mlrun/api/utils/clients/nuclio.py b/mlrun/api/utils/clients/nuclio.py index f71f63821ee9..1e43ff14aebe 100644 --- a/mlrun/api/utils/clients/nuclio.py +++ b/mlrun/api/utils/clients/nuclio.py @@ -20,8 +20,8 @@ import requests.adapters import sqlalchemy.orm -import mlrun.api.schemas import mlrun.api.utils.projects.remotes.follower +import mlrun.common.schemas import mlrun.errors import mlrun.utils.singleton from mlrun.utils import logger @@ -37,7 +37,7 @@ def __init__(self) -> None: self._api_url = mlrun.config.config.nuclio_dashboard_url def create_project( - self, session: sqlalchemy.orm.Session, project: mlrun.api.schemas.Project + self, session: sqlalchemy.orm.Session, project: mlrun.common.schemas.Project ): logger.debug("Creating project in Nuclio", project=project) body = self._generate_request_body(project) @@ -47,7 +47,7 @@ def store_project( self, session: sqlalchemy.orm.Session, name: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): logger.debug("Storing project in Nuclio", name=name, project=project) body = self._generate_request_body(project) @@ -65,7 +65,7 @@ def patch_project( session: sqlalchemy.orm.Session, name: str, project: dict, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ): logger.debug( "Patching project in Nuclio", @@ -93,14 +93,14 @@ def delete_project( self, session: sqlalchemy.orm.Session, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), ): logger.debug( "Deleting project in Nuclio", name=name, deletion_strategy=deletion_strategy ) body = self._generate_request_body( - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=name) + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=name) ) ) headers = { @@ -119,7 +119,7 @@ def delete_project( def get_project( self, session: sqlalchemy.orm.Session, name: str - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: response = self._get_project_from_nuclio(name) response_body = response.json() return self._transform_nuclio_project_to_schema(response_body) @@ -128,11 +128,11 @@ def list_projects( self, session: sqlalchemy.orm.Session, owner: str = None, - format_: mlrun.api.schemas.ProjectsFormat = mlrun.api.schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectsOutput: + ) -> mlrun.common.schemas.ProjectsOutput: if owner: raise NotImplementedError( "Listing nuclio projects by owner is currently not supported" @@ -154,10 +154,10 @@ def list_projects( projects = [] for nuclio_project in response_body.values(): projects.append(self._transform_nuclio_project_to_schema(nuclio_project)) - if format_ == mlrun.api.schemas.ProjectsFormat.full: - return mlrun.api.schemas.ProjectsOutput(projects=projects) - elif format_ == mlrun.api.schemas.ProjectsFormat.name_only: - return mlrun.api.schemas.ProjectsOutput( + if format_ == mlrun.common.schemas.ProjectsFormat.full: + return mlrun.common.schemas.ProjectsOutput(projects=projects) + elif format_ == mlrun.common.schemas.ProjectsFormat.name_only: + return mlrun.common.schemas.ProjectsOutput( projects=[project.metadata.name for project in projects] ) else: @@ -170,14 +170,14 @@ def list_project_summaries( session: sqlalchemy.orm.Session, owner: str = None, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectSummariesOutput: + ) -> mlrun.common.schemas.ProjectSummariesOutput: raise NotImplementedError("Listing project summaries is not supported") def get_project_summary( self, session: sqlalchemy.orm.Session, name: str - ) -> mlrun.api.schemas.ProjectSummary: + ) -> mlrun.common.schemas.ProjectSummary: raise NotImplementedError("Get project summary is not supported") def get_dashboard_version(self) -> str: @@ -226,7 +226,7 @@ def _send_request_to_api(self, method, path, **kwargs): return response @staticmethod - def _generate_request_body(project: mlrun.api.schemas.Project): + def _generate_request_body(project: mlrun.common.schemas.Project): body = { "metadata": {"name": project.metadata.name}, } @@ -240,13 +240,13 @@ def _generate_request_body(project: mlrun.api.schemas.Project): @staticmethod def _transform_nuclio_project_to_schema(nuclio_project): - return mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + return mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name=nuclio_project["metadata"]["name"], labels=nuclio_project["metadata"].get("labels"), annotations=nuclio_project["metadata"].get("annotations"), ), - spec=mlrun.api.schemas.ProjectSpec( + spec=mlrun.common.schemas.ProjectSpec( description=nuclio_project["spec"].get("description") ), ) diff --git a/mlrun/api/utils/clients/protocols/grpc.py b/mlrun/api/utils/clients/protocols/grpc.py index 311100475acc..5cd82e267b1e 100644 --- a/mlrun/api/utils/clients/protocols/grpc.py +++ b/mlrun/api/utils/clients/protocols/grpc.py @@ -15,7 +15,7 @@ import google.protobuf.reflection import grpc -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.config import mlrun.errors diff --git a/mlrun/api/utils/db/mysql.py b/mlrun/api/utils/db/mysql.py index c4c80961174d..31a69b53e621 100644 --- a/mlrun/api/utils/db/mysql.py +++ b/mlrun/api/utils/db/mysql.py @@ -32,29 +32,16 @@ class MySQLUtil(object): "functions", ] - def __init__(self): - mysql_dsn_data = self.get_mysql_dsn_data() - if not mysql_dsn_data: - raise RuntimeError(f"Invalid mysql dsn: {self.get_dsn()}") - - @staticmethod - def wait_for_db_liveness(logger, retry_interval=3, timeout=2 * 60): - logger.debug("Waiting for database liveness") - mysql_dsn_data = MySQLUtil.get_mysql_dsn_data() - if not mysql_dsn_data: - dsn = MySQLUtil.get_dsn() - if "sqlite" in dsn: - logger.debug("SQLite DB is used, liveness check not needed") - else: - logger.warn( - f"Invalid mysql dsn: {MySQLUtil.get_dsn()}, assuming live and skipping liveness verification" - ) - return + def __init__(self, logger: mlrun.utils.Logger): + self._logger = logger + def wait_for_db_liveness(self, retry_interval=3, timeout=2 * 60): + self._logger.debug("Waiting for database liveness") + mysql_dsn_data = self.get_mysql_dsn_data() tmp_connection = mlrun.utils.retry_until_successful( retry_interval, timeout, - logger, + self._logger, True, pymysql.connect, host=mysql_dsn_data["host"], @@ -62,7 +49,7 @@ def wait_for_db_liveness(logger, retry_interval=3, timeout=2 * 60): port=int(mysql_dsn_data["port"]), database=mysql_dsn_data["database"], ) - logger.debug("Database ready for connection") + self._logger.debug("Database ready for connection") tmp_connection.close() def check_db_has_tables(self): @@ -78,6 +65,18 @@ def check_db_has_tables(self): finally: connection.close() + def set_modes(self, modes): + if not modes or modes in ["nil", "none"]: + self._logger.debug("No sql modes were given, bailing", modes=modes) + return + connection = self._create_connection() + try: + self._logger.debug("Setting sql modes", modes=modes) + with connection.cursor() as cursor: + cursor.execute("SET GLOBAL sql_mode=%s;", (modes,)) + finally: + connection.close() + def check_db_has_data(self): connection = self._create_connection() try: @@ -101,10 +100,6 @@ def _create_connection(self): database=mysql_dsn_data["database"], ) - @staticmethod - def get_dsn() -> str: - return os.environ.get(MySQLUtil.dsn_env_var, "") - @staticmethod def get_mysql_dsn_data() -> typing.Optional[dict]: match = re.match(MySQLUtil.dsn_regex, MySQLUtil.get_dsn()) @@ -112,3 +107,7 @@ def get_mysql_dsn_data() -> typing.Optional[dict]: return None return match.groupdict() + + @staticmethod + def get_dsn() -> str: + return os.environ.get(MySQLUtil.dsn_env_var, "") diff --git a/mlrun/api/utils/db/sqlite_migration.py b/mlrun/api/utils/db/sqlite_migration.py index c5030798e84d..16492ff635f4 100644 --- a/mlrun/api/utils/db/sqlite_migration.py +++ b/mlrun/api/utils/db/sqlite_migration.py @@ -64,7 +64,7 @@ def __init__(self): self._migrator = self._create_migrator() self._mysql_util = None if self._mysql_dsn_data: - self._mysql_util = MySQLUtil() + self._mysql_util = MySQLUtil(logger) def is_database_migration_needed(self) -> bool: # if some data is missing, don't transfer the data diff --git a/mlrun/api/utils/events/__init__.py b/mlrun/api/utils/events/__init__.py new file mode 100644 index 000000000000..33c5b3d3bd7c --- /dev/null +++ b/mlrun/api/utils/events/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/mlrun/api/utils/events/base.py b/mlrun/api/utils/events/base.py new file mode 100644 index 000000000000..a8ce25ade56a --- /dev/null +++ b/mlrun/api/utils/events/base.py @@ -0,0 +1,85 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +import typing + +import mlrun.common.schemas + + +class BaseEventClient: + @abc.abstractmethod + def emit(self, event): + pass + + def generate_project_auth_secret_event( + self, + username: str, + secret_name: str, + action: mlrun.common.schemas.AuthSecretEventActions, + ): + """ + Generate a project auth secret event + :param username: username + :param secret_name: secret name + :param action: preformed action + :return: event object to emit + """ + pass + + @abc.abstractmethod + def generate_project_auth_secret_created_event( + self, username: str, secret_name: str + ): + pass + + @abc.abstractmethod + def generate_project_auth_secret_updated_event( + self, username: str, secret_name: str + ): + pass + + @abc.abstractmethod + def generate_project_secret_event( + self, + project: str, + secret_name: str, + secret_keys: typing.List[str] = None, + action: mlrun.common.schemas.SecretEventActions = mlrun.common.schemas.SecretEventActions.created, + ): + """ + Generate a project secret event + :param project: project name + :param secret_name: secret name + :param secret_keys: secret keys, optional, only relevant for created/updated events + :param action: preformed action + :return: event object to emit + """ + pass + + @abc.abstractmethod + def generate_project_secret_created_event( + self, project: str, secret_name: str, secret_keys: typing.List[str] + ): + pass + + @abc.abstractmethod + def generate_project_secret_updated_event( + self, project: str, secret_name: str, secret_keys: typing.List[str] + ): + pass + + @abc.abstractmethod + def generate_project_secret_deleted_event(self, project: str, secret_name: str): + pass diff --git a/mlrun/api/utils/events/events_factory.py b/mlrun/api/utils/events/events_factory.py new file mode 100644 index 000000000000..a48437401f2c --- /dev/null +++ b/mlrun/api/utils/events/events_factory.py @@ -0,0 +1,41 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import mlrun.api.utils.events.base +import mlrun.api.utils.events.iguazio +import mlrun.api.utils.events.nop +import mlrun.common.schemas +import mlrun.utils.singleton + + +class EventsFactory(object): + @staticmethod + def get_events_client( + kind: mlrun.common.schemas.EventClientKinds = None, **kwargs + ) -> mlrun.api.utils.events.base.BaseEventClient: + if mlrun.mlconf.events.mode == mlrun.common.schemas.EventsModes.disabled: + return mlrun.api.utils.events.nop.NopClient() + + if not kind: + if mlrun.mlconf.get_parsed_igz_version(): + kind = mlrun.common.schemas.EventClientKinds.iguazio + + if kind == mlrun.common.schemas.EventClientKinds.iguazio: + if not mlrun.mlconf.get_parsed_igz_version(): + raise mlrun.errors.MLRunInvalidArgumentError( + "Iguazio events client can only be used in Iguazio environment" + ) + return mlrun.api.utils.events.iguazio.Client(**kwargs) + + return mlrun.api.utils.events.nop.NopClient() diff --git a/mlrun/api/utils/events/iguazio.py b/mlrun/api/utils/events/iguazio.py new file mode 100644 index 000000000000..cdf3dd080866 --- /dev/null +++ b/mlrun/api/utils/events/iguazio.py @@ -0,0 +1,179 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import typing + +import igz_mgmt.schemas.manual_events + +import mlrun.api.utils.clients.iguazio +import mlrun.api.utils.events.base +import mlrun.common.schemas +from mlrun.utils import logger + +PROJECT_AUTH_SECRET_CREATED = "Security.Project.AuthSecret.Created" +PROJECT_AUTH_SECRET_UPDATED = "Security.Project.AuthSecret.Updated" +PROJECT_SECRET_CREATED = "Security.Project.Secret.Created" +PROJECT_SECRET_UPDATED = "Security.Project.Secret.Updated" +PROJECT_SECRET_DELETED = "Security.Project.Secret.Deleted" + + +class Client(mlrun.api.utils.events.base.BaseEventClient): + def __init__(self, access_key: str = None, verbose: bool = None): + self.access_key = ( + access_key + or mlrun.mlconf.events.access_key + or mlrun.mlconf.get_v3io_access_key() + ) + self.verbose = verbose if verbose is not None else mlrun.mlconf.events.verbose + self.source = "mlrun-api" + + def emit(self, event: igz_mgmt.schemas.manual_events.ManualEventSchema): + try: + logger.debug("Emitting event", event=event) + mlrun.api.utils.clients.iguazio.Client().emit_manual_event( + self.access_key, event + ) + except Exception as exc: + if self.verbose: + logger.warning( + "Failed to emit event", + event=event, + exc_info=exc, + ) + + def generate_project_auth_secret_event( + self, + username: str, + secret_name: str, + action: mlrun.common.schemas.AuthSecretEventActions, + ) -> igz_mgmt.schemas.manual_events.ManualEventSchema: + """ + Generate a project auth secret event + :param username: username + :param secret_name: secret name + :param action: preformed action + :return: event object to emit + """ + if action == mlrun.common.schemas.SecretEventActions.created: + return self.generate_project_auth_secret_created_event( + username, secret_name + ) + elif action == mlrun.common.schemas.SecretEventActions.updated: + return self.generate_project_auth_secret_updated_event( + username, secret_name + ) + else: + raise mlrun.errors.MLRunInvalidArgumentError(f"Unsupported action {action}") + + def generate_project_auth_secret_created_event( + self, username: str, secret_name: str + ) -> igz_mgmt.schemas.manual_events.ManualEventSchema: + return igz_mgmt.schemas.manual_events.ManualEventSchema( + source=self.source, + kind=PROJECT_AUTH_SECRET_CREATED, + description=f"User {username} created secret {secret_name}", + severity=igz_mgmt.constants.EventSeverity.info, + classification=igz_mgmt.constants.EventClassification.security, + system_event=False, + visibility=igz_mgmt.constants.EventVisibility.external, + ) + + def generate_project_auth_secret_updated_event( + self, username: str, secret_name: str + ) -> igz_mgmt.schemas.manual_events.ManualEventSchema: + return igz_mgmt.schemas.manual_events.ManualEventSchema( + source=self.source, + kind=PROJECT_AUTH_SECRET_UPDATED, + description=f"User {username} updated secret {secret_name}", + severity=igz_mgmt.constants.EventSeverity.info, + classification=igz_mgmt.constants.EventClassification.security, + system_event=False, + visibility=igz_mgmt.constants.EventVisibility.external, + ) + + def generate_project_secret_event( + self, + project: str, + secret_name: str, + secret_keys: typing.List[str] = None, + action: mlrun.common.schemas.SecretEventActions = mlrun.common.schemas.SecretEventActions.created, + ) -> igz_mgmt.schemas.manual_events.ManualEventSchema: + """ + Generate a project secret event + :param project: project name + :param secret_name: secret name + :param secret_keys: secret keys, optional, only relevant for created/updated events + :param action: preformed action + :return: event object to emit + """ + if action == mlrun.common.schemas.SecretEventActions.created: + return self.generate_project_secret_created_event( + project, secret_name, secret_keys + ) + elif action == mlrun.common.schemas.SecretEventActions.updated: + return self.generate_project_secret_updated_event( + project, secret_name, secret_keys + ) + elif action == mlrun.common.schemas.SecretEventActions.deleted: + return self.generate_project_secret_deleted_event(project, secret_name) + else: + raise mlrun.errors.MLRunInvalidArgumentError(f"Unsupported action {action}") + + def generate_project_secret_created_event( + self, project: str, secret_name: str, secret_keys: typing.List[str] + ) -> igz_mgmt.schemas.manual_events.ManualEventSchema: + normalized_secret_keys = self._list_to_string(secret_keys) + return igz_mgmt.schemas.manual_events.ManualEventSchema( + source=self.source, + kind=PROJECT_SECRET_CREATED, + description=f"Created project secret {secret_name} with secret keys {normalized_secret_keys}" + f" for project {project}", + severity=igz_mgmt.constants.EventSeverity.info, + classification=igz_mgmt.constants.EventClassification.security, + system_event=False, + visibility=igz_mgmt.constants.EventVisibility.external, + ) + + def generate_project_secret_updated_event( + self, + project: str, + secret_name: str, + secret_keys: typing.List[str], + ) -> igz_mgmt.schemas.manual_events.ManualEventSchema: + normalized_secret_keys = self._list_to_string(secret_keys) + return igz_mgmt.schemas.manual_events.ManualEventSchema( + source=self.source, + kind=PROJECT_SECRET_UPDATED, + description=f"Updated secret keys {normalized_secret_keys} of project secret {secret_name} " + f"for project {project}", + severity=igz_mgmt.constants.EventSeverity.info, + classification=igz_mgmt.constants.EventClassification.security, + system_event=False, + visibility=igz_mgmt.constants.EventVisibility.external, + ) + + def generate_project_secret_deleted_event(self, project: str, secret_name: str): + return igz_mgmt.schemas.manual_events.ManualEventSchema( + source=self.source, + kind=PROJECT_SECRET_DELETED, + description=f"Deleted project secret {secret_name} for project {project}", + severity=igz_mgmt.constants.EventSeverity.info, + classification=igz_mgmt.constants.EventClassification.security, + system_event=False, + visibility=igz_mgmt.constants.EventVisibility.external, + ) + + @staticmethod + def _list_to_string(list_to_convert: typing.List[str]) -> str: + return ", ".join(list_to_convert) diff --git a/mlrun/api/utils/events/nop.py b/mlrun/api/utils/events/nop.py new file mode 100644 index 000000000000..181583628ea1 --- /dev/null +++ b/mlrun/api/utils/events/nop.py @@ -0,0 +1,77 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import typing + +import mlrun.api.utils.events.base +import mlrun.common.schemas + + +class NopClient(mlrun.api.utils.events.base.BaseEventClient): + def emit(self, event): + return + + def generate_project_auth_secret_event( + self, + username: str, + secret_name: str, + action: mlrun.common.schemas.AuthSecretEventActions, + ): + """ + Generate a project auth secret event + :param username: username + :param secret_name: secret name + :param action: preformed action + :return: event object to emit + """ + return + + def generate_project_auth_secret_created_event( + self, username: str, secret_name: str + ): + return + + def generate_project_auth_secret_updated_event( + self, username: str, secret_name: str + ): + return + + def generate_project_secret_event( + self, + project: str, + secret_name: str, + secret_keys: typing.List[str] = None, + action: mlrun.common.schemas.SecretEventActions = mlrun.common.schemas.SecretEventActions.created, + ): + """ + Generate a project secret event + :param project: project name + :param secret_name: secret name + :param secret_keys: secret keys, optional, only relevant for created/updated events + :param action: preformed action + :return: event object to emit + """ + + def generate_project_secret_created_event( + self, project: str, secret_name: str, secret_keys: typing.List[str] + ): + return + + def generate_project_secret_updated_event( + self, project: str, secret_name: str, secret_keys: typing.List[str] + ): + return + + def generate_project_secret_deleted_event(self, project: str, secret_name: str): + return diff --git a/mlrun/api/utils/helpers.py b/mlrun/api/utils/helpers.py index 43dcdc0c8acd..fd31976ae5e8 100644 --- a/mlrun/api/utils/helpers.py +++ b/mlrun/api/utils/helpers.py @@ -13,22 +13,12 @@ # limitations under the License. # import asyncio -import enum import mlrun -import mlrun.api.schemas +import mlrun.common.schemas from mlrun.utils import logger -# TODO: From python 3.11 StrEnum is built-in and this will not be needed -class StrEnum(str, enum.Enum): - def __str__(self): - return self.value - - def __repr__(self): - return self.value - - def ensure_running_on_chief(function): """ The motivation of this function is to catch development bugs in which we are accidentally using functions / flows @@ -41,7 +31,7 @@ def ensure_running_on_chief(function): def _ensure_running_on_chief(): if ( mlrun.mlconf.httpdb.clusterization.role - != mlrun.api.schemas.ClusterizationRole.chief + != mlrun.common.schemas.ClusterizationRole.chief ): if ( mlrun.mlconf.httpdb.clusterization.ensure_function_running_on_chief_mode @@ -67,4 +57,17 @@ async def async_wrapper(*args, **kwargs): if asyncio.iscoroutinefunction(function): return async_wrapper + + # ensure method name is preserved + wrapper.__name__ = function.__name__ + return wrapper + + +def minimize_project_schema( + project: mlrun.common.schemas.Project, +) -> mlrun.common.schemas.Project: + project.spec.functions = None + project.spec.workflows = None + project.spec.artifacts = None + return project diff --git a/mlrun/api/utils/periodic.py b/mlrun/api/utils/periodic.py index fc3f1409d5ec..821283a1829b 100644 --- a/mlrun/api/utils/periodic.py +++ b/mlrun/api/utils/periodic.py @@ -34,9 +34,12 @@ async def _periodic_function_wrapper(interval: int, function, *args, **kwargs): await function(*args, **kwargs) else: await run_in_threadpool(function, *args, **kwargs) - except Exception: + except Exception as exc: logger.warning( - f"Failed during periodic function execution: {function.__name__}, exc: {traceback.format_exc()}" + "Failed during periodic function execution", + func_name=function.__name__, + exc=mlrun.errors.err_to_str(exc), + tb=traceback.format_exc(), ) await asyncio.sleep(interval) diff --git a/mlrun/api/utils/projects/follower.py b/mlrun/api/utils/projects/follower.py index b2c88f181b72..f2eef0b4ccc4 100644 --- a/mlrun/api/utils/projects/follower.py +++ b/mlrun/api/utils/projects/follower.py @@ -23,7 +23,6 @@ import mlrun.api.crud import mlrun.api.db.session -import mlrun.api.schemas import mlrun.api.utils.auth.verifier import mlrun.api.utils.clients.iguazio import mlrun.api.utils.clients.nuclio @@ -31,6 +30,7 @@ import mlrun.api.utils.projects.member import mlrun.api.utils.projects.remotes.leader import mlrun.api.utils.projects.remotes.nop_leader +import mlrun.common.schemas import mlrun.config import mlrun.errors import mlrun.utils @@ -78,7 +78,7 @@ def initialize(self): # we're doing a full_sync on every initialization full_sync = ( mlrun.mlconf.httpdb.clusterization.role - == mlrun.api.schemas.ClusterizationRole.chief + == mlrun.common.schemas.ClusterizationRole.chief ) self._sync_projects(full_sync=full_sync) except Exception as exc: @@ -96,12 +96,12 @@ def shutdown(self): def create_project( self, db_session: sqlalchemy.orm.Session, - project: mlrun.api.schemas.Project, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + project: mlrun.common.schemas.Project, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, wait_for_completion: bool = True, commit_before_get: bool = False, - ) -> typing.Tuple[typing.Optional[mlrun.api.schemas.Project], bool]: + ) -> typing.Tuple[typing.Optional[mlrun.common.schemas.Project], bool]: if self._is_request_from_leader(projects_role): mlrun.api.crud.Projects().create_project(db_session, project) return project, False @@ -137,11 +137,11 @@ def store_project( self, db_session: sqlalchemy.orm.Session, name: str, - project: mlrun.api.schemas.Project, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + project: mlrun.common.schemas.Project, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, wait_for_completion: bool = True, - ) -> typing.Tuple[typing.Optional[mlrun.api.schemas.Project], bool]: + ) -> typing.Tuple[typing.Optional[mlrun.common.schemas.Project], bool]: if self._is_request_from_leader(projects_role): mlrun.api.crud.Projects().store_project(db_session, name, project) return project, False @@ -166,11 +166,11 @@ def patch_project( db_session: sqlalchemy.orm.Session, name: str, project: dict, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, wait_for_completion: bool = True, - ) -> typing.Tuple[typing.Optional[mlrun.api.schemas.Project], bool]: + ) -> typing.Tuple[typing.Optional[mlrun.common.schemas.Project], bool]: if self._is_request_from_leader(projects_role): # No real scenario for this to be useful currently - in iguazio patch is transformed to store request raise NotImplementedError("Patch operation not supported from leader") @@ -179,7 +179,7 @@ def patch_project( strategy = patch_mode.to_mergedeep_strategy() current_project_dict = current_project.dict(exclude_unset=True) mergedeep.merge(current_project_dict, project, strategy=strategy) - patched_project = mlrun.api.schemas.Project(**current_project_dict) + patched_project = mlrun.common.schemas.Project(**current_project_dict) return self.store_project( db_session, name, @@ -193,9 +193,9 @@ def delete_project( self, db_session: sqlalchemy.orm.Session, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, - auth_info: mlrun.api.schemas.AuthInfo = mlrun.api.schemas.AuthInfo(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, + auth_info: mlrun.common.schemas.AuthInfo = mlrun.common.schemas.AuthInfo(), wait_for_completion: bool = True, ) -> bool: if self._is_request_from_leader(projects_role): @@ -216,30 +216,30 @@ def get_project( db_session: sqlalchemy.orm.Session, name: str, leader_session: typing.Optional[str] = None, - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: return mlrun.api.crud.Projects().get_project(db_session, name) def get_project_owner( self, db_session: sqlalchemy.orm.Session, name: str, - ) -> mlrun.api.schemas.ProjectOwner: + ) -> mlrun.common.schemas.ProjectOwner: return self._leader_client.get_project_owner(self._sync_session, name) def list_projects( self, db_session: sqlalchemy.orm.Session, owner: str = None, - format_: mlrun.api.schemas.ProjectsFormat = mlrun.api.schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, # needed only for external usage when requesting leader format - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectsOutput: + ) -> mlrun.common.schemas.ProjectsOutput: if ( - format_ == mlrun.api.schemas.ProjectsFormat.leader + format_ == mlrun.common.schemas.ProjectsFormat.leader and not self._is_request_from_leader(projects_role) ): raise mlrun.errors.MLRunAccessDeniedError( @@ -249,7 +249,7 @@ def list_projects( projects_output = mlrun.api.crud.Projects().list_projects( db_session, owner, format_, labels, state, names ) - if format_ == mlrun.api.schemas.ProjectsFormat.leader: + if format_ == mlrun.common.schemas.ProjectsFormat.leader: leader_projects = [ self._leader_client.format_as_leader_project(project) for project in projects_output.projects @@ -262,11 +262,11 @@ async def list_project_summaries( db_session: sqlalchemy.orm.Session, owner: str = None, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + state: mlrun.common.schemas.ProjectState = None, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectSummariesOutput: + ) -> mlrun.common.schemas.ProjectSummariesOutput: return await mlrun.api.crud.Projects().list_project_summaries( db_session, owner, labels, state, names ) @@ -276,7 +276,7 @@ async def get_project_summary( db_session: sqlalchemy.orm.Session, name: str, leader_session: typing.Optional[str] = None, - ) -> mlrun.api.schemas.ProjectSummary: + ) -> mlrun.common.schemas.ProjectSummary: return await mlrun.api.crud.Projects().get_project_summary(db_session, name) def _start_periodic_sync(self): @@ -316,14 +316,14 @@ def _sync_projects(self, full_sync=False): db_session = mlrun.api.db.session.create_session() try: db_projects = mlrun.api.crud.Projects().list_projects( - db_session, format_=mlrun.api.schemas.ProjectsFormat.name_only + db_session, format_=mlrun.common.schemas.ProjectsFormat.name_only ) # Don't add projects in non terminal state if they didn't exist before to prevent race conditions filtered_projects = [] for leader_project in leader_projects: if ( leader_project.status.state - not in mlrun.api.schemas.ProjectState.terminal_states() + not in mlrun.common.schemas.ProjectState.terminal_states() and leader_project.metadata.name not in db_projects.projects ): continue @@ -349,7 +349,7 @@ def _sync_projects(self, full_sync=False): mlrun.api.crud.Projects().delete_project( db_session, project_to_remove, - mlrun.api.schemas.DeletionStrategy.cascading, + mlrun.common.schemas.DeletionStrategy.cascading, ) if latest_updated_at: @@ -363,7 +363,7 @@ def _sync_projects(self, full_sync=False): mlrun.api.db.session.close_session(db_session) def _is_request_from_leader( - self, projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] + self, projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] ) -> bool: if projects_role and projects_role.value == self._leader_name: return True @@ -371,7 +371,7 @@ def _is_request_from_leader( @staticmethod def _is_project_matching_labels( - labels: typing.List[str], project: mlrun.api.schemas.Project + labels: typing.List[str], project: mlrun.common.schemas.Project ): if not project.metadata.labels: return False diff --git a/mlrun/api/utils/projects/leader.py b/mlrun/api/utils/projects/leader.py index 30af77ee6676..e48c85265d23 100644 --- a/mlrun/api/utils/projects/leader.py +++ b/mlrun/api/utils/projects/leader.py @@ -20,12 +20,12 @@ import sqlalchemy.orm import mlrun.api.db.session -import mlrun.api.schemas import mlrun.api.utils.clients.nuclio import mlrun.api.utils.periodic import mlrun.api.utils.projects.member import mlrun.api.utils.projects.remotes.follower import mlrun.api.utils.projects.remotes.nop_follower +import mlrun.common.schemas import mlrun.config import mlrun.errors import mlrun.utils @@ -58,12 +58,12 @@ def shutdown(self): def create_project( self, db_session: sqlalchemy.orm.Session, - project: mlrun.api.schemas.Project, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + project: mlrun.common.schemas.Project, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, wait_for_completion: bool = True, commit_before_get: bool = False, - ) -> typing.Tuple[typing.Optional[mlrun.api.schemas.Project], bool]: + ) -> typing.Tuple[typing.Optional[mlrun.common.schemas.Project], bool]: self._enrich_and_validate_before_creation(project) self._run_on_all_followers(True, "create_project", db_session, project) return self.get_project(db_session, project.metadata.name), False @@ -72,11 +72,11 @@ def store_project( self, db_session: sqlalchemy.orm.Session, name: str, - project: mlrun.api.schemas.Project, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + project: mlrun.common.schemas.Project, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, wait_for_completion: bool = True, - ) -> typing.Tuple[typing.Optional[mlrun.api.schemas.Project], bool]: + ) -> typing.Tuple[typing.Optional[mlrun.common.schemas.Project], bool]: self._enrich_project(project) mlrun.projects.ProjectMetadata.validate_project_name(name) self._validate_body_and_path_names_matches(name, project) @@ -88,11 +88,11 @@ def patch_project( db_session: sqlalchemy.orm.Session, name: str, project: dict, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, wait_for_completion: bool = True, - ) -> typing.Tuple[mlrun.api.schemas.Project, bool]: + ) -> typing.Tuple[mlrun.common.schemas.Project, bool]: self._enrich_project_patch(project) self._validate_body_and_path_names_matches(name, project) self._run_on_all_followers( @@ -104,9 +104,9 @@ def delete_project( self, db_session: sqlalchemy.orm.Session, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, - auth_info: mlrun.api.schemas.AuthInfo = mlrun.api.schemas.AuthInfo(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, + auth_info: mlrun.common.schemas.AuthInfo = mlrun.common.schemas.AuthInfo(), wait_for_completion: bool = True, ) -> bool: self._projects_in_deletion.add(name) @@ -123,20 +123,20 @@ def get_project( db_session: sqlalchemy.orm.Session, name: str, leader_session: typing.Optional[str] = None, - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: return self._leader_follower.get_project(db_session, name) def list_projects( self, db_session: sqlalchemy.orm.Session, owner: str = None, - format_: mlrun.api.schemas.ProjectsFormat = mlrun.api.schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + state: mlrun.common.schemas.ProjectState = None, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectsOutput: + ) -> mlrun.common.schemas.ProjectsOutput: return self._leader_follower.list_projects( db_session, owner, format_, labels, state, names ) @@ -146,11 +146,11 @@ async def list_project_summaries( db_session: sqlalchemy.orm.Session, owner: str = None, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + state: mlrun.common.schemas.ProjectState = None, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectSummariesOutput: + ) -> mlrun.common.schemas.ProjectSummariesOutput: return await self._leader_follower.list_project_summaries( db_session, owner, labels, state, names ) @@ -160,14 +160,14 @@ async def get_project_summary( db_session: sqlalchemy.orm.Session, name: str, leader_session: typing.Optional[str] = None, - ) -> mlrun.api.schemas.ProjectSummary: + ) -> mlrun.common.schemas.ProjectSummary: return await self._leader_follower.get_project_summary(db_session, name) def get_project_owner( self, db_session: sqlalchemy.orm.Session, name: str, - ) -> mlrun.api.schemas.ProjectOwner: + ) -> mlrun.common.schemas.ProjectOwner: raise NotImplementedError() def _start_periodic_sync(self): @@ -192,8 +192,8 @@ def _sync_projects(self): db_session = mlrun.api.db.session.create_session() try: # re-generating all of the maps every time since _ensure_follower_projects_synced might cause changes - leader_projects: mlrun.api.schemas.ProjectsOutput - follower_projects_map: typing.Dict[str, mlrun.api.schemas.ProjectsOutput] + leader_projects: mlrun.common.schemas.ProjectsOutput + follower_projects_map: typing.Dict[str, mlrun.common.schemas.ProjectsOutput] leader_projects, follower_projects_map = self._run_on_all_followers( True, "list_projects", db_session ) @@ -245,9 +245,9 @@ def _ensure_project_synced( follower_names: typing.Set[str], project_name: str, followers_projects_map: typing.Dict[ - str, typing.Dict[str, mlrun.api.schemas.Project] + str, typing.Dict[str, mlrun.common.schemas.Project] ], - leader_projects_map: typing.Dict[str, mlrun.api.schemas.Project], + leader_projects_map: typing.Dict[str, mlrun.common.schemas.Project], ): # FIXME: This function only handles syncing project existence, i.e. if a user updates a project attribute # through one of the followers this change won't be synced and the projects will be left with this discrepancy @@ -272,7 +272,6 @@ def _ensure_project_synced( logger.warning( "Failed creating missing project in leader", project_follower_name=project_follower_name, - project=project, project_name=project_name, exc=err_to_str(exc), traceback=traceback.format_exc(), @@ -308,14 +307,13 @@ def _store_project_in_followers( db_session: sqlalchemy.orm.Session, follower_names: typing.Set[str], project_name: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): for follower_name in follower_names: logger.debug( "Updating project in follower", follower_name=follower_name, project_name=project_name, - project=project, ) try: self._enrich_and_validate_before_creation(project) @@ -329,7 +327,6 @@ def _store_project_in_followers( "Failed updating project in follower", follower_name=follower_name, project_name=project_name, - project=project, exc=err_to_str(exc), traceback=traceback.format_exc(), ) @@ -341,7 +338,7 @@ def _create_project_in_missing_followers( # the name of the follower which we took the missing project from project_follower_name: str, project_name: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): for missing_follower in missing_followers: logger.debug( @@ -349,7 +346,6 @@ def _create_project_in_missing_followers( missing_follower_name=missing_follower, project_follower_name=project_follower_name, project_name=project_name, - project=project, ) try: self._enrich_and_validate_before_creation(project) @@ -363,7 +359,6 @@ def _create_project_in_missing_followers( missing_follower_name=missing_follower, project_follower_name=project_follower_name, project_name=project_name, - project=project, exc=err_to_str(exc), traceback=traceback.format_exc(), ) @@ -428,12 +423,14 @@ def _initialize_follower( raise ValueError(f"Unknown follower name: {name}") return followers_classes_map[name] - def _enrich_and_validate_before_creation(self, project: mlrun.api.schemas.Project): + def _enrich_and_validate_before_creation( + self, project: mlrun.common.schemas.Project + ): self._enrich_project(project) mlrun.projects.ProjectMetadata.validate_project_name(project.metadata.name) @staticmethod - def _enrich_project(project: mlrun.api.schemas.Project): + def _enrich_project(project: mlrun.common.schemas.Project): project.status.state = project.spec.desired_state @staticmethod @@ -457,9 +454,9 @@ def validate_project_name(name: str, raise_on_failure: bool = True) -> bool: @staticmethod def _validate_body_and_path_names_matches( - path_name: str, project: typing.Union[mlrun.api.schemas.Project, dict] + path_name: str, project: typing.Union[mlrun.common.schemas.Project, dict] ): - if isinstance(project, mlrun.api.schemas.Project): + if isinstance(project, mlrun.common.schemas.Project): body_name = project.metadata.name elif isinstance(project, dict): body_name = project.get("metadata", {}).get("name") diff --git a/mlrun/api/utils/projects/member.py b/mlrun/api/utils/projects/member.py index 9f71e80274cd..779cde48c162 100644 --- a/mlrun/api/utils/projects/member.py +++ b/mlrun/api/utils/projects/member.py @@ -19,8 +19,8 @@ import mlrun.api.crud import mlrun.api.db.session -import mlrun.api.schemas import mlrun.api.utils.clients.log_collector +import mlrun.common.schemas import mlrun.utils.singleton from mlrun.utils import logger @@ -39,11 +39,11 @@ def ensure_project( db_session: sqlalchemy.orm.Session, name: str, wait_for_completion: bool = True, - auth_info: mlrun.api.schemas.AuthInfo = mlrun.api.schemas.AuthInfo(), + auth_info: mlrun.common.schemas.AuthInfo = mlrun.common.schemas.AuthInfo(), ): project_names = self.list_projects( db_session, - format_=mlrun.api.schemas.ProjectsFormat.name_only, + format_=mlrun.common.schemas.ProjectsFormat.name_only, leader_session=auth_info.session, ) if name not in project_names.projects: @@ -53,12 +53,12 @@ def ensure_project( def create_project( self, db_session: sqlalchemy.orm.Session, - project: mlrun.api.schemas.Project, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + project: mlrun.common.schemas.Project, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, wait_for_completion: bool = True, commit_before_get: bool = False, - ) -> typing.Tuple[typing.Optional[mlrun.api.schemas.Project], bool]: + ) -> typing.Tuple[typing.Optional[mlrun.common.schemas.Project], bool]: pass @abc.abstractmethod @@ -66,11 +66,11 @@ def store_project( self, db_session: sqlalchemy.orm.Session, name: str, - project: mlrun.api.schemas.Project, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + project: mlrun.common.schemas.Project, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, wait_for_completion: bool = True, - ) -> typing.Tuple[typing.Optional[mlrun.api.schemas.Project], bool]: + ) -> typing.Tuple[typing.Optional[mlrun.common.schemas.Project], bool]: pass @abc.abstractmethod @@ -79,11 +79,11 @@ def patch_project( db_session: sqlalchemy.orm.Session, name: str, project: dict, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, wait_for_completion: bool = True, - ) -> typing.Tuple[mlrun.api.schemas.Project, bool]: + ) -> typing.Tuple[mlrun.common.schemas.Project, bool]: pass @abc.abstractmethod @@ -91,9 +91,9 @@ def delete_project( self, db_session: sqlalchemy.orm.Session, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, - auth_info: mlrun.api.schemas.AuthInfo = mlrun.api.schemas.AuthInfo(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, + auth_info: mlrun.common.schemas.AuthInfo = mlrun.common.schemas.AuthInfo(), wait_for_completion: bool = True, ) -> bool: pass @@ -104,7 +104,7 @@ def get_project( db_session: sqlalchemy.orm.Session, name: str, leader_session: typing.Optional[str] = None, - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: pass @abc.abstractmethod @@ -112,13 +112,13 @@ def list_projects( self, db_session: sqlalchemy.orm.Session, owner: str = None, - format_: mlrun.api.schemas.ProjectsFormat = mlrun.api.schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + state: mlrun.common.schemas.ProjectState = None, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectsOutput: + ) -> mlrun.common.schemas.ProjectsOutput: pass @abc.abstractmethod @@ -127,7 +127,7 @@ async def get_project_summary( db_session: sqlalchemy.orm.Session, name: str, leader_session: typing.Optional[str] = None, - ) -> mlrun.api.schemas.ProjectSummary: + ) -> mlrun.common.schemas.ProjectSummary: pass @abc.abstractmethod @@ -136,11 +136,11 @@ async def list_project_summaries( db_session: sqlalchemy.orm.Session, owner: str = None, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, - projects_role: typing.Optional[mlrun.api.schemas.ProjectsRole] = None, + state: mlrun.common.schemas.ProjectState = None, + projects_role: typing.Optional[mlrun.common.schemas.ProjectsRole] = None, leader_session: typing.Optional[str] = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectSummariesOutput: + ) -> mlrun.common.schemas.ProjectSummariesOutput: pass @abc.abstractmethod @@ -148,7 +148,7 @@ def get_project_owner( self, db_session: sqlalchemy.orm.Session, name: str, - ) -> mlrun.api.schemas.ProjectOwner: + ) -> mlrun.common.schemas.ProjectOwner: pass async def post_delete_project( @@ -157,7 +157,7 @@ async def post_delete_project( ): if ( mlrun.mlconf.log_collector.mode - != mlrun.api.schemas.LogsCollectorMode.legacy + != mlrun.common.schemas.LogsCollectorMode.legacy ): await self._stop_logs_for_project(project_name) await self._delete_project_logs(project_name) diff --git a/mlrun/api/utils/projects/remotes/follower.py b/mlrun/api/utils/projects/remotes/follower.py index 39777156043c..73c679a34d04 100644 --- a/mlrun/api/utils/projects/remotes/follower.py +++ b/mlrun/api/utils/projects/remotes/follower.py @@ -17,13 +17,13 @@ import sqlalchemy.orm -import mlrun.api.schemas +import mlrun.common.schemas class Member(abc.ABC): @abc.abstractmethod def create_project( - self, session: sqlalchemy.orm.Session, project: mlrun.api.schemas.Project + self, session: sqlalchemy.orm.Session, project: mlrun.common.schemas.Project ): pass @@ -32,7 +32,7 @@ def store_project( self, session: sqlalchemy.orm.Session, name: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): pass @@ -42,7 +42,7 @@ def patch_project( session: sqlalchemy.orm.Session, name: str, project: dict, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ): pass @@ -51,14 +51,14 @@ def delete_project( self, session: sqlalchemy.orm.Session, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), ): pass @abc.abstractmethod def get_project( self, session: sqlalchemy.orm.Session, name: str - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: pass @abc.abstractmethod @@ -66,11 +66,11 @@ def list_projects( self, session: sqlalchemy.orm.Session, owner: str = None, - format_: mlrun.api.schemas.ProjectsFormat = mlrun.api.schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectsOutput: + ) -> mlrun.common.schemas.ProjectsOutput: pass @abc.abstractmethod @@ -79,13 +79,13 @@ def list_project_summaries( session: sqlalchemy.orm.Session, owner: str = None, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectSummariesOutput: + ) -> mlrun.common.schemas.ProjectSummariesOutput: pass @abc.abstractmethod def get_project_summary( self, session: sqlalchemy.orm.Session, name: str - ) -> mlrun.api.schemas.ProjectSummary: + ) -> mlrun.common.schemas.ProjectSummary: pass diff --git a/mlrun/api/utils/projects/remotes/leader.py b/mlrun/api/utils/projects/remotes/leader.py index 3b7522905438..3283f6d6614c 100644 --- a/mlrun/api/utils/projects/remotes/leader.py +++ b/mlrun/api/utils/projects/remotes/leader.py @@ -16,7 +16,7 @@ import datetime import typing -import mlrun.api.schemas +import mlrun.common.schemas class Member(abc.ABC): @@ -24,7 +24,7 @@ class Member(abc.ABC): def create_project( self, session: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, wait_for_completion: bool = True, ) -> bool: pass @@ -34,7 +34,7 @@ def update_project( self, session: str, name: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): pass @@ -43,7 +43,7 @@ def delete_project( self, session: str, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), wait_for_completion: bool = True, ) -> bool: pass @@ -54,7 +54,7 @@ def list_projects( session: str, updated_after: typing.Optional[datetime.datetime] = None, ) -> typing.Tuple[ - typing.List[mlrun.api.schemas.Project], typing.Optional[datetime.datetime] + typing.List[mlrun.common.schemas.Project], typing.Optional[datetime.datetime] ]: pass @@ -63,13 +63,13 @@ def get_project( self, session: str, name: str, - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: pass @abc.abstractmethod def format_as_leader_project( - self, project: mlrun.api.schemas.Project - ) -> mlrun.api.schemas.IguazioProject: + self, project: mlrun.common.schemas.Project + ) -> mlrun.common.schemas.IguazioProject: pass @abc.abstractmethod @@ -77,5 +77,5 @@ def get_project_owner( self, session: str, name: str, - ) -> mlrun.api.schemas.ProjectOwner: + ) -> mlrun.common.schemas.ProjectOwner: pass diff --git a/mlrun/api/utils/projects/remotes/nop_follower.py b/mlrun/api/utils/projects/remotes/nop_follower.py index d29e6bb3420d..c5a7b4c12fe3 100644 --- a/mlrun/api/utils/projects/remotes/nop_follower.py +++ b/mlrun/api/utils/projects/remotes/nop_follower.py @@ -17,18 +17,19 @@ import mergedeep import sqlalchemy.orm -import mlrun.api.schemas +import mlrun.api.utils.helpers import mlrun.api.utils.projects.remotes.follower +import mlrun.common.schemas import mlrun.errors class Member(mlrun.api.utils.projects.remotes.follower.Member): def __init__(self) -> None: super().__init__() - self._projects: typing.Dict[str, mlrun.api.schemas.Project] = {} + self._projects: typing.Dict[str, mlrun.common.schemas.Project] = {} def create_project( - self, session: sqlalchemy.orm.Session, project: mlrun.api.schemas.Project + self, session: sqlalchemy.orm.Session, project: mlrun.common.schemas.Project ): if project.metadata.name in self._projects: raise mlrun.errors.MLRunConflictError("Project already exists") @@ -39,7 +40,7 @@ def store_project( self, session: sqlalchemy.orm.Session, name: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): # deep copy so we won't accidentally get changes from tests self._projects[name] = project.copy(deep=True) @@ -49,25 +50,25 @@ def patch_project( session: sqlalchemy.orm.Session, name: str, project: dict, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, ): existing_project_dict = self._projects[name].dict() strategy = patch_mode.to_mergedeep_strategy() mergedeep.merge(existing_project_dict, project, strategy=strategy) - self._projects[name] = mlrun.api.schemas.Project(**existing_project_dict) + self._projects[name] = mlrun.common.schemas.Project(**existing_project_dict) def delete_project( self, session: sqlalchemy.orm.Session, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), ): if name in self._projects: del self._projects[name] def get_project( self, session: sqlalchemy.orm.Session, name: str - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: # deep copy so we won't accidentally get changes from tests return self._projects[name].copy(deep=True) @@ -75,11 +76,11 @@ def list_projects( self, session: sqlalchemy.orm.Session, owner: str = None, - format_: mlrun.api.schemas.ProjectsFormat = mlrun.api.schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectsOutput: + ) -> mlrun.common.schemas.ProjectsOutput: if owner or labels or state: raise NotImplementedError( "Filtering by owner, labels or state is not supported" @@ -93,11 +94,18 @@ def list_projects( for project_name, project in self._projects.items() if project_name in names ] - if format_ == mlrun.api.schemas.ProjectsFormat.full: - return mlrun.api.schemas.ProjectsOutput(projects=projects) - elif format_ == mlrun.api.schemas.ProjectsFormat.name_only: + if format_ == mlrun.common.schemas.ProjectsFormat.full: + return mlrun.common.schemas.ProjectsOutput(projects=projects) + elif format_ == mlrun.common.schemas.ProjectsFormat.minimal: + return mlrun.common.schemas.ProjectsOutput( + projects=[ + mlrun.api.utils.helpers.minimize_project_schema(project) + for project in projects + ] + ) + elif format_ == mlrun.common.schemas.ProjectsFormat.name_only: project_names = [project.metadata.name for project in projects] - return mlrun.api.schemas.ProjectsOutput(projects=project_names) + return mlrun.common.schemas.ProjectsOutput(projects=project_names) else: raise NotImplementedError( f"Provided format is not supported. format={format_}" @@ -108,12 +116,12 @@ def list_project_summaries( session: sqlalchemy.orm.Session, owner: str = None, labels: typing.List[str] = None, - state: mlrun.api.schemas.ProjectState = None, + state: mlrun.common.schemas.ProjectState = None, names: typing.Optional[typing.List[str]] = None, - ) -> mlrun.api.schemas.ProjectSummariesOutput: + ) -> mlrun.common.schemas.ProjectSummariesOutput: raise NotImplementedError("Listing project summaries is not supported") def get_project_summary( self, session: sqlalchemy.orm.Session, name: str - ) -> mlrun.api.schemas.ProjectSummary: + ) -> mlrun.common.schemas.ProjectSummary: raise NotImplementedError("Get project summary is not supported") diff --git a/mlrun/api/utils/projects/remotes/nop_leader.py b/mlrun/api/utils/projects/remotes/nop_leader.py index 961d49148414..92bb717bc0a8 100644 --- a/mlrun/api/utils/projects/remotes/nop_leader.py +++ b/mlrun/api/utils/projects/remotes/nop_leader.py @@ -15,9 +15,9 @@ import datetime import typing -import mlrun.api.schemas import mlrun.api.utils.projects.remotes.leader import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.errors @@ -26,12 +26,12 @@ def __init__(self) -> None: super().__init__() self.db_session = None self.project_owner_access_key = "" - self._project_role = mlrun.api.schemas.ProjectsRole.nop + self._project_role = mlrun.common.schemas.ProjectsRole.nop def create_project( self, session: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, wait_for_completion: bool = True, ) -> bool: self._update_state(project) @@ -47,7 +47,7 @@ def update_project( self, session: str, name: str, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): self._update_state(project) mlrun.api.utils.singletons.project_member.get_project_member().store_project( @@ -55,12 +55,13 @@ def update_project( ) @staticmethod - def _update_state(project: mlrun.api.schemas.Project): + def _update_state(project: mlrun.common.schemas.Project): if ( not project.status.state - or project.status.state in mlrun.api.schemas.ProjectState.terminal_states() + or project.status.state + in mlrun.common.schemas.ProjectState.terminal_states() ): - project.status.state = mlrun.api.schemas.ProjectState( + project.status.state = mlrun.common.schemas.ProjectState( project.spec.desired_state ) @@ -68,7 +69,7 @@ def delete_project( self, session: str, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), wait_for_completion: bool = True, ) -> bool: return mlrun.api.utils.singletons.project_member.get_project_member().delete_project( @@ -80,7 +81,7 @@ def list_projects( session: str, updated_after: typing.Optional[datetime.datetime] = None, ) -> typing.Tuple[ - typing.List[mlrun.api.schemas.Project], typing.Optional[datetime.datetime] + typing.List[mlrun.common.schemas.Project], typing.Optional[datetime.datetime] ]: return ( mlrun.api.utils.singletons.project_member.get_project_member() @@ -93,7 +94,7 @@ def get_project( self, session: str, name: str, - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: return ( mlrun.api.utils.singletons.project_member.get_project_member().get_project( self.db_session, name @@ -101,16 +102,16 @@ def get_project( ) def format_as_leader_project( - self, project: mlrun.api.schemas.Project - ) -> mlrun.api.schemas.IguazioProject: - return mlrun.api.schemas.IguazioProject(data=project.dict()) + self, project: mlrun.common.schemas.Project + ) -> mlrun.common.schemas.IguazioProject: + return mlrun.common.schemas.IguazioProject(data=project.dict()) def get_project_owner( self, session: str, name: str, - ) -> mlrun.api.schemas.ProjectOwner: + ) -> mlrun.common.schemas.ProjectOwner: project = self.get_project(session, name) - return mlrun.api.schemas.ProjectOwner( + return mlrun.common.schemas.ProjectOwner( username=project.spec.owner, access_key=self.project_owner_access_key ) diff --git a/mlrun/api/utils/scheduler.py b/mlrun/api/utils/scheduler.py index 9c6e1f3edcd2..e0f52fe7e42e 100644 --- a/mlrun/api/utils/scheduler.py +++ b/mlrun/api/utils/scheduler.py @@ -27,11 +27,12 @@ from apscheduler.triggers.cron import CronTrigger as APSchedulerCronTrigger from sqlalchemy.orm import Session +import mlrun.api.api.utils import mlrun.api.utils.auth.verifier import mlrun.api.utils.clients.iguazio import mlrun.api.utils.helpers +import mlrun.common.schemas import mlrun.errors -from mlrun.api import schemas from mlrun.api.db.session import close_session, create_session from mlrun.api.utils.singletons.db import get_db from mlrun.config import config @@ -60,7 +61,7 @@ def __init__(self): # we don't allow to schedule a job to run more than one time per X # NOTE this cannot be less than one minute - see _validate_cron_trigger self._min_allowed_interval = config.httpdb.scheduling.min_allowed_interval - self._secrets_provider = schemas.SecretProviderName.kubernetes + self._secrets_provider = mlrun.common.schemas.SecretProviderName.kubernetes async def start(self, db_session: Session): logger.info("Starting scheduler") @@ -73,7 +74,7 @@ async def start(self, db_session: Session): try: if ( mlrun.mlconf.httpdb.clusterization.role - == mlrun.api.schemas.ClusterizationRole.chief + == mlrun.common.schemas.ClusterizationRole.chief ): self._reload_schedules(db_session) except Exception as exc: @@ -93,7 +94,7 @@ def _append_access_key_secret_to_labels(self, labels, secret_name): return labels def _get_access_key_secret_name_from_db_record( - self, db_schedule: schemas.ScheduleRecord + self, db_schedule: mlrun.common.schemas.ScheduleRecord ): schedule_labels = db_schedule.dict()["labels"] for label in schedule_labels: @@ -104,19 +105,21 @@ def _get_access_key_secret_name_from_db_record( def create_schedule( self, db_session: Session, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, project: str, name: str, - kind: schemas.ScheduleKinds, + kind: mlrun.common.schemas.ScheduleKinds, scheduled_object: Union[Dict, Callable], - cron_trigger: Union[str, schemas.ScheduleCronTrigger], + cron_trigger: Union[str, mlrun.common.schemas.ScheduleCronTrigger], labels: Dict = None, concurrency_limit: int = None, ): if concurrency_limit is None: concurrency_limit = config.httpdb.scheduling.default_concurrency_limit if isinstance(cron_trigger, str): - cron_trigger = schemas.ScheduleCronTrigger.from_crontab(cron_trigger) + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger.from_crontab( + cron_trigger + ) self._validate_cron_trigger(cron_trigger) @@ -135,6 +138,9 @@ def create_schedule( # We use the schedule labels to keep track of the access-key to use. Note that this is the name of the secret, # not the secret value itself. Therefore, it can be kept in a non-secure field. labels = self._append_access_key_secret_to_labels(labels, secret_name) + + self._enrich_schedule_notifications(project, name, scheduled_object) + get_db().create_schedule( db_session, project, @@ -177,16 +183,18 @@ def update_schedule_next_run_time( def update_schedule( self, db_session: Session, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, project: str, name: str, scheduled_object: Union[Dict, Callable] = None, - cron_trigger: Union[str, schemas.ScheduleCronTrigger] = None, + cron_trigger: Union[str, mlrun.common.schemas.ScheduleCronTrigger] = None, labels: Dict = None, concurrency_limit: int = None, ): if isinstance(cron_trigger, str): - cron_trigger = schemas.ScheduleCronTrigger.from_crontab(cron_trigger) + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger.from_crontab( + cron_trigger + ) if cron_trigger is not None: self._validate_cron_trigger(cron_trigger) @@ -206,6 +214,8 @@ def update_schedule( secret_name = self._store_schedule_secrets_using_auth_secret(auth_info) labels = self._append_access_key_secret_to_labels(labels, secret_name) + self._enrich_schedule_notifications(project, name, scheduled_object) + get_db().update_schedule( db_session, project, @@ -241,7 +251,7 @@ def list_schedules( labels: str = None, include_last_run: bool = False, include_credentials: bool = False, - ) -> schemas.SchedulesOutput: + ) -> mlrun.common.schemas.SchedulesOutput: db_schedules = get_db().list_schedules(db_session, project, name, labels, kind) schedules = [] for db_schedule in db_schedules: @@ -249,7 +259,7 @@ def list_schedules( db_session, db_schedule, include_last_run, include_credentials ) schedules.append(schedule) - return schemas.SchedulesOutput(schedules=schedules) + return mlrun.common.schemas.SchedulesOutput(schedules=schedules) def get_schedule( self, @@ -258,7 +268,7 @@ def get_schedule( name: str, include_last_run: bool = False, include_credentials: bool = False, - ) -> schemas.ScheduleOutput: + ) -> mlrun.common.schemas.ScheduleOutput: logger.debug("Getting schedule", project=project, name=name) db_schedule = get_db().get_schedule(db_session, project, name) return self._transform_and_enrich_db_schedule( @@ -273,7 +283,7 @@ def delete_schedule( name: str, ): logger.debug("Deleting schedule", project=project, name=name) - self._remove_schedule_scheduler_resources(project, name) + self._remove_schedule_scheduler_resources(db_session, project, name) get_db().delete_schedule(db_session, project, name) @mlrun.api.utils.helpers.ensure_running_on_chief @@ -288,15 +298,18 @@ def delete_schedules( ) logger.debug("Deleting schedules", project=project) for schedule in schedules.schedules: - self._remove_schedule_scheduler_resources(schedule.project, schedule.name) + self._remove_schedule_scheduler_resources( + db_session, schedule.project, schedule.name + ) get_db().delete_schedules(db_session, project) - def _remove_schedule_scheduler_resources(self, project, name): + def _remove_schedule_scheduler_resources(self, db_session: Session, project, name): self._remove_schedule_from_scheduler(project, name) # This is kept for backwards compatibility - if schedule was using the "old" format of storing secrets, then # this is a good opportunity to remove them. Using the new method we don't remove secrets since they are per # access-key and there may be other entities (runtimes, for example) using the same secret. self._remove_schedule_secrets(project, name) + self._remove_schedule_notification_secrets(db_session, project, name) def _remove_schedule_from_scheduler(self, project, name): job_id = self._resolve_job_id(project, name) @@ -309,7 +322,7 @@ def _remove_schedule_from_scheduler(self, project, name): async def invoke_schedule( self, db_session: Session, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, project: str, name: str, ): @@ -333,15 +346,42 @@ async def invoke_schedule( ) return await function(*args, **kwargs) + @mlrun.api.utils.helpers.ensure_running_on_chief + def set_schedule_notifications( + self, + session: Session, + project: str, + identifier: mlrun.common.schemas.ScheduleIdentifier, + notifications: List[mlrun.model.Notification], + auth_info: mlrun.common.schemas.AuthInfo, + ): + """ + Set notifications for a schedule. This will replace any existing notifications. + :param session: DB session + :param project: Project name + :param identifier: Schedule identifier + :param notifications: List of notifications to set + :param auth_info: Authorization info + """ + name = identifier.name + logger.debug("Setting schedule notifications", project=project, name=name) + db_schedule = get_db().get_schedule(session, project, name) + scheduled_object = db_schedule.scheduled_object + if scheduled_object: + scheduled_object.get("task", {}).get("spec", {})["notifications"] = [ + notification.to_dict() for notification in notifications + ] + self.update_schedule(session, auth_info, project, name, scheduled_object) + def _ensure_auth_info_has_access_key( self, - auth_info: mlrun.api.schemas.AuthInfo, - kind: schemas.ScheduleKinds, + auth_info: mlrun.common.schemas.AuthInfo, + kind: mlrun.common.schemas.ScheduleKinds, ): import mlrun.api.crud if ( - kind not in schemas.ScheduleKinds.local_kinds() + kind not in mlrun.common.schemas.ScheduleKinds.local_kinds() and mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required() ): if ( @@ -371,7 +411,7 @@ def _ensure_auth_info_has_access_key( def _store_schedule_secrets_using_auth_secret( self, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, ) -> str: # import here to avoid circular imports import mlrun.api.crud @@ -388,8 +428,8 @@ def _store_schedule_secrets_using_auth_secret( auth_info.username = "" secret_name = mlrun.api.crud.Secrets().store_auth_secret( - mlrun.api.schemas.AuthSecretData( - provider=mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.AuthSecretData( + provider=mlrun.common.schemas.SecretProviderName.kubernetes, username=auth_info.username, access_key=auth_info.access_key, ) @@ -400,7 +440,7 @@ def _store_schedule_secrets_using_auth_secret( # are sure we are far enough that it's no longer going to be used (or keep, and use for other things). def _store_schedule_secrets( self, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, project: str, name: str, ): @@ -441,7 +481,7 @@ def _store_schedule_secrets( secrets[username_secret_key] = auth_info.username mlrun.api.crud.Secrets().store_project_secrets( project, - schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=self._secrets_provider, secrets=secrets, ), @@ -545,7 +585,7 @@ def _get_schedule_secrets( def _validate_cron_trigger( self, - cron_trigger: schemas.ScheduleCronTrigger, + cron_trigger: mlrun.common.schemas.ScheduleCronTrigger, # accepting now from outside for testing purposes now: datetime = None, ): @@ -590,7 +630,7 @@ def _validate_cron_trigger( delta=second_next_run_time - next_run_time, ) raise ValueError( - f"Cron trigger too frequent. no more then one job " + f"Cron trigger too frequent. no more than one job " f"per {self._min_allowed_interval} is allowed" ) @@ -598,11 +638,11 @@ def _create_schedule_in_scheduler( self, project: str, name: str, - kind: schemas.ScheduleKinds, + kind: mlrun.common.schemas.ScheduleKinds, scheduled_object: Any, - cron_trigger: schemas.ScheduleCronTrigger, + cron_trigger: mlrun.common.schemas.ScheduleCronTrigger, concurrency_limit: int, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, ): job_id = self._resolve_job_id(project, name) logger.debug("Adding schedule to scheduler", job_id=job_id) @@ -628,11 +668,11 @@ def _update_schedule_in_scheduler( self, project: str, name: str, - kind: schemas.ScheduleKinds, + kind: mlrun.common.schemas.ScheduleKinds, scheduled_object: Any, - cron_trigger: schemas.ScheduleCronTrigger, + cron_trigger: mlrun.common.schemas.ScheduleCronTrigger, concurrency_limit: int, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, ): job_id = self._resolve_job_id(project, name) logger.debug("Updating schedule in scheduler", job_id=job_id) @@ -706,7 +746,7 @@ def _reload_schedules(self, db_session: Session): if access_key: need_to_update_credentials = True - auth_info = mlrun.api.schemas.AuthInfo( + auth_info = mlrun.common.schemas.AuthInfo( username=username, access_key=access_key, # enriching with control plane tag because scheduling a function requires control plane @@ -753,21 +793,21 @@ def _reload_schedules(self, db_session: Session): def _transform_and_enrich_db_schedule( self, db_session: Session, - schedule_record: schemas.ScheduleRecord, + schedule_record: mlrun.common.schemas.ScheduleRecord, include_last_run: bool = False, include_credentials: bool = False, - ) -> schemas.ScheduleOutput: + ) -> mlrun.common.schemas.ScheduleOutput: schedule_dict = schedule_record.dict() schedule_dict["labels"] = { label["name"]: label["value"] for label in schedule_dict["labels"] } - schedule = schemas.ScheduleOutput(**schedule_dict) + schedule = mlrun.common.schemas.ScheduleOutput(**schedule_dict) # Schedules are running only on chief. Therefore, we query next_run_time from the scheduler only when # running on chief. if ( mlrun.mlconf.httpdb.clusterization.role - == mlrun.api.schemas.ClusterizationRole.chief + == mlrun.common.schemas.ClusterizationRole.chief ): job_id = self._resolve_job_id(schedule_record.project, schedule_record.name) job = self._scheduler.get_job(job_id) @@ -788,7 +828,7 @@ def _transform_and_enrich_db_schedule( @staticmethod def _enrich_schedule_with_last_run( - db_session: Session, schedule_output: schemas.ScheduleOutput + db_session: Session, schedule_output: mlrun.common.schemas.ScheduleOutput ): if schedule_output.last_run_uri: run_project, run_uid, iteration, _ = RunObject.parse_uri( @@ -798,7 +838,7 @@ def _enrich_schedule_with_last_run( schedule_output.last_run = run_data def _enrich_schedule_with_credentials( - self, schedule_output: schemas.ScheduleOutput + self, schedule_output: mlrun.common.schemas.ScheduleOutput ): secret_name = schedule_output.labels.get(self._db_record_auth_label) if secret_name: @@ -808,18 +848,18 @@ def _enrich_schedule_with_credentials( def _resolve_job_function( self, - scheduled_kind: schemas.ScheduleKinds, + scheduled_kind: mlrun.common.schemas.ScheduleKinds, scheduled_object: Any, project_name: str, schedule_name: str, schedule_concurrency_limit: int, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, ) -> Tuple[Callable, Optional[Union[List, Tuple]], Optional[Dict]]: """ :return: a tuple (function, args, kwargs) to be used with the APScheduler.add_job """ - if scheduled_kind == schemas.ScheduleKinds.job: + if scheduled_kind == mlrun.common.schemas.ScheduleKinds.job: scheduled_object_copy = copy.deepcopy(scheduled_object) return ( Scheduler.submit_run_wrapper, @@ -833,7 +873,7 @@ def _resolve_job_function( ], {}, ) - if scheduled_kind == schemas.ScheduleKinds.local_function: + if scheduled_kind == mlrun.common.schemas.ScheduleKinds.local_function: return scheduled_object, [], {} # sanity @@ -851,6 +891,55 @@ def _resolve_job_id(self, project, name) -> str: """ return self._job_id_separator.join([project, name]) + @staticmethod + def _enrich_schedule_notifications( + project: str, schedule_name: str, scheduled_object: Union[Dict, Callable] + ): + if not isinstance(scheduled_object, dict): + return + + schedule_notifications = ( + scheduled_object.get("task", {}).get("spec", {}).get("notifications") + ) + if schedule_notifications: + scheduled_object["task"]["spec"]["notifications"] = [ + notification.to_dict() + for notification in mlrun.api.api.utils.validate_and_mask_notification_list( + schedule_notifications, schedule_name, project + ) + ] + + @staticmethod + def _remove_schedule_notification_secrets( + db_session: Session, project: str, schedule_name: str + ): + try: + db_schedule = get_db().get_schedule( + db_session, + project, + schedule_name, + ) + except mlrun.errors.MLRunNotFoundError: + # we allow deleting a schedule even if it does not exist in the DB + logger.debug( + "Failed to find schedule. Continuing", + project=project, + schedule_name=schedule_name, + ) + return + + if db_schedule and isinstance(db_schedule.scheduled_object, dict): + notifications = ( + db_schedule.scheduled_object.get("task", {}) + .get("spec", {}) + .get("notifications") + ) + if notifications: + for notification in notifications: + mlrun.api.api.utils.delete_notification_params_secret( + project, mlrun.model.Notification.from_dict(notification) + ) + @staticmethod async def submit_run_wrapper( scheduler, @@ -858,7 +947,7 @@ async def submit_run_wrapper( project_name, schedule_name, schedule_concurrency_limit, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, ): # removing the schedule from the body otherwise when the scheduler will submit this task it will go to an @@ -872,7 +961,7 @@ async def submit_run_wrapper( if "task" in scheduled_object and "metadata" in scheduled_object["task"]: scheduled_object["task"]["metadata"].setdefault("labels", {}) scheduled_object["task"]["metadata"]["labels"][ - schemas.constants.LabelNames.schedule_name + mlrun.common.schemas.constants.LabelNames.schedule_name ] = schedule_name return await fastapi.concurrency.run_in_threadpool( @@ -887,7 +976,7 @@ async def submit_run_wrapper( @staticmethod def transform_schemas_cron_trigger_to_apscheduler_cron_trigger( - cron_trigger: schemas.ScheduleCronTrigger, + cron_trigger: mlrun.common.schemas.ScheduleCronTrigger, ): return APSchedulerCronTrigger( cron_trigger.year, @@ -927,7 +1016,7 @@ def _submit_run_wrapper( db_session, states=RunStates.non_terminal_states(), project=project_name, - labels=f"{schemas.constants.LabelNames.schedule_name}={schedule_name}", + labels=f"{mlrun.common.schemas.constants.LabelNames.schedule_name}={schedule_name}", ) if len(active_runs) >= schedule_concurrency_limit: logger.warn( @@ -965,7 +1054,7 @@ def _submit_run_wrapper( # Update the schedule with the new auth info so we won't need to do the above again in the next run scheduler.update_schedule( db_session, - mlrun.api.schemas.AuthInfo( + mlrun.common.schemas.AuthInfo( username=project_owner.username, access_key=project_owner.access_key, # enriching with control plane tag because scheduling a function requires control plane diff --git a/mlrun/api/utils/singletons/db.py b/mlrun/api/utils/singletons/db.py index c7b3cbe0d908..d31d7df73008 100644 --- a/mlrun/api/utils/singletons/db.py +++ b/mlrun/api/utils/singletons/db.py @@ -13,7 +13,6 @@ # limitations under the License. # from mlrun.api.db.base import DBInterface -from mlrun.api.db.filedb.db import FileDB from mlrun.api.db.sqldb.db import SQLDB from mlrun.api.db.sqldb.session import create_session from mlrun.config import config @@ -33,16 +32,11 @@ def initialize_db(override_db=None): if override_db: db = override_db return - if config.httpdb.db_type == "filedb": - logger.info("Creating file db") - db = FileDB(config.httpdb.dirpath) - db.initialize(None) - else: - logger.info("Creating sql db") - db = SQLDB(config.httpdb.dsn) - db_session = None - try: - db_session = create_session() - db.initialize(db_session) - finally: - db_session.close() + logger.info("Creating sql db") + db = SQLDB(config.httpdb.dsn) + db_session = None + try: + db_session = create_session() + db.initialize(db_session) + finally: + db_session.close() diff --git a/mlrun/api/utils/singletons/k8s.py b/mlrun/api/utils/singletons/k8s.py index 184c3ce2b3e3..80f9cc65c44f 100644 --- a/mlrun/api/utils/singletons/k8s.py +++ b/mlrun/api/utils/singletons/k8s.py @@ -11,9 +11,665 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# -from mlrun.k8s_utils import K8sHelper, get_k8s_helper +import base64 +import hashlib +import time +import typing + +from kubernetes import client, config +from kubernetes.client.rest import ApiException + +import mlrun.common.schemas +import mlrun.config as mlconfig +import mlrun.errors +import mlrun.platforms.iguazio +from mlrun.utils import logger + +_k8s = None + + +def get_k8s_helper(namespace=None, silent=True, log=False) -> "K8sHelper": + """ + Get a k8s helper singleton object + :param namespace: the namespace to use, if not specified will use the namespace configured in mlrun config + :param silent: set to true if you're calling this function from a code that might run from remotely (outside of a + k8s cluster) + :param log: sometimes we want to avoid logging when executing init_k8s_config + """ + global _k8s + if not _k8s: + _k8s = K8sHelper(namespace, silent=silent, log=log) + return _k8s + + +class SecretTypes: + opaque = "Opaque" + v3io_fuse = "v3io/fuse" + + +class K8sHelper: + def __init__(self, namespace=None, silent=False, log=True): + self.namespace = namespace or mlconfig.config.namespace + self.config_file = mlconfig.config.kubernetes.kubeconfig_path or None + self.running_inside_kubernetes_cluster = False + try: + self._init_k8s_config(log) + self.v1api = client.CoreV1Api() + self.crdapi = client.CustomObjectsApi() + except Exception as exc: + logger.warning( + "cannot initialize kubernetes client", exc=mlrun.errors.err_to_str(exc) + ) + if not silent: + raise + + def resolve_namespace(self, namespace=None): + return namespace or self.namespace + + def _init_k8s_config(self, log=True): + try: + config.load_incluster_config() + self.running_inside_kubernetes_cluster = True + if log: + logger.info("using in-cluster config.") + except Exception: + try: + config.load_kube_config(self.config_file) + if log: + logger.info("using local kubernetes config.") + except Exception: + raise RuntimeError( + "cannot find local kubernetes config file," + " place it in ~/.kube/config or specify it in " + "KUBECONFIG env var" + ) + + def is_running_inside_kubernetes_cluster(self): + return self.running_inside_kubernetes_cluster + + def list_pods(self, namespace=None, selector="", states=None): + try: + resp = self.v1api.list_namespaced_pod( + self.resolve_namespace(namespace), label_selector=selector + ) + except ApiException as exc: + logger.error(f"failed to list pods: {mlrun.errors.err_to_str(exc)}") + raise exc + + items = [] + for i in resp.items: + if not states or i.status.phase in states: + items.append(i) + return items + + def create_pod(self, pod, max_retry=3, retry_interval=3): + if "pod" in dir(pod): + pod = pod.pod + pod.metadata.namespace = self.resolve_namespace(pod.metadata.namespace) + + retry_count = 0 + while True: + try: + resp = self.v1api.create_namespaced_pod(pod.metadata.namespace, pod) + except ApiException as exc: + + if retry_count > max_retry: + logger.error( + "failed to create pod after max retries", + retry_count=retry_count, + exc=mlrun.errors.err_to_str(exc), + pod=pod, + ) + raise exc + + logger.error( + "failed to create pod", exc=mlrun.errors.err_to_str(exc), pod=pod + ) + + # known k8s issue, see https://github.com/kubernetes/kubernetes/issues/67761 + if "gke-resource-quotas" in mlrun.errors.err_to_str(exc): + logger.warning( + "failed to create pod due to gke resource error, " + f"sleeping {retry_interval} seconds and retrying" + ) + retry_count += 1 + time.sleep(retry_interval) + continue + + raise exc + else: + logger.info(f"Pod {resp.metadata.name} created") + return resp.metadata.name, resp.metadata.namespace + + def delete_pod(self, name, namespace=None): + try: + api_response = self.v1api.delete_namespaced_pod( + name, + self.resolve_namespace(namespace), + grace_period_seconds=0, + propagation_policy="Background", + ) + return api_response + except ApiException as exc: + # ignore error if pod is already removed + if exc.status != 404: + logger.error( + f"failed to delete pod: {mlrun.errors.err_to_str(exc)}", + pod_name=name, + ) + raise exc + + def get_pod(self, name, namespace=None, raise_on_not_found=False): + try: + api_response = self.v1api.read_namespaced_pod( + name=name, namespace=self.resolve_namespace(namespace) + ) + return api_response + except ApiException as exc: + if exc.status != 404: + logger.error(f"failed to get pod: {mlrun.errors.err_to_str(exc)}") + raise exc + else: + if raise_on_not_found: + raise mlrun.errors.MLRunNotFoundError(f"Pod not found: {name}") + return None + + def get_pod_status(self, name, namespace=None): + return self.get_pod( + name, namespace, raise_on_not_found=True + ).status.phase.lower() + + def delete_crd(self, name, crd_group, crd_version, crd_plural, namespace=None): + try: + namespace = self.resolve_namespace(namespace) + self.crdapi.delete_namespaced_custom_object( + crd_group, + crd_version, + namespace, + crd_plural, + name, + ) + logger.info( + "Deleted crd object", + crd_name=name, + namespace=namespace, + ) + except ApiException as exc: + + # ignore error if crd is already removed + if exc.status != 404: + logger.error( + f"failed to delete crd: {mlrun.errors.err_to_str(exc)}", + crd_name=name, + crd_group=crd_group, + crd_version=crd_version, + crd_plural=crd_plural, + ) + raise exc + + def logs(self, name, namespace=None): + try: + resp = self.v1api.read_namespaced_pod_log( + name=name, namespace=self.resolve_namespace(namespace) + ) + except ApiException as exc: + logger.error(f"failed to get pod logs: {mlrun.errors.err_to_str(exc)}") + raise exc + + return resp + + def get_logger_pods(self, project, uid, run_kind, namespace=""): + + # As this file is imported in mlrun.runtimes, we sadly cannot have this import in the top level imports + # as that will create an import loop. + # TODO: Fix the import loops already! + import mlrun.runtimes + + namespace = self.resolve_namespace(namespace) + mpijob_crd_version = mlrun.runtimes.utils.resolve_mpijob_crd_version() + mpijob_role_label = ( + mlrun.runtimes.constants.MPIJobCRDVersions.role_label_by_version( + mpijob_crd_version + ) + ) + extra_selectors = { + "spark": "spark-role=driver", + "mpijob": f"{mpijob_role_label}=launcher", + } + + # TODO: all mlrun labels are sprinkled in a lot of places - they need to all be defined in a central, + # inclusive place. + selectors = [ + "mlrun/class", + f"mlrun/project={project}", + f"mlrun/uid={uid}", + ] + + # In order to make the `list_pods` request return a lighter and quicker result, we narrow the search for + # the relevant pods using the proper label selector according to the run kind + if run_kind in extra_selectors: + selectors.append(extra_selectors[run_kind]) + + selector = ",".join(selectors) + pods = self.list_pods(namespace, selector=selector) + if not pods: + logger.error("no pod matches that uid", uid=uid) + return + + return {p.metadata.name: p.status.phase for p in pods} + + def get_project_vault_secret_name( + self, project, service_account_name, namespace="" + ): + namespace = self.resolve_namespace(namespace) + + try: + service_account = self.v1api.read_namespaced_service_account( + service_account_name, namespace + ) + except ApiException as exc: + # It's valid for the service account to not exist. Simply return None + if exc.status != 404: + logger.error( + f"failed to retrieve service accounts: {mlrun.errors.err_to_str(exc)}" + ) + raise exc + return None + + if len(service_account.secrets) > 1: + raise ValueError( + f"Service account {service_account_name} has more than one secret" + ) + + return service_account.secrets[0].name + + def get_project_secret_name(self, project) -> str: + return mlconfig.config.secret_stores.kubernetes.project_secret_name.format( + project=project + ) + + def get_auth_secret_name(self, access_key: str) -> str: + hashed_access_key = self._hash_access_key(access_key) + return mlconfig.config.secret_stores.kubernetes.auth_secret_name.format( + hashed_access_key=hashed_access_key + ) + + @staticmethod + def _hash_access_key(access_key: str): + return hashlib.sha224(access_key.encode()).hexdigest() + + def store_project_secrets(self, project, secrets, namespace="") -> (str, bool): + secret_name = self.get_project_secret_name(project) + created = self.store_secrets(secret_name, secrets, namespace) + return secret_name, created + + def read_auth_secret(self, secret_name, namespace="", raise_on_not_found=False): + namespace = self.resolve_namespace(namespace) + + try: + secret_data = self.v1api.read_namespaced_secret(secret_name, namespace).data + except ApiException as exc: + logger.error( + "Failed to read secret", + secret_name=secret_name, + namespace=namespace, + exc=mlrun.errors.err_to_str(exc), + ) + if exc.status != 404: + raise exc + elif raise_on_not_found: + raise mlrun.errors.MLRunNotFoundError( + f"Secret '{secret_name}' was not found in namespace '{namespace}'" + ) from exc + + return None, None + + def _get_secret_value(key): + if secret_data.get(key): + return base64.b64decode(secret_data[key]).decode("utf-8") + else: + return None + + username = _get_secret_value( + mlrun.common.schemas.AuthSecretData.get_field_secret_key("username") + ) + access_key = _get_secret_value( + mlrun.common.schemas.AuthSecretData.get_field_secret_key("access_key") + ) + + return username, access_key + + def store_auth_secret( + self, username: str, access_key: str, namespace="" + ) -> (str, bool): + """ + Store the given access key as a secret in the cluster. The secret name is generated from the access key + :return: returns the secret name and a boolean indicating whether the secret was created or updated + """ + secret_name = self.get_auth_secret_name(access_key) + secret_data = { + mlrun.common.schemas.AuthSecretData.get_field_secret_key( + "username" + ): username, + mlrun.common.schemas.AuthSecretData.get_field_secret_key( + "access_key" + ): access_key, + } + created = self.store_secrets( + secret_name, + secret_data, + namespace, + type_=SecretTypes.v3io_fuse, + labels={"mlrun/username": username}, + ) + return secret_name, created + + def store_secrets( + self, + secret_name, + secrets, + namespace="", + type_=SecretTypes.opaque, + labels: typing.Optional[dict] = None, + ) -> bool: + """ + Store secrets in a kubernetes secret object + :return: returns True if the secret was created, False if it already existed and required an update + """ + namespace = self.resolve_namespace(namespace) + try: + k8s_secret = self.v1api.read_namespaced_secret(secret_name, namespace) + except ApiException as exc: + # If secret doesn't exist, we'll simply create it + if exc.status != 404: + logger.error( + f"failed to retrieve k8s secret: {mlrun.errors.err_to_str(exc)}" + ) + raise exc + k8s_secret = client.V1Secret(type=type_) + k8s_secret.metadata = client.V1ObjectMeta( + name=secret_name, namespace=namespace, labels=labels + ) + k8s_secret.string_data = secrets + self.v1api.create_namespaced_secret(namespace, k8s_secret) + return True + + secret_data = k8s_secret.data.copy() + for key, value in secrets.items(): + secret_data[key] = base64.b64encode(value.encode()).decode("utf-8") + + k8s_secret.data = secret_data + self.v1api.replace_namespaced_secret(secret_name, namespace, k8s_secret) + return False + + def load_secret(self, secret_name, namespace=""): + namespace = namespace or self.resolve_namespace(namespace) + + try: + k8s_secret = self.v1api.read_namespaced_secret(secret_name, namespace) + except ApiException: + return None + + return k8s_secret.data + + def delete_project_secrets(self, project, secrets, namespace="") -> (str, bool): + """ + Delete secrets from a kubernetes secret object + :return: returns the secret name and a boolean indicating whether the secret was deleted + """ + secret_name = self.get_project_secret_name(project) + deleted = self.delete_secrets(secret_name, secrets, namespace) + return secret_name, deleted + + def delete_auth_secret(self, secret_ref: str, namespace=""): + self.delete_secrets(secret_ref, {}, namespace) + + def delete_secrets(self, secret_name, secrets, namespace="") -> bool: + """ + Delete secrets from a kubernetes secret object + :return: returns True if the secret was deleted, False if it still exists and only deleted part of the keys + """ + namespace = self.resolve_namespace(namespace) + + try: + k8s_secret = self.v1api.read_namespaced_secret(secret_name, namespace) + except ApiException as exc: + # If secret does not exist, return as if the deletion was successfully + if exc.status == 404: + return + else: + logger.error( + f"failed to retrieve k8s secret: {mlrun.errors.err_to_str(exc)}" + ) + raise exc + + if not secrets: + secret_data = {} + else: + secret_data = k8s_secret.data.copy() + for secret in secrets: + secret_data.pop(secret, None) + + if not secret_data: + self.v1api.delete_namespaced_secret(secret_name, namespace) + return True + else: + k8s_secret.data = secret_data + self.v1api.replace_namespaced_secret(secret_name, namespace, k8s_secret) + return False + + def _get_project_secrets_raw_data(self, project, namespace=""): + secret_name = self.get_project_secret_name(project) + return self._get_secret_raw_data(secret_name, namespace) + + def _get_secret_raw_data(self, secret_name, namespace=""): + namespace = self.resolve_namespace(namespace) + + try: + k8s_secret = self.v1api.read_namespaced_secret(secret_name, namespace) + except ApiException: + return None + + return k8s_secret.data + + def get_project_secret_keys(self, project, namespace="", filter_internal=False): + secrets_data = self._get_project_secrets_raw_data(project, namespace) + if not secrets_data: + return [] + + secret_keys = list(secrets_data.keys()) + if filter_internal: + secret_keys = list( + filter(lambda key: not key.startswith("mlrun."), secret_keys) + ) + return secret_keys + + def get_project_secret_data(self, project, secret_keys=None, namespace=""): + secrets_data = self._get_project_secrets_raw_data(project, namespace) + return self._decode_secret_data(secrets_data, secret_keys) + + def get_secret_data(self, secret_name, namespace=""): + secrets_data = self._get_secret_raw_data(secret_name, namespace) + return self._decode_secret_data(secrets_data) + + def _decode_secret_data(self, secrets_data, secret_keys=None): + results = {} + if not secrets_data: + return results + + # If not asking for specific keys, return all + secret_keys = secret_keys or secrets_data.keys() + + for key in secret_keys: + encoded_value = secrets_data.get(key) + if encoded_value: + results[key] = base64.b64decode(secrets_data[key]).decode("utf-8") + return results + + +class BasePod: + def __init__( + self, + task_name="", + image=None, + command=None, + args=None, + namespace="", + kind="job", + project=None, + default_pod_spec_attributes=None, + resources=None, + ): + self.namespace = namespace + self.name = "" + self.task_name = task_name + self.image = image + self.command = command + self.args = args + self._volumes = [] + self._mounts = [] + self.env = None + self.node_selector = None + self.project = project or mlrun.mlconf.default_project + self._labels = { + "mlrun/task-name": task_name, + "mlrun/class": kind, + "mlrun/project": self.project, + } + self._annotations = {} + self._init_containers = [] + # will be applied on the pod spec only when calling .pod(), allows to override spec attributes + self.default_pod_spec_attributes = default_pod_spec_attributes + self.resources = resources + + @property + def pod(self): + return self._get_spec() + + @property + def init_containers(self): + return self._init_containers + + @init_containers.setter + def init_containers(self, containers): + self._init_containers = containers + + def append_init_container( + self, + image, + command=None, + args=None, + env=None, + image_pull_policy="IfNotPresent", + name="init", + ): + if isinstance(env, dict): + env = [client.V1EnvVar(name=k, value=v) for k, v in env.items()] + self._init_containers.append( + client.V1Container( + name=name, + image=image, + env=env, + command=command, + args=args, + image_pull_policy=image_pull_policy, + ) + ) + + def add_label(self, key, value): + self._labels[key] = str(value) + + def add_annotation(self, key, value): + self._annotations[key] = str(value) + + def add_volume(self, volume: client.V1Volume, mount_path, name=None, sub_path=None): + self._mounts.append( + client.V1VolumeMount( + name=name or volume.name, mount_path=mount_path, sub_path=sub_path + ) + ) + self._volumes.append(volume) + + def mount_empty(self, name="empty", mount_path="/empty"): + self.add_volume( + client.V1Volume(name=name, empty_dir=client.V1EmptyDirVolumeSource()), + mount_path=mount_path, + ) + + def mount_v3io( + self, name="v3io", remote="~/", mount_path="/User", access_key="", user="" + ): + self.add_volume( + mlrun.platforms.iguazio.v3io_to_vol(name, remote, access_key, user), + mount_path=mount_path, + name=name, + ) + + def mount_cfgmap(self, name, path="/config"): + self.add_volume( + client.V1Volume( + name=name, config_map=client.V1ConfigMapVolumeSource(name=name) + ), + mount_path=path, + ) + + def mount_secret(self, name, path="/secret", items=None, sub_path=None): + self.add_volume( + client.V1Volume( + name=name, + secret=client.V1SecretVolumeSource( + secret_name=name, + items=items, + ), + ), + mount_path=path, + sub_path=sub_path, + ) + + def set_node_selector(self, node_selector: typing.Optional[typing.Dict[str, str]]): + self.node_selector = node_selector + + def _get_spec(self, template=False): + + pod_obj = client.V1PodTemplate if template else client.V1Pod + + if self.env and isinstance(self.env, dict): + env = [client.V1EnvVar(name=k, value=v) for k, v in self.env.items()] + else: + env = self.env + container = client.V1Container( + name="base", + image=self.image, + env=env, + command=self.command, + args=self.args, + volume_mounts=self._mounts, + resources=self.resources, + ) + + pod_spec = client.V1PodSpec( + containers=[container], + restart_policy="Never", + volumes=self._volumes, + node_selector=self.node_selector, + ) + + # if attribute isn't defined use default pod spec attributes + for key, val in self.default_pod_spec_attributes.items(): + if not getattr(pod_spec, key, None): + setattr(pod_spec, key, val) + for init_containers in self._init_containers: + init_containers.volume_mounts = self._mounts + pod_spec.init_containers = self._init_containers -def get_k8s() -> K8sHelper: - return get_k8s_helper(silent=True) + pod = pod_obj( + metadata=client.V1ObjectMeta( + generate_name=f"{self.task_name}-", + namespace=self.namespace, + labels=self._labels, + annotations=self._annotations, + ), + spec=pod_spec, + ) + return pod diff --git a/mlrun/artifacts/__init__.py b/mlrun/artifacts/__init__.py index 627947c4579d..3e08428bbcb9 100644 --- a/mlrun/artifacts/__init__.py +++ b/mlrun/artifacts/__init__.py @@ -17,7 +17,7 @@ # Don't remove this, used by sphinx documentation __all__ = ["get_model", "update_model"] -from .base import Artifact, get_artifact_meta +from .base import Artifact, ArtifactMetadata, ArtifactSpec, get_artifact_meta from .dataset import DatasetArtifact, TableArtifact, update_dataset_meta from .manager import ArtifactManager, ArtifactProducer, dict_to_artifact from .model import ModelArtifact, get_model, update_model diff --git a/mlrun/artifacts/base.py b/mlrun/artifacts/base.py index 642550eefbaf..250a27ba3dc2 100644 --- a/mlrun/artifacts/base.py +++ b/mlrun/artifacts/base.py @@ -83,6 +83,7 @@ class ArtifactSpec(ModelObj): "size", "db_key", "extra_data", + "unpackaging_instructions", ] _extra_fields = ["annotations", "producer", "sources", "license", "encoding"] @@ -98,6 +99,7 @@ def __init__( db_key=None, extra_data=None, body=None, + unpackaging_instructions: dict = None, ): self.src_path = src_path self.target_path = target_path @@ -107,6 +109,7 @@ def __init__( self.size = size self.db_key = db_key self.extra_data = extra_data or {} + self.unpackaging_instructions = unpackaging_instructions self._body = body self.encoding = None diff --git a/mlrun/artifacts/dataset.py b/mlrun/artifacts/dataset.py index b7b3d1743b37..47483a417e10 100644 --- a/mlrun/artifacts/dataset.py +++ b/mlrun/artifacts/dataset.py @@ -22,10 +22,11 @@ from pandas.io.json import build_table_schema import mlrun +import mlrun.common.schemas import mlrun.utils.helpers from ..datastore import is_store_uri, store_manager -from .base import Artifact, ArtifactSpec, LegacyArtifact +from .base import Artifact, ArtifactSpec, LegacyArtifact, StorePrefix default_preview_rows_length = 20 max_preview_columns = 100 @@ -122,9 +123,10 @@ def __init__(self): class DatasetArtifact(Artifact): - kind = "dataset" + kind = mlrun.common.schemas.ArtifactCategories.dataset # List of all the supported saving formats of a DataFrame: SUPPORTED_FORMATS = ["csv", "parquet", "pq", "tsdb", "kv"] + _store_prefix = StorePrefix.Dataset def __init__( self, diff --git a/mlrun/artifacts/manager.py b/mlrun/artifacts/manager.py index 647ecba49cf0..fed8e36a55eb 100644 --- a/mlrun/artifacts/manager.py +++ b/mlrun/artifacts/manager.py @@ -191,6 +191,12 @@ def log_artifact( if db_key is None: # set the default artifact db key if producer.kind == "run": + # When the producer's type is "run," + # we generate a different db_key than the one we obtained in the request. + # As a result, a new artifact for the requested key will be created, + # which will contain the new db_key and will represent the current run. + # We implement this so that the user can query an artifact, + # and receive back all the runs that are associated with his search result. db_key = producer.name + "_" + key else: db_key = key diff --git a/tests/notebooks.yml b/mlrun/common/__init__.py similarity index 83% rename from tests/notebooks.yml rename to mlrun/common/__init__.py index 3887daddbc63..b3085be1eb56 100644 --- a/tests/notebooks.yml +++ b/mlrun/common/__init__.py @@ -12,9 +12,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # -- nb: mlrun_db.ipynb -- nb: mlrun_basics.ipynb - env: - MLRUN_DBPATH: /tmp/mlrun-db - pip: - - matplotlib diff --git a/mlrun/common/constants.py b/mlrun/common/constants.py new file mode 100644 index 000000000000..380ec1b97ab5 --- /dev/null +++ b/mlrun/common/constants.py @@ -0,0 +1,15 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +IMAGE_NAME_ENRICH_REGISTRY_PREFIX = "." # prefix for image name to enrich with registry diff --git a/mlrun/common/model_monitoring.py b/mlrun/common/model_monitoring.py new file mode 100644 index 000000000000..4093cd6cab6b --- /dev/null +++ b/mlrun/common/model_monitoring.py @@ -0,0 +1,209 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import enum +import hashlib +from dataclasses import dataclass +from typing import Optional + +import mlrun.utils + + +class EventFieldType: + FUNCTION_URI = "function_uri" + FUNCTION = "function" + MODEL_URI = "model_uri" + MODEL = "model" + VERSION = "version" + VERSIONED_MODEL = "versioned_model" + MODEL_CLASS = "model_class" + TIMESTAMP = "timestamp" + # `endpoint_id` is deprecated as a field in the model endpoint schema since 1.3.1, replaced by `uid`. + ENDPOINT_ID = "endpoint_id" + UID = "uid" + ENDPOINT_TYPE = "endpoint_type" + REQUEST_ID = "request_id" + RECORD_TYPE = "record_type" + FEATURES = "features" + FEATURE_NAMES = "feature_names" + NAMED_FEATURES = "named_features" + LABELS = "labels" + LATENCY = "latency" + LABEL_NAMES = "label_names" + PREDICTION = "prediction" + PREDICTIONS = "predictions" + NAMED_PREDICTIONS = "named_predictions" + ERROR_COUNT = "error_count" + ENTITIES = "entities" + FIRST_REQUEST = "first_request" + LAST_REQUEST = "last_request" + METRICS = "metrics" + TIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f" + BATCH_INTERVALS_DICT = "batch_intervals_dict" + DEFAULT_BATCH_INTERVALS = "default_batch_intervals" + MINUTES = "minutes" + HOURS = "hours" + DAYS = "days" + MODEL_ENDPOINTS = "model_endpoints" + STATE = "state" + PROJECT = "project" + STREAM_PATH = "stream_path" + ACTIVE = "active" + MONITORING_MODE = "monitoring_mode" + FEATURE_STATS = "feature_stats" + CURRENT_STATS = "current_stats" + CHILDREN = "children" + CHILDREN_UIDS = "children_uids" + DRIFT_MEASURES = "drift_measures" + DRIFT_STATUS = "drift_status" + MONITOR_CONFIGURATION = "monitor_configuration" + FEATURE_SET_URI = "monitoring_feature_set_uri" + ALGORITHM = "algorithm" + + +class EventLiveStats: + LATENCY_AVG_5M = "latency_avg_5m" + LATENCY_AVG_1H = "latency_avg_1h" + PREDICTIONS_PER_SECOND = "predictions_per_second" + PREDICTIONS_COUNT_5M = "predictions_count_5m" + PREDICTIONS_COUNT_1H = "predictions_count_1h" + + +class EventKeyMetrics: + BASE_METRICS = "base_metrics" + CUSTOM_METRICS = "custom_metrics" + ENDPOINT_FEATURES = "endpoint_features" + GENERIC = "generic" + REAL_TIME = "real_time" + + +class TimeSeriesTarget: + TSDB = "tsdb" + + +class ModelEndpointTarget: + V3IO_NOSQL = "v3io-nosql" + SQL = "sql" + + +class ProjectSecretKeys: + ENDPOINT_STORE_CONNECTION = "MODEL_MONITORING_ENDPOINT_STORE_CONNECTION" + ACCESS_KEY = "MODEL_MONITORING_ACCESS_KEY" + KAFKA_BOOTSTRAP_SERVERS = "KAFKA_BOOTSTRAP_SERVERS" + STREAM_PATH = "STREAM_PATH" + + +class ModelMonitoringStoreKinds: + ENDPOINTS = "endpoints" + EVENTS = "events" + + +class FileTargetKind: + ENDPOINTS = "endpoints" + EVENTS = "events" + STREAM = "stream" + PARQUET = "parquet" + LOG_STREAM = "log_stream" + + +class ModelMonitoringMode(str, enum.Enum): + enabled = "enabled" + disabled = "disabled" + + +class EndpointType(enum.IntEnum): + NODE_EP = 1 # end point that is not a child of a router + ROUTER = 2 # endpoint that is router + LEAF_EP = 3 # end point that is a child of a router + + +def create_model_endpoint_uid(function_uri: str, versioned_model: str): + function_uri = FunctionURI.from_string(function_uri) + versioned_model = VersionedModel.from_string(versioned_model) + + if ( + not function_uri.project + or not function_uri.function + or not versioned_model.model + ): + raise ValueError("Both function_uri and versioned_model have to be initialized") + + uid = EndpointUID( + function_uri.project, + function_uri.function, + function_uri.tag, + function_uri.hash_key, + versioned_model.model, + versioned_model.version, + ) + + return uid + + +@dataclass +class FunctionURI: + project: str + function: str + tag: Optional[str] = None + hash_key: Optional[str] = None + + @classmethod + def from_string(cls, function_uri): + project, uri, tag, hash_key = mlrun.utils.parse_versioned_object_uri( + function_uri + ) + return cls( + project=project, + function=uri, + tag=tag or None, + hash_key=hash_key or None, + ) + + +@dataclass +class VersionedModel: + model: str + version: Optional[str] + + @classmethod + def from_string(cls, model): + try: + model, version = model.split(":") + except ValueError: + model, version = model, None + + return cls(model, version) + + +@dataclass +class EndpointUID: + project: str + function: str + function_tag: str + function_hash_key: str + model: str + model_version: str + uid: Optional[str] = None + + def __post_init__(self): + function_ref = ( + f"{self.function}_{self.function_tag or self.function_hash_key or 'N/A'}" + ) + versioned_model = f"{self.model}_{self.model_version or 'N/A'}" + unique_string = f"{self.project}_{function_ref}_{versioned_model}" + self.uid = hashlib.sha1(unique_string.encode("utf-8")).hexdigest() + + def __str__(self): + return self.uid diff --git a/mlrun/common/schemas/__init__.py b/mlrun/common/schemas/__init__.py new file mode 100644 index 000000000000..b067fd6b79f2 --- /dev/null +++ b/mlrun/common/schemas/__init__.py @@ -0,0 +1,166 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx + +from .artifact import ArtifactCategories, ArtifactIdentifier, ArtifactsFormat +from .auth import ( + AuthInfo, + AuthorizationAction, + AuthorizationResourceTypes, + AuthorizationVerificationInput, + Credentials, + ProjectsRole, +) +from .background_task import ( + BackgroundTask, + BackgroundTaskMetadata, + BackgroundTaskSpec, + BackgroundTaskState, + BackgroundTaskStatus, +) +from .client_spec import ClientSpec +from .clusterization_spec import ( + ClusterizationSpec, + WaitForChiefToReachOnlineStateFeatureFlag, +) +from .constants import ( + APIStates, + ClusterizationRole, + DeletionStrategy, + FeatureStorePartitionByField, + HeaderNames, + LogsCollectorMode, + OrderType, + PatchMode, + RunPartitionByField, + SortField, +) +from .events import ( + AuthSecretEventActions, + EventClientKinds, + EventsModes, + SecretEventActions, +) +from .feature_store import ( + EntitiesOutput, + Entity, + EntityListOutput, + EntityRecord, + Feature, + FeatureListOutput, + FeatureRecord, + FeatureSet, + FeatureSetDigestOutput, + FeatureSetDigestSpec, + FeatureSetIngestInput, + FeatureSetIngestOutput, + FeatureSetRecord, + FeatureSetsOutput, + FeatureSetSpec, + FeatureSetsTagsOutput, + FeaturesOutput, + FeatureVector, + FeatureVectorRecord, + FeatureVectorsOutput, + FeatureVectorsTagsOutput, +) +from .frontend_spec import ( + AuthenticationFeatureFlag, + FeatureFlags, + FrontendSpec, + NuclioStreamsFeatureFlag, + PreemptionNodesFeatureFlag, + ProjectMembershipFeatureFlag, +) +from .function import FunctionState, PreemptionModes, SecurityContextEnrichmentModes +from .http import HTTPSessionRetryMode +from .hub import ( + HubCatalog, + HubItem, + HubObjectMetadata, + HubSource, + HubSourceSpec, + IndexedHubSource, + last_source_index, +) +from .k8s import NodeSelectorOperator, Resources, ResourceSpec +from .memory_reports import MostCommonObjectTypesReport, ObjectTypeReport +from .model_endpoints import ( + Features, + FeatureValues, + GrafanaColumn, + GrafanaDataPoint, + GrafanaNumberColumn, + GrafanaStringColumn, + GrafanaTable, + GrafanaTimeSeriesTarget, + ModelEndpoint, + ModelEndpointList, + ModelEndpointMetadata, + ModelEndpointSpec, + ModelEndpointStatus, + ModelMonitoringStoreKinds, +) +from .notification import ( + Notification, + NotificationKind, + NotificationSeverity, + NotificationStatus, + SetNotificationRequest, +) +from .object import ObjectKind, ObjectMetadata, ObjectSpec, ObjectStatus +from .pipeline import PipelinesFormat, PipelinesOutput, PipelinesPagination +from .project import ( + IguazioProject, + Project, + ProjectDesiredState, + ProjectMetadata, + ProjectOwner, + ProjectsFormat, + ProjectsOutput, + ProjectSpec, + ProjectState, + ProjectStatus, + ProjectSummariesOutput, + ProjectSummary, +) +from .runs import RunIdentifier +from .runtime_resource import ( + GroupedByJobRuntimeResourcesOutput, + GroupedByProjectRuntimeResourcesOutput, + KindRuntimeResources, + ListRuntimeResourcesGroupByField, + RuntimeResource, + RuntimeResources, + RuntimeResourcesOutput, +) +from .schedule import ( + ScheduleCronTrigger, + ScheduleIdentifier, + ScheduleInput, + ScheduleKinds, + ScheduleOutput, + ScheduleRecord, + SchedulesOutput, + ScheduleUpdate, +) +from .secret import ( + AuthSecretData, + SecretKeysData, + SecretProviderName, + SecretsData, + UserSecretCreationRequest, +) +from .tag import Tag, TagObjects diff --git a/mlrun/api/schemas/artifact.py b/mlrun/common/schemas/artifact.py similarity index 64% rename from mlrun/api/schemas/artifact.py rename to mlrun/common/schemas/artifact.py index 1474ff567060..a61200661cc9 100644 --- a/mlrun/api/schemas/artifact.py +++ b/mlrun/common/schemas/artifact.py @@ -16,31 +16,30 @@ import pydantic -import mlrun.api.utils.helpers +import mlrun.common.types -class ArtifactCategories(mlrun.api.utils.helpers.StrEnum): +class ArtifactCategories(mlrun.common.types.StrEnum): model = "model" dataset = "dataset" other = "other" - def to_kinds_filter(self) -> typing.Tuple[typing.List[str], bool]: - # FIXME: these artifact definitions (or at least the kinds enum) should sit in a dedicated module - # import here to prevent import cycle - import mlrun.artifacts.dataset - import mlrun.artifacts.model + # we define the link as a category to prevent import cycles, but it's not a real category + # and should not be used as such + link = "link" - link_kind = mlrun.artifacts.base.LinkArtifact.kind + def to_kinds_filter(self) -> typing.Tuple[typing.List[str], bool]: + link_kind = ArtifactCategories.link.value if self.value == ArtifactCategories.model.value: - return [mlrun.artifacts.model.ModelArtifact.kind, link_kind], False + return [ArtifactCategories.model.value, link_kind], False if self.value == ArtifactCategories.dataset.value: - return [mlrun.artifacts.dataset.DatasetArtifact.kind, link_kind], False + return [ArtifactCategories.dataset.value, link_kind], False if self.value == ArtifactCategories.other.value: return ( [ - mlrun.artifacts.model.ModelArtifact.kind, - mlrun.artifacts.dataset.DatasetArtifact.kind, + ArtifactCategories.model.value, + ArtifactCategories.dataset.value, ], True, ) @@ -56,6 +55,6 @@ class ArtifactIdentifier(pydantic.BaseModel): # hash: typing.Optional[str] -class ArtifactsFormat(mlrun.api.utils.helpers.StrEnum): +class ArtifactsFormat(mlrun.common.types.StrEnum): full = "full" legacy = "legacy" diff --git a/mlrun/api/schemas/auth.py b/mlrun/common/schemas/auth.py similarity index 91% rename from mlrun/api/schemas/auth.py rename to mlrun/common/schemas/auth.py index e6968525779b..c27ef378c844 100644 --- a/mlrun/api/schemas/auth.py +++ b/mlrun/common/schemas/auth.py @@ -18,17 +18,17 @@ from nuclio.auth import AuthInfo as NuclioAuthInfo from nuclio.auth import AuthKinds as NuclioAuthKinds -import mlrun.api.utils.helpers +import mlrun.common.types -class ProjectsRole(mlrun.api.utils.helpers.StrEnum): +class ProjectsRole(mlrun.common.types.StrEnum): iguazio = "iguazio" mlrun = "mlrun" nuclio = "nuclio" nop = "nop" -class AuthorizationAction(mlrun.api.utils.helpers.StrEnum): +class AuthorizationAction(mlrun.common.types.StrEnum): read = "read" create = "create" update = "update" @@ -39,7 +39,7 @@ class AuthorizationAction(mlrun.api.utils.helpers.StrEnum): store = "store" -class AuthorizationResourceTypes(mlrun.api.utils.helpers.StrEnum): +class AuthorizationResourceTypes(mlrun.common.types.StrEnum): project = "project" log = "log" runtime_resource = "runtime-resource" @@ -56,7 +56,7 @@ class AuthorizationResourceTypes(mlrun.api.utils.helpers.StrEnum): run = "run" model_endpoint = "model-endpoint" pipeline = "pipeline" - marketplace_source = "marketplace-source" + hub_source = "hub-source" def to_resource_string( self, @@ -85,8 +85,8 @@ def to_resource_string( AuthorizationResourceTypes.runtime_resource: "/projects/{project_name}/runtime-resources", AuthorizationResourceTypes.model_endpoint: "/projects/{project_name}/model-endpoints/{resource_name}", AuthorizationResourceTypes.pipeline: "/projects/{project_name}/pipelines/{resource_name}", - # Marketplace sources are not project-scoped, and auth is globally on the sources endpoint. - AuthorizationResourceTypes.marketplace_source: "/marketplace/sources", + # Hub sources are not project-scoped, and auth is globally on the sources endpoint. + AuthorizationResourceTypes.hub_source: "/hub/sources", }[self].format(project_name=project_name, resource_name=resource_name) diff --git a/mlrun/api/schemas/background_task.py b/mlrun/common/schemas/background_task.py similarity index 94% rename from mlrun/api/schemas/background_task.py rename to mlrun/common/schemas/background_task.py index 1a174cbb9459..a9fa1f25af2c 100644 --- a/mlrun/api/schemas/background_task.py +++ b/mlrun/common/schemas/background_task.py @@ -17,12 +17,12 @@ import pydantic -import mlrun.api.utils.helpers +import mlrun.common.types from .object import ObjectKind -class BackgroundTaskState(mlrun.api.utils.helpers.StrEnum): +class BackgroundTaskState(mlrun.common.types.StrEnum): succeeded = "succeeded" failed = "failed" running = "running" diff --git a/mlrun/api/schemas/client_spec.py b/mlrun/common/schemas/client_spec.py similarity index 100% rename from mlrun/api/schemas/client_spec.py rename to mlrun/common/schemas/client_spec.py diff --git a/mlrun/api/schemas/clusterization_spec.py b/mlrun/common/schemas/clusterization_spec.py similarity index 87% rename from mlrun/api/schemas/clusterization_spec.py rename to mlrun/common/schemas/clusterization_spec.py index 9f77d90e953d..1d9ed1bc7bb8 100644 --- a/mlrun/api/schemas/clusterization_spec.py +++ b/mlrun/common/schemas/clusterization_spec.py @@ -16,7 +16,7 @@ import pydantic -import mlrun.api.utils.helpers +import mlrun.common.types class ClusterizationSpec(pydantic.BaseModel): @@ -24,6 +24,6 @@ class ClusterizationSpec(pydantic.BaseModel): chief_version: typing.Optional[str] -class WaitForChiefToReachOnlineStateFeatureFlag(mlrun.api.utils.helpers.StrEnum): +class WaitForChiefToReachOnlineStateFeatureFlag(mlrun.common.types.StrEnum): enabled = "enabled" disabled = "disabled" diff --git a/mlrun/api/schemas/constants.py b/mlrun/common/schemas/constants.py similarity index 82% rename from mlrun/api/schemas/constants.py rename to mlrun/common/schemas/constants.py index 31d3897dfb46..2170af9453b7 100644 --- a/mlrun/api/schemas/constants.py +++ b/mlrun/common/schemas/constants.py @@ -14,11 +14,11 @@ # import mergedeep -import mlrun.api.utils.helpers +import mlrun.common.types import mlrun.errors -class PatchMode(mlrun.api.utils.helpers.StrEnum): +class PatchMode(mlrun.common.types.StrEnum): replace = "replace" additive = "additive" @@ -33,7 +33,7 @@ def to_mergedeep_strategy(self) -> mergedeep.Strategy: ) -class DeletionStrategy(mlrun.api.utils.helpers.StrEnum): +class DeletionStrategy(mlrun.common.types.StrEnum): restrict = "restrict" restricted = "restricted" cascade = "cascade" @@ -97,7 +97,7 @@ class HeaderNames: ui_clear_cache = f"{headers_prefix}ui-clear-cache" -class FeatureStorePartitionByField(mlrun.api.utils.helpers.StrEnum): +class FeatureStorePartitionByField(mlrun.common.types.StrEnum): name = "name" # Supported for feature-store objects def to_partition_by_db_field(self, db_cls): @@ -109,7 +109,7 @@ def to_partition_by_db_field(self, db_cls): ) -class RunPartitionByField(mlrun.api.utils.helpers.StrEnum): +class RunPartitionByField(mlrun.common.types.StrEnum): name = "name" # Supported for runs objects def to_partition_by_db_field(self, db_cls): @@ -121,7 +121,7 @@ def to_partition_by_db_field(self, db_cls): ) -class SortField(mlrun.api.utils.helpers.StrEnum): +class SortField(mlrun.common.types.StrEnum): created = "created" updated = "updated" @@ -139,7 +139,7 @@ def to_db_field(self, db_cls): ) -class OrderType(mlrun.api.utils.helpers.StrEnum): +class OrderType(mlrun.common.types.StrEnum): asc = "asc" desc = "desc" @@ -170,6 +170,19 @@ class APIStates: def terminal_states(): return [APIStates.online, APIStates.offline] + @staticmethod + def description(state: str): + return { + APIStates.online: "API is online", + APIStates.waiting_for_migrations: "API is waiting for migrations to be triggered. " + "Send POST request to /api/operations/migrations to trigger it", + APIStates.migrations_in_progress: "Migrations are in progress", + APIStates.migrations_failed: "Migrations failed, API can't be started", + APIStates.migrations_completed: "Migrations completed, API is waiting to become online", + APIStates.offline: "API is offline", + APIStates.waiting_for_chief: "API is waiting for chief to be ready", + }.get(state, f"Unknown API state '{state}'") + class ClusterizationRole: chief = "chief" diff --git a/mlrun/common/schemas/events.py b/mlrun/common/schemas/events.py new file mode 100644 index 000000000000..966d4078e9c9 --- /dev/null +++ b/mlrun/common/schemas/events.py @@ -0,0 +1,36 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import mlrun.common.types + + +class EventsModes(mlrun.common.types.StrEnum): + enabled = "enabled" + disabled = "disabled" + + +class EventClientKinds(mlrun.common.types.StrEnum): + iguazio = "iguazio" + nop = "nop" + + +class SecretEventActions(mlrun.common.types.StrEnum): + created = "created" + updated = "updated" + deleted = "deleted" + + +class AuthSecretEventActions(mlrun.common.types.StrEnum): + created = "created" + updated = "updated" diff --git a/mlrun/api/schemas/feature_store.py b/mlrun/common/schemas/feature_store.py similarity index 100% rename from mlrun/api/schemas/feature_store.py rename to mlrun/common/schemas/feature_store.py diff --git a/mlrun/api/schemas/frontend_spec.py b/mlrun/common/schemas/frontend_spec.py similarity index 88% rename from mlrun/api/schemas/frontend_spec.py rename to mlrun/common/schemas/frontend_spec.py index 35ff1c2febfd..d8821292bbc0 100644 --- a/mlrun/api/schemas/frontend_spec.py +++ b/mlrun/common/schemas/frontend_spec.py @@ -16,29 +16,29 @@ import pydantic -import mlrun.api.utils.helpers +import mlrun.common.types from .k8s import Resources -class ProjectMembershipFeatureFlag(mlrun.api.utils.helpers.StrEnum): +class ProjectMembershipFeatureFlag(mlrun.common.types.StrEnum): enabled = "enabled" disabled = "disabled" -class PreemptionNodesFeatureFlag(mlrun.api.utils.helpers.StrEnum): +class PreemptionNodesFeatureFlag(mlrun.common.types.StrEnum): enabled = "enabled" disabled = "disabled" -class AuthenticationFeatureFlag(mlrun.api.utils.helpers.StrEnum): +class AuthenticationFeatureFlag(mlrun.common.types.StrEnum): none = "none" basic = "basic" bearer = "bearer" iguazio = "iguazio" -class NuclioStreamsFeatureFlag(mlrun.api.utils.helpers.StrEnum): +class NuclioStreamsFeatureFlag(mlrun.common.types.StrEnum): enabled = "enabled" disabled = "disabled" diff --git a/mlrun/api/schemas/function.py b/mlrun/common/schemas/function.py similarity index 93% rename from mlrun/api/schemas/function.py rename to mlrun/common/schemas/function.py index 078f53bafdf1..ca5fd24421a6 100644 --- a/mlrun/api/schemas/function.py +++ b/mlrun/common/schemas/function.py @@ -16,10 +16,10 @@ import pydantic -import mlrun.api.utils.helpers +import mlrun.common.types -# Ideally we would want this to be class FunctionState(mlrun.api.utils.helpers.StrEnum) which is the +# Ideally we would want this to be class FunctionState(mlrun.common.types.StrEnum) which is the # "FastAPI-compatible" way of creating schemas # But, when we save a function to the DB, we pickle the body, which saves the state as an instance of this class (and # not just a string), then if for some reason we downgrade to 0.6.4, before we had this class, we fail reading (pickle @@ -46,7 +46,7 @@ class FunctionState: build = "build" -class PreemptionModes(mlrun.api.utils.helpers.StrEnum): +class PreemptionModes(mlrun.common.types.StrEnum): # makes function pods be able to run on preemptible nodes allow = "allow" # makes the function pods run on preemptible nodes only @@ -59,7 +59,7 @@ class PreemptionModes(mlrun.api.utils.helpers.StrEnum): # used when running in Iguazio (otherwise use disabled mode) # populates mlrun.mlconf.function.spec.security_context.enrichment_mode -class SecurityContextEnrichmentModes(mlrun.api.utils.helpers.StrEnum): +class SecurityContextEnrichmentModes(mlrun.common.types.StrEnum): # always use the user id of the user that triggered the 1st run / created the function # NOTE: this mode is incomplete and not fully supported yet retain = "retain" diff --git a/mlrun/api/schemas/http.py b/mlrun/common/schemas/http.py similarity index 87% rename from mlrun/api/schemas/http.py rename to mlrun/common/schemas/http.py index 640d75613df0..0b95a1e84f84 100644 --- a/mlrun/api/schemas/http.py +++ b/mlrun/common/schemas/http.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import mlrun.api.utils.helpers +import mlrun.common.types -class HTTPSessionRetryMode(mlrun.api.utils.helpers.StrEnum): +class HTTPSessionRetryMode(mlrun.common.types.StrEnum): enabled = "enabled" disabled = "disabled" diff --git a/mlrun/api/schemas/marketplace.py b/mlrun/common/schemas/hub.py similarity index 57% rename from mlrun/api/schemas/marketplace.py rename to mlrun/common/schemas/hub.py index fda43d3deff7..49b7ab1e9b1b 100644 --- a/mlrun/api/schemas/marketplace.py +++ b/mlrun/common/schemas/hub.py @@ -17,15 +17,15 @@ from pydantic import BaseModel, Extra, Field -import mlrun.api.utils.helpers +import mlrun.common.types import mlrun.errors -from mlrun.api.schemas.object import ObjectKind, ObjectSpec, ObjectStatus +from mlrun.common.schemas.object import ObjectKind, ObjectSpec, ObjectStatus from mlrun.config import config # Defining a different base class (not ObjectMetadata), as there's no project, and it differs enough to # justify a new class -class MarketplaceObjectMetadata(BaseModel): +class HubObjectMetadata(BaseModel): name: str description: str = "" labels: Optional[dict] = {} @@ -37,24 +37,22 @@ class Config: # Currently only functions are supported. Will add more in the future. -class MarketplaceSourceType(mlrun.api.utils.helpers.StrEnum): +class HubSourceType(mlrun.common.types.StrEnum): functions = "functions" # Sources-related objects -class MarketplaceSourceSpec(ObjectSpec): +class HubSourceSpec(ObjectSpec): path: str # URL to base directory, should include schema (s3://, etc...) channel: str credentials: Optional[dict] = {} - object_type: MarketplaceSourceType = Field( - MarketplaceSourceType.functions, const=True - ) + object_type: HubSourceType = Field(HubSourceType.functions, const=True) -class MarketplaceSource(BaseModel): - kind: ObjectKind = Field(ObjectKind.marketplace_source, const=True) - metadata: MarketplaceObjectMetadata - spec: MarketplaceSourceSpec +class HubSource(BaseModel): + kind: ObjectKind = Field(ObjectKind.hub_source, const=True) + metadata: HubObjectMetadata + spec: HubSourceSpec status: Optional[ObjectStatus] = ObjectStatus(state="created") def get_full_uri(self, relative_path): @@ -66,28 +64,26 @@ def get_full_uri(self, relative_path): ) def get_catalog_uri(self): - return self.get_full_uri(config.marketplace.catalog_filename) + return self.get_full_uri(config.hub.catalog_filename) @classmethod def generate_default_source(cls): - if not config.marketplace.default_source.create: + if not config.hub.default_source.create: return None now = datetime.now(timezone.utc) - hub_metadata = MarketplaceObjectMetadata( - name=config.marketplace.default_source.name, - description=config.marketplace.default_source.description, + hub_metadata = HubObjectMetadata( + name=config.hub.default_source.name, + description=config.hub.default_source.description, created=now, updated=now, ) return cls( metadata=hub_metadata, - spec=MarketplaceSourceSpec( - path=config.marketplace.default_source.url, - channel=config.marketplace.default_source.channel, - object_type=MarketplaceSourceType( - config.marketplace.default_source.object_type - ), + spec=HubSourceSpec( + path=config.hub.default_source.url, + channel=config.hub.default_source.channel, + object_type=HubSourceType(config.hub.default_source.object_type), ), status=ObjectStatus(state="created"), ) @@ -96,43 +92,43 @@ def generate_default_source(cls): last_source_index = -1 -class IndexedMarketplaceSource(BaseModel): +class IndexedHubSource(BaseModel): index: int = last_source_index # Default last. Otherwise, must be > 0 - source: MarketplaceSource + source: HubSource # Item-related objects -class MarketplaceItemMetadata(MarketplaceObjectMetadata): - source: MarketplaceSourceType = Field(MarketplaceSourceType.functions, const=True) +class HubItemMetadata(HubObjectMetadata): + source: HubSourceType = Field(HubSourceType.functions, const=True) version: str tag: Optional[str] def get_relative_path(self) -> str: - if self.source == MarketplaceSourceType.functions: - # This is needed since the marketplace deployment script modifies the paths to use _ instead of -. + if self.source == HubSourceType.functions: + # This is needed since the hub deployment script modifies the paths to use _ instead of -. modified_name = self.name.replace("-", "_") # Prefer using the tag if exists. Otherwise, use version. version = self.tag or self.version return f"{modified_name}/{version}/" else: raise mlrun.errors.MLRunInvalidArgumentError( - f"Bad source for marketplace item - {self.source}" + f"Bad source for hub item - {self.source}" ) -class MarketplaceItemSpec(ObjectSpec): +class HubItemSpec(ObjectSpec): item_uri: str + assets: Dict[str, str] = {} -class MarketplaceItem(BaseModel): - kind: ObjectKind = Field(ObjectKind.marketplace_item, const=True) - metadata: MarketplaceItemMetadata - spec: MarketplaceItemSpec +class HubItem(BaseModel): + kind: ObjectKind = Field(ObjectKind.hub_item, const=True) + metadata: HubItemMetadata + spec: HubItemSpec status: ObjectStatus - assets: Dict[str, str] = {} -class MarketplaceCatalog(BaseModel): - kind: ObjectKind = Field(ObjectKind.marketplace_catalog, const=True) +class HubCatalog(BaseModel): + kind: ObjectKind = Field(ObjectKind.hub_catalog, const=True) channel: str - catalog: List[MarketplaceItem] + catalog: List[HubItem] diff --git a/mlrun/api/schemas/k8s.py b/mlrun/common/schemas/k8s.py similarity index 93% rename from mlrun/api/schemas/k8s.py rename to mlrun/common/schemas/k8s.py index 3ab15cd3090b..ca93b16c340c 100644 --- a/mlrun/api/schemas/k8s.py +++ b/mlrun/common/schemas/k8s.py @@ -16,7 +16,7 @@ import pydantic -import mlrun.api.utils.helpers +import mlrun.common.types class ResourceSpec(pydantic.BaseModel): @@ -30,7 +30,7 @@ class Resources(pydantic.BaseModel): limits: ResourceSpec = ResourceSpec() -class NodeSelectorOperator(mlrun.api.utils.helpers.StrEnum): +class NodeSelectorOperator(mlrun.common.types.StrEnum): """ A node selector operator is the set of operators that can be used in a node selector requirement https://github.com/kubernetes/api/blob/b754a94214be15ffc8d648f9fe6481857f1fc2fe/core/v1/types.go#L2765 diff --git a/mlrun/api/schemas/memory_reports.py b/mlrun/common/schemas/memory_reports.py similarity index 100% rename from mlrun/api/schemas/memory_reports.py rename to mlrun/common/schemas/memory_reports.py diff --git a/mlrun/common/schemas/model_endpoints.py b/mlrun/common/schemas/model_endpoints.py new file mode 100644 index 000000000000..ec36c738601c --- /dev/null +++ b/mlrun/common/schemas/model_endpoints.py @@ -0,0 +1,342 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import enum +import json +import typing +from typing import Any, Dict, List, Optional, Tuple, Union + +from pydantic import BaseModel, Field +from pydantic.main import Extra + +import mlrun.common.model_monitoring +from mlrun.common.schemas.object import ObjectKind, ObjectSpec, ObjectStatus + + +class ModelMonitoringStoreKinds: + # TODO: do changes in examples & demos In 1.5.0 remove + ENDPOINTS = "endpoints" + EVENTS = "events" + + +class ModelEndpointMetadata(BaseModel): + project: Optional[str] = "" + labels: Optional[dict] = {} + uid: Optional[str] = "" + + class Config: + extra = Extra.allow + + @classmethod + def from_flat_dict(cls, endpoint_dict: dict, json_parse_values: typing.List = None): + """Create a `ModelEndpointMetadata` object from an endpoint dictionary + + :param endpoint_dict: Model endpoint dictionary. + :param json_parse_values: List of dictionary keys with a JSON string value that will be parsed into a + dictionary using json.loads(). + """ + new_object = cls() + if json_parse_values is None: + json_parse_values = [mlrun.common.model_monitoring.EventFieldType.LABELS] + + return _mapping_attributes( + base_model=new_object, + flattened_dictionary=endpoint_dict, + json_parse_values=json_parse_values, + ) + + +class ModelEndpointSpec(ObjectSpec): + function_uri: Optional[str] = "" # /: + model: Optional[str] = "" # : + model_class: Optional[str] = "" + model_uri: Optional[str] = "" + feature_names: Optional[List[str]] = [] + label_names: Optional[List[str]] = [] + stream_path: Optional[str] = "" + algorithm: Optional[str] = "" + monitor_configuration: Optional[dict] = {} + active: Optional[bool] = True + monitoring_mode: Optional[ + mlrun.common.model_monitoring.ModelMonitoringMode + ] = mlrun.common.model_monitoring.ModelMonitoringMode.disabled.value + + @classmethod + def from_flat_dict(cls, endpoint_dict: dict, json_parse_values: typing.List = None): + """Create a `ModelEndpointSpec` object from an endpoint dictionary + + :param endpoint_dict: Model endpoint dictionary. + :param json_parse_values: List of dictionary keys with a JSON string value that will be parsed into a + dictionary using json.loads(). + """ + new_object = cls() + if json_parse_values is None: + json_parse_values = [ + mlrun.common.model_monitoring.EventFieldType.FEATURE_NAMES, + mlrun.common.model_monitoring.EventFieldType.LABEL_NAMES, + mlrun.common.model_monitoring.EventFieldType.MONITOR_CONFIGURATION, + ] + return _mapping_attributes( + base_model=new_object, + flattened_dictionary=endpoint_dict, + json_parse_values=json_parse_values, + ) + + +class Histogram(BaseModel): + buckets: List[float] + counts: List[int] + + +class FeatureValues(BaseModel): + min: float + mean: float + max: float + histogram: Histogram + + @classmethod + def from_dict(cls, stats: Optional[dict]): + if stats: + return FeatureValues( + min=stats["min"], + mean=stats["mean"], + max=stats["max"], + histogram=Histogram(buckets=stats["hist"][1], counts=stats["hist"][0]), + ) + else: + return None + + +class Features(BaseModel): + name: str + weight: float + expected: Optional[FeatureValues] + actual: Optional[FeatureValues] + + @classmethod + def new( + cls, + feature_name: str, + feature_stats: Optional[dict], + current_stats: Optional[dict], + ): + return cls( + name=feature_name, + weight=-1.0, + expected=FeatureValues.from_dict(feature_stats), + actual=FeatureValues.from_dict(current_stats), + ) + + +class ModelEndpointStatus(ObjectStatus): + feature_stats: Optional[dict] = {} + current_stats: Optional[dict] = {} + first_request: Optional[str] = "" + last_request: Optional[str] = "" + error_count: Optional[int] = 0 + drift_status: Optional[str] = "" + drift_measures: Optional[dict] = {} + metrics: Optional[Dict[str, Dict[str, Any]]] = { + mlrun.common.model_monitoring.EventKeyMetrics.GENERIC: { + mlrun.common.model_monitoring.EventLiveStats.LATENCY_AVG_1H: 0, + mlrun.common.model_monitoring.EventLiveStats.PREDICTIONS_PER_SECOND: 0, + } + } + features: Optional[List[Features]] = [] + children: Optional[List[str]] = [] + children_uids: Optional[List[str]] = [] + endpoint_type: Optional[ + mlrun.common.model_monitoring.EndpointType + ] = mlrun.common.model_monitoring.EndpointType.NODE_EP.value + monitoring_feature_set_uri: Optional[str] = "" + state: Optional[str] = "" + + class Config: + extra = Extra.allow + + @classmethod + def from_flat_dict(cls, endpoint_dict: dict, json_parse_values: typing.List = None): + """Create a `ModelEndpointStatus` object from an endpoint dictionary + + :param endpoint_dict: Model endpoint dictionary. + :param json_parse_values: List of dictionary keys with a JSON string value that will be parsed into a + dictionary using json.loads(). + """ + new_object = cls() + if json_parse_values is None: + json_parse_values = [ + mlrun.common.model_monitoring.EventFieldType.FEATURE_STATS, + mlrun.common.model_monitoring.EventFieldType.CURRENT_STATS, + mlrun.common.model_monitoring.EventFieldType.DRIFT_MEASURES, + mlrun.common.model_monitoring.EventFieldType.METRICS, + mlrun.common.model_monitoring.EventFieldType.CHILDREN, + mlrun.common.model_monitoring.EventFieldType.CHILDREN_UIDS, + mlrun.common.model_monitoring.EventFieldType.ENDPOINT_TYPE, + ] + return _mapping_attributes( + base_model=new_object, + flattened_dictionary=endpoint_dict, + json_parse_values=json_parse_values, + ) + + +class ModelEndpoint(BaseModel): + kind: ObjectKind = Field(ObjectKind.model_endpoint, const=True) + metadata: ModelEndpointMetadata = ModelEndpointMetadata() + spec: ModelEndpointSpec = ModelEndpointSpec() + status: ModelEndpointStatus = ModelEndpointStatus() + + class Config: + extra = Extra.allow + + def __init__(self, **data: Any): + super().__init__(**data) + if self.metadata.uid is None: + uid = mlrun.common.model_monitoring.create_model_endpoint_uid( + function_uri=self.spec.function_uri, + versioned_model=self.spec.model, + ) + self.metadata.uid = str(uid) + + def flat_dict(self): + """Generate a flattened `ModelEndpoint` dictionary. The flattened dictionary result is important for storing + the model endpoint object in the database. + + :return: Flattened `ModelEndpoint` dictionary. + """ + # Convert the ModelEndpoint object into a dictionary using BaseModel dict() function + # In addition, remove the BaseModel kind as it is not required by the DB schema + model_endpoint_dictionary = self.dict(exclude={"kind"}) + + # Initialize a flattened dictionary that will be filled with the model endpoint dictionary attributes + flatten_dict = {} + for k_object in model_endpoint_dictionary: + for key in model_endpoint_dictionary[k_object]: + # Extract the value of the current field + current_value = model_endpoint_dictionary[k_object][key] + + # If the value is not from type str or bool (e.g. dict), convert it into a JSON string + # for matching the database required format + if not isinstance(current_value, (str, bool, int)) or isinstance( + current_value, enum.IntEnum + ): + flatten_dict[key] = json.dumps(current_value) + else: + flatten_dict[key] = current_value + + if mlrun.common.model_monitoring.EventFieldType.METRICS not in flatten_dict: + # Initialize metrics dictionary + flatten_dict[mlrun.common.model_monitoring.EventFieldType.METRICS] = { + mlrun.common.model_monitoring.EventKeyMetrics.GENERIC: { + mlrun.common.model_monitoring.EventLiveStats.LATENCY_AVG_1H: 0, + mlrun.common.model_monitoring.EventLiveStats.PREDICTIONS_PER_SECOND: 0, + } + } + + # Remove the features from the dictionary as this field will be filled only within the feature analysis process + flatten_dict.pop(mlrun.common.model_monitoring.EventFieldType.FEATURES, None) + return flatten_dict + + @classmethod + def from_flat_dict(cls, endpoint_dict: dict): + """Create a `ModelEndpoint` object from an endpoint flattened dictionary. Because the provided dictionary + is flattened, we pass it as is to the subclasses without splitting the keys into spec, metadata, and status. + + :param endpoint_dict: Model endpoint dictionary. + """ + + return cls( + metadata=ModelEndpointMetadata.from_flat_dict(endpoint_dict=endpoint_dict), + spec=ModelEndpointSpec.from_flat_dict(endpoint_dict=endpoint_dict), + status=ModelEndpointStatus.from_flat_dict(endpoint_dict=endpoint_dict), + ) + + +class ModelEndpointList(BaseModel): + endpoints: List[ModelEndpoint] = [] + + +class GrafanaColumn(BaseModel): + text: str + type: str + + +class GrafanaNumberColumn(GrafanaColumn): + text: str + type: str = "number" + + +class GrafanaStringColumn(GrafanaColumn): + text: str + type: str = "string" + + +class GrafanaTable(BaseModel): + columns: List[GrafanaColumn] + rows: List[List[Optional[Union[float, int, str]]]] = [] + type: str = "table" + + def add_row(self, *args): + self.rows.append(list(args)) + + +class GrafanaDataPoint(BaseModel): + value: float + timestamp: int # Unix timestamp in milliseconds + + +class GrafanaTimeSeriesTarget(BaseModel): + target: str + datapoints: List[Tuple[float, int]] = [] + + def add_data_point(self, data_point: GrafanaDataPoint): + self.datapoints.append((data_point.value, data_point.timestamp)) + + +def _mapping_attributes( + base_model: BaseModel, + flattened_dictionary: dict, + json_parse_values: typing.List = None, +): + """Generate a `BaseModel` object with the provided dictionary attributes. + + :param base_model: `BaseModel` object (e.g. `ModelEndpointMetadata`). + :param flattened_dictionary: Flattened dictionary that contains the model endpoint attributes. + :param json_parse_values: List of dictionary keys with a JSON string value that will be parsed into a + dictionary using json.loads(). + """ + # Get the fields of the provided base model object. These fields will be used to filter to relevent keys + # from the flattened dictionary. + wanted_keys = base_model.__fields__.keys() + + # Generate a filtered flattened dictionary that will be parsed into the BaseModel object + dict_to_parse = {} + for field_key in wanted_keys: + if field_key in flattened_dictionary: + if field_key in json_parse_values: + # Parse the JSON value into a valid dictionary + dict_to_parse[field_key] = _json_loads_if_not_none( + flattened_dictionary[field_key] + ) + else: + dict_to_parse[field_key] = flattened_dictionary[field_key] + + return base_model.parse_obj(dict_to_parse) + + +def _json_loads_if_not_none(field: Any) -> Any: + return ( + json.loads(field) if field and field != "null" and field is not None else None + ) diff --git a/mlrun/common/schemas/notification.py b/mlrun/common/schemas/notification.py new file mode 100644 index 000000000000..cc489c97f32b --- /dev/null +++ b/mlrun/common/schemas/notification.py @@ -0,0 +1,57 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import typing + +import pydantic + +import mlrun.common.types + + +class NotificationKind(mlrun.common.types.StrEnum): + console = "console" + git = "git" + ipython = "ipython" + slack = "slack" + + +class NotificationSeverity(mlrun.common.types.StrEnum): + INFO = "info" + DEBUG = "debug" + VERBOSE = "verbose" + WARNING = "warning" + ERROR = "error" + + +class NotificationStatus(mlrun.common.types.StrEnum): + PENDING = "pending" + SENT = "sent" + ERROR = "error" + + +class Notification(pydantic.BaseModel): + kind: NotificationKind = None + name: str = None + message: str = None + severity: NotificationSeverity = None + when: typing.List[str] = None + condition: str = None + params: typing.Dict[str, typing.Any] = None + status: NotificationStatus = None + sent_time: typing.Union[str, datetime.datetime] = None + + +class SetNotificationRequest(pydantic.BaseModel): + notifications: typing.List[Notification] = None diff --git a/mlrun/api/schemas/object.py b/mlrun/common/schemas/object.py similarity index 89% rename from mlrun/api/schemas/object.py rename to mlrun/common/schemas/object.py index e5f34746a7d6..f0cad67021a3 100644 --- a/mlrun/api/schemas/object.py +++ b/mlrun/common/schemas/object.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Extra -import mlrun.api.utils.helpers +import mlrun.common.types class ObjectMetadata(BaseModel): @@ -69,12 +69,12 @@ class Config: orm_mode = True -class ObjectKind(mlrun.api.utils.helpers.StrEnum): +class ObjectKind(mlrun.common.types.StrEnum): project = "project" feature_set = "FeatureSet" background_task = "BackgroundTask" feature_vector = "FeatureVector" model_endpoint = "model-endpoint" - marketplace_source = "MarketplaceSource" - marketplace_item = "MarketplaceItem" - marketplace_catalog = "MarketplaceCatalog" + hub_source = "HubSource" + hub_item = "HubItem" + hub_catalog = "HubCatalog" diff --git a/mlrun/api/schemas/pipeline.py b/mlrun/common/schemas/pipeline.py similarity index 92% rename from mlrun/api/schemas/pipeline.py rename to mlrun/common/schemas/pipeline.py index 30211c158c8c..e1e3815794dc 100644 --- a/mlrun/api/schemas/pipeline.py +++ b/mlrun/common/schemas/pipeline.py @@ -16,10 +16,10 @@ import pydantic -import mlrun.api.utils.helpers +import mlrun.common.types -class PipelinesFormat(mlrun.api.utils.helpers.StrEnum): +class PipelinesFormat(mlrun.common.types.StrEnum): full = "full" metadata_only = "metadata_only" summary = "summary" diff --git a/mlrun/api/schemas/project.py b/mlrun/common/schemas/project.py similarity index 89% rename from mlrun/api/schemas/project.py rename to mlrun/common/schemas/project.py index 6d81446d077b..40afe67792ed 100644 --- a/mlrun/api/schemas/project.py +++ b/mlrun/common/schemas/project.py @@ -17,14 +17,17 @@ import pydantic -import mlrun.api.utils.helpers +import mlrun.common.types from .object import ObjectKind, ObjectStatus -class ProjectsFormat(mlrun.api.utils.helpers.StrEnum): +class ProjectsFormat(mlrun.common.types.StrEnum): full = "full" name_only = "name_only" + # minimal format removes large fields from the response (e.g. functions, workflows, artifacts) + # and is used for faster response times (in the UI) + minimal = "minimal" # internal - allowed only in follower mode, only for the leader for upgrade purposes leader = "leader" @@ -39,13 +42,13 @@ class Config: extra = pydantic.Extra.allow -class ProjectDesiredState(mlrun.api.utils.helpers.StrEnum): +class ProjectDesiredState(mlrun.common.types.StrEnum): online = "online" offline = "offline" archived = "archived" -class ProjectState(mlrun.api.utils.helpers.StrEnum): +class ProjectState(mlrun.common.types.StrEnum): unknown = "unknown" creating = "creating" deleting = "deleting" @@ -80,6 +83,7 @@ class ProjectSpec(pydantic.BaseModel): subpath: typing.Optional[str] = None origin_url: typing.Optional[str] = None desired_state: typing.Optional[ProjectDesiredState] = ProjectDesiredState.online + custom_packagers: typing.Optional[typing.List[typing.Tuple[str, bool]]] = None class Config: extra = pydantic.Extra.allow diff --git a/mlrun/common/schemas/runs.py b/mlrun/common/schemas/runs.py new file mode 100644 index 000000000000..4c8abfd0758d --- /dev/null +++ b/mlrun/common/schemas/runs.py @@ -0,0 +1,30 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +# TODO: When we remove support for python 3.7, we can use Literal from the typing package. +# Remove the following try/except block with import from typing_extensions. +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +import pydantic + + +class RunIdentifier(pydantic.BaseModel): + kind: Literal["run"] = "run" + uid: typing.Optional[str] + iter: typing.Optional[int] diff --git a/mlrun/api/schemas/runtime_resource.py b/mlrun/common/schemas/runtime_resource.py similarity index 93% rename from mlrun/api/schemas/runtime_resource.py rename to mlrun/common/schemas/runtime_resource.py index 3fb9d204b279..332c27b67086 100644 --- a/mlrun/api/schemas/runtime_resource.py +++ b/mlrun/common/schemas/runtime_resource.py @@ -16,10 +16,10 @@ import pydantic -import mlrun.api.utils.helpers +import mlrun.common.types -class ListRuntimeResourcesGroupByField(mlrun.api.utils.helpers.StrEnum): +class ListRuntimeResourcesGroupByField(mlrun.common.types.StrEnum): job = "job" project = "project" diff --git a/mlrun/api/schemas/schedule.py b/mlrun/common/schemas/schedule.py similarity index 88% rename from mlrun/api/schemas/schedule.py rename to mlrun/common/schemas/schedule.py index 08a6df5822d3..3ef981b3989f 100644 --- a/mlrun/api/schemas/schedule.py +++ b/mlrun/common/schemas/schedule.py @@ -15,11 +15,18 @@ from datetime import datetime from typing import Any, List, Optional, Union +# TODO: When we remove support for python 3.7, we can use Literal from the typing package. +# Remove the following try/except block with import from typing_extensions. +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + from pydantic import BaseModel -import mlrun.api.utils.helpers -from mlrun.api.schemas.auth import Credentials -from mlrun.api.schemas.object import LabelRecord +import mlrun.common.types +from mlrun.common.schemas.auth import Credentials +from mlrun.common.schemas.object import LabelRecord class ScheduleCronTrigger(BaseModel): @@ -78,7 +85,7 @@ def to_crontab(self) -> str: return f"{self.minute} {self.hour} {self.day} {self.month} {self.day_of_week}" -class ScheduleKinds(mlrun.api.utils.helpers.StrEnum): +class ScheduleKinds(mlrun.common.types.StrEnum): job = "job" pipeline = "pipeline" @@ -136,3 +143,8 @@ class ScheduleOutput(ScheduleRecord): class SchedulesOutput(BaseModel): schedules: List[ScheduleOutput] + + +class ScheduleIdentifier(BaseModel): + kind: Literal["schedule"] = "schedule" + name: str diff --git a/mlrun/api/schemas/secret.py b/mlrun/common/schemas/secret.py similarity index 93% rename from mlrun/api/schemas/secret.py rename to mlrun/common/schemas/secret.py index 5b842d5dadc9..27cac5d6a62d 100644 --- a/mlrun/api/schemas/secret.py +++ b/mlrun/common/schemas/secret.py @@ -16,10 +16,10 @@ from pydantic import BaseModel, Field -import mlrun.api.utils.helpers +import mlrun.common.types -class SecretProviderName(mlrun.api.utils.helpers.StrEnum): +class SecretProviderName(mlrun.common.types.StrEnum): """Enum containing names of valid providers for secrets.""" vault = "vault" diff --git a/mlrun/api/schemas/tag.py b/mlrun/common/schemas/tag.py similarity index 93% rename from mlrun/api/schemas/tag.py rename to mlrun/common/schemas/tag.py index 90d3bd3a2670..2bcab5ef7be6 100644 --- a/mlrun/api/schemas/tag.py +++ b/mlrun/common/schemas/tag.py @@ -29,4 +29,4 @@ class TagObjects(pydantic.BaseModel): kind: str # TODO: Add more types to the list for new supported tagged objects - identifiers: typing.List[typing.Union[ArtifactIdentifier]] + identifiers: typing.List[ArtifactIdentifier] diff --git a/mlrun/api/schemas/notification.py b/mlrun/common/types.py similarity index 64% rename from mlrun/api/schemas/notification.py rename to mlrun/common/types.py index ac5591c0addd..92ce98e61e98 100644 --- a/mlrun/api/schemas/notification.py +++ b/mlrun/common/types.py @@ -11,20 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# +import enum -import mlrun.api.utils.helpers - - -class NotificationSeverity(mlrun.api.utils.helpers.StrEnum): - INFO = "info" - DEBUG = "debug" - VERBOSE = "verbose" - WARNING = "warning" - ERROR = "error" +# TODO: From python 3.11 StrEnum is built-in and this will not be needed +class StrEnum(str, enum.Enum): + def __str__(self): + return self.value -class NotificationStatus(mlrun.api.utils.helpers.StrEnum): - PENDING = "pending" - SENT = "sent" - ERROR = "error" + def __repr__(self): + return self.value diff --git a/mlrun/config.py b/mlrun/config.py index c289eaea0b0d..32765b7c93f7 100644 --- a/mlrun/config.py +++ b/mlrun/config.py @@ -48,6 +48,10 @@ default_config = { "namespace": "", # default kubernetes namespace + "kubernetes": { + "kubeconfig_path": "", # local path to kubeconfig file (for development purposes), + # empty by default as the API already running inside k8s cluster + }, "dbpath": "", # db/api url # url to nuclio dashboard api (can be with user & token, e.g. https://username:password@dashboard-url.com) "nuclio_dashboard_url": "", @@ -74,9 +78,10 @@ "spark_app_image_tag": "", # image tag to use for spark operator app runtime "spark_history_server_path": "", # spark logs directory for spark history server "spark_operator_version": "spark-3", # the version of the spark operator in use - "builder_alpine_image": "alpine:3.13.1", # builder alpine image (as kaniko's initContainer) "package_path": "mlrun", # mlrun pip package "default_base_image": "mlrun/mlrun", # default base image when doing .deploy() + # template for project default image name. Parameter {name} will be replaced with project name + "default_project_image_name": ".mlrun-project-image-{name}", "default_project": "default", # default project name "default_archive": "", # default remote archive URL (for build tar.gz) "mpijob_crd_version": "", # mpijob crd version (e.g: "v1alpha1". must be in: mlrun.runtime.MPIJobCRDVersions) @@ -155,7 +160,7 @@ # default security context to be applied to all functions - json string base64 encoded format # in camelCase format: {"runAsUser": 1000, "runAsGroup": 3000} "default": "e30=", # encoded empty dict - # see mlrun.api.schemas.function.SecurityContextEnrichmentModes for available options + # see mlrun.common.schemas.function.SecurityContextEnrichmentModes for available options "enrichment_mode": "disabled", # default 65534 (nogroup), set to -1 to use the user unix id or # function.spec.security_context.pipelines.kfp_pod_user_unix_id for kfp pods @@ -178,7 +183,7 @@ "mpijob": "mlrun/ml-models", }, # see enrich_function_preemption_spec for more info, - # and mlrun.api.schemas.function.PreemptionModes for available options + # and mlrun.common.schemas.function.PreemptionModes for available options "preemption_mode": "prevent", }, "httpdb": { @@ -219,7 +224,7 @@ "allowed_file_paths": "s3://,gcs://,gs://,az://", "db_type": "sqldb", "max_workers": 64, - # See mlrun.api.schemas.APIStates for options + # See mlrun.common.schemas.APIStates for options "state": "online", "retry_api_call_on_exception": "enabled", "http_connection_timeout_keep_alive": 11, @@ -230,10 +235,10 @@ "conflict_retry_interval": None, # Whether to perform data migrations on initialization. enabled or disabled "data_migrations_mode": "enabled", - # Whether or not to perform database migration from sqlite to mysql on initialization + # Whether to perform database migration from sqlite to mysql on initialization "database_migration_mode": "enabled", "backup": { - # Whether or not to use db backups on initialization + # Whether to use db backups on initialization "mode": "enabled", "file_format": "db_backup_%Y%m%d%H%M.db", "use_rotation": True, @@ -244,6 +249,14 @@ # None will set this to be equal to the httpdb.max_workers "connections_pool_size": None, "connections_pool_max_overflow": None, + # below is a db-specific configuration + "mysql": { + # comma separated mysql modes (globally) to set on runtime + # optional values (as per https://dev.mysql.com/doc/refman/8.0/en/sql-mode.html#sql-mode-full): + # + # if set to "nil" or "none", nothing would be set + "modes": "STRICT_TRANS_TABLES", + }, }, "jobs": { # whether to allow to run local runtimes in the API - configurable to allow the scheduler testing to work @@ -357,9 +370,12 @@ # git+https://github.com/mlrun/mlrun@development. by default uses the version "mlrun_version_specifier": "", "kaniko_image": "gcr.io/kaniko-project/executor:v1.8.0", # kaniko builder image - "kaniko_init_container_image": "alpine:3.13.1", + "kaniko_init_container_image": "alpine:3.18", # image for kaniko init container when docker registry is ECR "kaniko_aws_cli_image": "amazon/aws-cli:2.7.10", + # kaniko sometimes fails to get filesystem from image, this is a workaround to retry the process + # a known issue in Kaniko - https://github.com/GoogleContainerTools/kaniko/issues/1717 + "kaniko_image_fs_extraction_retries": "3", # additional docker build args in json encoded base64 format "build_args": "", "pip_ca_secret_name": "", @@ -372,18 +388,37 @@ }, "v3io_api": "", "v3io_framesd": "", + # If running from sdk and MLRUN_DBPATH is not set, the db will fallback to a nop db which will not preform any + # run db operations. + "nop_db": { + # if set to true, will raise an error for trying to use run db functionality + # if set to false, will use a nop db which will not preform any run db operations + "raise_error": False, + # if set to true, will log a warning for trying to use run db functionality while in nop db mode + "verbose": True, + }, }, "model_endpoint_monitoring": { "serving_stream_args": {"shard_count": 1, "retention_period_hours": 24}, "drift_thresholds": {"default": {"possible_drift": 0.5, "drift_detected": 0.7}}, + # Store prefixes are used to handle model monitoring storing policies based on project and kind, such as events, + # stream, and endpoints. "store_prefixes": { "default": "v3io:///users/pipelines/{project}/model-endpoints/{kind}", "user_space": "v3io:///projects/{project}/model-endpoints/{kind}", + "stream": "", }, + # Offline storage path can be either relative or a full path. This path is used for general offline data + # storage such as the parquet file which is generated from the monitoring stream function for the drift analysis + "offline_storage_path": "model-endpoints/{kind}", + # Default http path that points to the monitoring stream nuclio function. Will be used as a stream path + # when the user is working in CE environment and has not provided any stream path. + "default_http_sink": "http://nuclio-{project}-model-monitoring-stream.mlrun.svc.cluster.local:8080", "batch_processing_function_branch": "master", "parquet_batching_max_events": 10000, - # See mlrun.api.schemas.ModelEndpointStoreType for available options - "store_type": "kv", + # See mlrun.common.schemas.ModelEndpointStoreType for available options + "store_type": "v3io-nosql", + "endpoint_store_connection": "", }, "secret_stores": { "vault": { @@ -427,8 +462,8 @@ "projects_prefix": "projects", # The UI link prefix for projects "url": "", # remote/external mlrun UI url (for hyperlinks) }, - "marketplace": { - "k8s_secrets_project_name": "-marketplace-secrets", + "hub": { + "k8s_secrets_project_name": "-hub-secrets", "catalog_filename": "catalog.json", "default_source": { # Set false to avoid creating a global source (for example in a dark site) @@ -508,6 +543,27 @@ # interval for stopping log collection for runs which are in a terminal state "stop_logs_interval": 3600, }, + # Configurations for the `mlrun.package` sub-package involving packagers - logging returned outputs and parsing + # inputs data items: + "packagers": { + # Whether to enable packagers. True will wrap each run in the `mlrun.package.handler` decorator to log and parse + # using packagers. + "enabled": True, + # Whether to treat returned tuples from functions as a tuple and not as multiple returned items. If True, all + # returned values will be packaged together as the tuple they are returned in. Default is False to enable + # logging multiple returned items. + "pack_tuples": False, + }, + # Events are currently (and only) used to audit changes and record access to MLRun entities (such as secrets) + "events": { + # supported modes "enabled", "disabled". + # "enabled" - events are emitted. + # "disabled" - a nop client is used (aka doing nothing). + "mode": "enabled", + "verbose": False, + # used for igz client when emitting events + "access_key": "", + }, } _is_running_as_api = None @@ -518,8 +574,7 @@ def is_running_as_api(): global _is_running_as_api if _is_running_as_api is None: - # os.getenv will load the env var as string, and json.loads will convert it to a bool - _is_running_as_api = json.loads(os.getenv("MLRUN_IS_API_SERVER", "false")) + _is_running_as_api = os.getenv("MLRUN_IS_API_SERVER", "false").lower() == "true" return _is_running_as_api @@ -927,6 +982,68 @@ def get_v3io_access_key(self): # Get v3io access key from the environment return os.environ.get("V3IO_ACCESS_KEY") + def get_model_monitoring_file_target_path( + self, + project: str = "", + kind: str = "", + target: str = "online", + artifact_path: str = None, + ) -> str: + """Get the full path from the configuration based on the provided project and kind. + + :param project: Project name. + :param kind: Kind of target path (e.g. events, log_stream, endpoints, etc.) + :param target: Can be either online or offline. If the target is online, then we try to get a specific + path for the provided kind. If it doesn't exist, use the default path. + If the target path is offline and the offline path is already a full path in the + configuration, then the result will be that path as-is. If the offline path is a + relative path, then the result will be based on the project artifact path and the offline + relative path. If project artifact path wasn't provided, then we use MLRun artifact + path instead. + :param artifact_path: Optional artifact path that will be used as a relative path. If not provided, the + relative artifact path will be taken from the global MLRun artifact path. + + :return: Full configured path for the provided kind. + """ + + if target != "offline": + store_prefix_dict = ( + mlrun.mlconf.model_endpoint_monitoring.store_prefixes.to_dict() + ) + if store_prefix_dict.get(kind): + # Target exist in store prefix and has a valid string value + return store_prefix_dict[kind].format(project=project) + return mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( + project=project, kind=kind + ) + + # Get the current offline path from the configuration + file_path = mlrun.mlconf.model_endpoint_monitoring.offline_storage_path.format( + project=project, kind=kind + ) + + # Absolute path + if any(value in file_path for value in ["://", ":///"]) or os.path.isabs( + file_path + ): + return file_path + + # Relative path + else: + artifact_path = artifact_path or config.artifact_path + if artifact_path[-1] != "/": + artifact_path += "/" + + return mlrun.utils.helpers.fill_artifact_path_template( + artifact_path=artifact_path + file_path, project=project + ) + + def is_ce_mode(self) -> bool: + # True if the setup is in CE environment + return isinstance(mlrun.mlconf.ce, mlrun.config.Config) and any( + ver in mlrun.mlconf.ce.mode for ver in ["lite", "full"] + ) + # Global configuration config = Config.from_dict(default_config) @@ -947,7 +1064,7 @@ def _populate(skip_errors=False): def _do_populate(env=None, skip_errors=False): global config - if not os.environ.get("MLRUN_IGNORE_ENV_FILE") and not is_running_as_api(): + if not os.environ.get("MLRUN_IGNORE_ENV_FILE"): if "MLRUN_ENV_FILE" in os.environ: env_file = os.path.expanduser(os.environ["MLRUN_ENV_FILE"]) dotenv.load_dotenv(env_file, override=True) @@ -984,12 +1101,10 @@ def _do_populate(env=None, skip_errors=False): def _validate_config(config): - import mlrun.k8s_utils - try: limits_gpu = config.default_function_pod_resources.limits.gpu requests_gpu = config.default_function_pod_resources.requests.gpu - mlrun.k8s_utils.verify_gpu_requests_and_limits( + _verify_gpu_requests_and_limits( requests_gpu=requests_gpu, limits_gpu=limits_gpu, ) @@ -999,6 +1114,19 @@ def _validate_config(config): config.verify_security_context_enrichment_mode_is_allowed() +def _verify_gpu_requests_and_limits(requests_gpu: str = None, limits_gpu: str = None): + # https://kubernetes.io/docs/tasks/manage-gpus/scheduling-gpus/ + if requests_gpu and not limits_gpu: + raise mlrun.errors.MLRunConflictError( + "You cannot specify GPU requests without specifying limits" + ) + if requests_gpu and limits_gpu and requests_gpu != limits_gpu: + raise mlrun.errors.MLRunConflictError( + f"When specifying both GPU requests and limits these two values must be equal, " + f"requests_gpu={requests_gpu}, limits_gpu={limits_gpu}" + ) + + def _convert_resources_to_str(config: dict = None): resources_types = ["cpu", "memory", "gpu"] resource_requirements = ["requests", "limits"] @@ -1049,15 +1177,18 @@ def read_env(env=None, prefix=env_prefix): cfg[path[0]] = value env_dbpath = env.get("MLRUN_DBPATH", "") + # expected format: https://mlrun-api.tenant.default-tenant.app.some-system.some-namespace.com is_remote_mlrun = ( env_dbpath.startswith("https://mlrun-api.") and "tenant." in env_dbpath ) + # It's already a standard to set this env var to configure the v3io api, so we're supporting it (instead # of MLRUN_V3IO_API), in remote usage this can be auto detected from the DBPATH v3io_api = env.get("V3IO_API") if v3io_api: config["v3io_api"] = v3io_api elif is_remote_mlrun: + # in remote mlrun we can't use http, so we'll use https config["v3io_api"] = env_dbpath.replace("https://mlrun-api.", "https://webapi.") # It's already a standard to set this env var to configure the v3io framesd, so we're supporting it (instead diff --git a/mlrun/data_types/spark.py b/mlrun/data_types/spark.py index 2e6ef44eef53..9da70288e054 100644 --- a/mlrun/data_types/spark.py +++ b/mlrun/data_types/spark.py @@ -16,6 +16,8 @@ from os import environ import numpy as np +import pytz +from pyspark.sql.functions import to_utc_timestamp from pyspark.sql.types import BooleanType, DoubleType, TimestampType from mlrun.utils import logger @@ -143,6 +145,9 @@ def get_df_stats_spark(df, options, num_bins=20, sample_size=None): is_timestamp = isinstance(field.dataType, TimestampType) is_boolean = isinstance(field.dataType, BooleanType) if is_timestamp: + df_after_type_casts = df_after_type_casts.withColumn( + field.name, to_utc_timestamp(df_after_type_casts[field.name], "UTC") + ) timestamp_columns.add(field.name) if is_boolean: boolean_columns.add(field.name) @@ -210,11 +215,13 @@ def get_df_stats_spark(df, options, num_bins=20, sample_size=None): if col in timestamp_columns: for stat, val in stats.items(): if stat == "mean" or stat in original_type_stats: - stats[stat] = datetime.fromtimestamp(val).isoformat() + stats[stat] = datetime.fromtimestamp(val, tz=pytz.UTC).isoformat() elif stat == "hist": values = stats[stat][1] for i in range(len(values)): - values[i] = datetime.fromtimestamp(values[i]).isoformat() + values[i] = datetime.fromtimestamp( + values[i], tz=pytz.UTC + ).isoformat() # for boolean values, keep mean and histogram values numeric (0 to 1 representation) if col in boolean_columns: for stat, val in stats.items(): diff --git a/mlrun/datastore/__init__.py b/mlrun/datastore/__init__.py index 9833fa1495aa..ac39cf1844a8 100644 --- a/mlrun/datastore/__init__.py +++ b/mlrun/datastore/__init__.py @@ -33,7 +33,12 @@ import mlrun.datastore.wasbfs -from ..platforms.iguazio import KafkaOutputStream, OutputStream, parse_path +from ..platforms.iguazio import ( + HTTPOutputStream, + KafkaOutputStream, + OutputStream, + parse_path, +) from ..utils import logger from .base import DataItem from .datastore import StoreManager, in_memory_store, uri_to_ipython @@ -69,7 +74,7 @@ def get_in_memory_items(): def get_stream_pusher(stream_path: str, **kwargs): - """get a stream pusher object from URL, currently only support v3io stream + """get a stream pusher object from URL. common kwargs:: @@ -87,6 +92,8 @@ def get_stream_pusher(stream_path: str, **kwargs): return KafkaOutputStream( topic, bootstrap_servers, kwargs.get("kafka_producer_options") ) + elif stream_path.startswith("http://") or stream_path.startswith("https://"): + return HTTPOutputStream(stream_path=stream_path) elif "://" not in stream_path: return OutputStream(stream_path, **kwargs) elif stream_path.startswith("v3io"): diff --git a/mlrun/datastore/base.py b/mlrun/datastore/base.py index c0f5f51186ed..77888a6dfa5f 100644 --- a/mlrun/datastore/base.py +++ b/mlrun/datastore/base.py @@ -11,21 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import sys import tempfile +import urllib.parse from base64 import b64encode from os import path, remove +from typing import Optional, Union import dask.dataframe as dd import fsspec import orjson import pandas as pd +import pyarrow +import pytz import requests import urllib3 import mlrun.errors from mlrun.errors import err_to_str -from mlrun.utils import is_ipython, logger +from mlrun.utils import StorePrefix, is_ipython, logger + +from .store_resources import is_store_uri, parse_store_uri +from .utils import filter_df_start_end_time, select_columns_from_df verify_ssl = False if not verify_ssl: @@ -63,6 +69,17 @@ def is_structured(self): def is_unstructured(self): return True + @staticmethod + def _sanitize_url(url): + """ + Extract only the schema, netloc, and path from an input URL if they exist, + excluding parameters, query, or fragments. + """ + parsed_url = urllib.parse.urlparse(url) + scheme = f"{parsed_url.scheme}:" if parsed_url.scheme else "" + netloc = f"//{parsed_url.netloc}" if parsed_url.netloc else "//" + return f"{scheme}{netloc}{parsed_url.path}" + @staticmethod def uri_to_kfp(endpoint, subpath): raise ValueError("data store doesnt support KFP URLs") @@ -71,7 +88,7 @@ def uri_to_kfp(endpoint, subpath): def uri_to_ipython(endpoint, subpath): return "" - def get_filesystem(self, silent=True): + def get_filesystem(self, silent=True) -> Optional[fsspec.AbstractFileSystem]: """return fsspec file system object, if supported""" return None @@ -135,6 +152,64 @@ def download(self, key, target_path): def upload(self, key, src_path): pass + @staticmethod + def _parquet_reader(df_module, url, file_system, time_column, start_time, end_time): + from storey.utils import find_filters, find_partitions + + def set_filters( + partitions_time_attributes, start_time_inner, end_time_inner, kwargs + ): + filters = [] + find_filters( + partitions_time_attributes, + start_time_inner, + end_time_inner, + filters, + time_column, + ) + kwargs["filters"] = filters + + def reader(*args, **kwargs): + if start_time or end_time: + if time_column is None: + raise mlrun.errors.MLRunInvalidArgumentError( + "When providing start_time or end_time, must provide time_column" + ) + + partitions_time_attributes = find_partitions(url, file_system) + set_filters( + partitions_time_attributes, + start_time, + end_time, + kwargs, + ) + try: + return df_module.read_parquet(*args, **kwargs) + except pyarrow.lib.ArrowInvalid as ex: + if not str(ex).startswith( + "Cannot compare timestamp with timezone to timestamp without timezone" + ): + raise ex + + if start_time.tzinfo: + start_time_inner = start_time.replace(tzinfo=None) + end_time_inner = end_time.replace(tzinfo=None) + else: + start_time_inner = start_time.replace(tzinfo=pytz.utc) + end_time_inner = end_time.replace(tzinfo=pytz.utc) + + set_filters( + partitions_time_attributes, + start_time_inner, + end_time_inner, + kwargs, + ) + return df_module.read_parquet(*args, **kwargs) + else: + return df_module.read_parquet(*args, **kwargs) + + return reader + def as_df( self, url, @@ -148,17 +223,29 @@ def as_df( **kwargs, ): df_module = df_module or pd - if url.endswith(".csv") or format == "csv": + file_url = self._sanitize_url(url) + is_csv, is_json, drop_time_column = False, False, False + file_system = self.get_filesystem() + if file_url.endswith(".csv") or format == "csv": + is_csv = True + drop_time_column = False if columns: + if ( + time_column + and (start_time or end_time) + and time_column not in columns + ): + columns.append(time_column) + drop_time_column = True kwargs["usecols"] = columns + reader = df_module.read_csv - filesystem = self.get_filesystem() - if filesystem: - if filesystem.isdir(url): + if file_system: + if file_system.isdir(file_url): def reader(*args, **kwargs): base_path = args[0] - file_entries = filesystem.listdir(base_path) + file_entries = file_system.listdir(base_path) filenames = [] for file_entry in file_entries: if ( @@ -176,51 +263,31 @@ def reader(*args, **kwargs): dfs.append(df_module.read_csv(*updated_args, **kwargs)) return pd.concat(dfs) - elif url.endswith(".parquet") or url.endswith(".pq") or format == "parquet": + elif ( + file_url.endswith(".parquet") + or file_url.endswith(".pq") + or format == "parquet" + ): if columns: kwargs["columns"] = columns - def reader(*args, **kwargs): - if start_time or end_time: - if sys.version_info < (3, 7): - raise ValueError( - f"feature not supported for python version {sys.version_info}" - ) - - if time_column is None: - raise mlrun.errors.MLRunInvalidArgumentError( - "When providing start_time or end_time, must provide time_column" - ) - - from storey.utils import find_filters, find_partitions - - filters = [] - partitions_time_attributes = find_partitions(url, file_system) - - find_filters( - partitions_time_attributes, - start_time, - end_time, - filters, - time_column, - ) - kwargs["filters"] = filters - - return df_module.read_parquet(*args, **kwargs) + reader = self._parquet_reader( + df_module, url, file_system, time_column, start_time, end_time + ) - elif url.endswith(".json") or format == "json": + elif file_url.endswith(".json") or format == "json": + is_json = True reader = df_module.read_json else: raise Exception(f"file type unhandled {url}") - file_system = self.get_filesystem() if file_system: - if self.supports_isdir() and file_system.isdir(url) or df_module == dd: + if self.supports_isdir() and file_system.isdir(file_url) or df_module == dd: storage_options = self.get_storage_options() if storage_options: kwargs["storage_options"] = storage_options - return reader(url, **kwargs) + df = reader(url, **kwargs) else: file = url @@ -230,12 +297,26 @@ def reader(*args, **kwargs): # support the storage_options parameter. file = file_system.open(url) - return reader(file, **kwargs) - - temp_file = tempfile.NamedTemporaryFile(delete=False) - self.download(self._join(subpath), temp_file.name) - df = reader(temp_file.name, **kwargs) - remove(temp_file.name) + df = reader(file, **kwargs) + else: + temp_file = tempfile.NamedTemporaryFile(delete=False) + self.download(self._join(subpath), temp_file.name) + df = reader(temp_file.name, **kwargs) + remove(temp_file.name) + + if is_json or is_csv: + # for parquet file the time filtering is executed in `reader` + df = filter_df_start_end_time( + df, + time_column=time_column, + start_time=start_time, + end_time=end_time, + ) + if drop_time_column: + df.drop(columns=[time_column], inplace=True) + if is_json: + # for csv and parquet files the columns select is executed in `reader`. + df = select_columns_from_df(df, columns=columns) return df def to_dict(self): @@ -383,7 +464,7 @@ def listdir(self): return self._store.listdir(self._path) def local(self): - """get the local path of the file, download to tmp first if its a remote object""" + """get the local path of the file, download to tmp first if it's a remote object""" if self.kind == "file": return self._path if self._local_path: @@ -397,27 +478,47 @@ def local(self): self.download(self._local_path) return self._local_path + def remove_local(self): + """remove the local file if it exists and was downloaded from a remote object""" + if self.kind == "file": + return + + if self._local_path: + remove(self._local_path) + self._local_path = "" + def as_df( self, columns=None, df_module=None, format="", + time_column=None, + start_time=None, + end_time=None, **kwargs, ): """return a dataframe object (generated from the dataitem). - :param columns: optional, list of columns to select - :param df_module: optional, py module used to create the DataFrame (e.g. pd, dd, cudf, ..) - :param format: file format, if not specified it will be deducted from the suffix + :param columns: optional, list of columns to select + :param df_module: optional, py module used to create the DataFrame (e.g. pd, dd, cudf, ..) + :param format: file format, if not specified it will be deducted from the suffix + :param start_time: filters out data before this time + :param end_time: filters out data after this time + :param time_column: Store timestamp_key will be used if None. + The results will be filtered by this column and start_time & end_time. """ - return self._store.as_df( + df = self._store.as_df( self._url, self._path, columns=columns, df_module=df_module, format=format, + time_column=time_column, + start_time=start_time, + end_time=end_time, **kwargs, ) + return df def show(self, format=None): """show the data object content in Jupyter @@ -451,6 +552,19 @@ def show(self, format=None): else: logger.error(f"unsupported show() format {suffix} for {self.url}") + def get_artifact_type(self) -> Union[str, None]: + """ + Check if the data item represents an Artifact (one of Artifact, DatasetArtifact and ModelArtifact). If it does + it return the store uri prefix (artifacts, datasets or models), otherwise None. + + :return: The store prefix of the artifact if it is an artifact data item and None if not. + """ + if self.artifact_url and is_store_uri(url=self.artifact_url): + store_uri_prefix = parse_store_uri(self.artifact_url)[0] + if StorePrefix.is_artifact(prefix=store_uri_prefix): + return store_uri_prefix + return None + def __str__(self): return self.url @@ -514,7 +628,12 @@ def http_upload(url, file_path, headers=None, auth=None): class HttpStore(DataStore): def __init__(self, parent, schema, name, endpoint="", secrets: dict = None): super().__init__(parent, name, schema, endpoint, secrets) + self._https_auth_token = None + self._schema = schema self.auth = None + self._headers = {} + self._enrich_https_token() + self._validate_https_token() def get_filesystem(self, silent=True): """return fsspec file system object, if supported""" @@ -532,9 +651,22 @@ def put(self, key, data, append=False): raise ValueError("unimplemented") def get(self, key, size=None, offset=0): - data = http_get(self.url + self._join(key), None, self.auth) + data = http_get(self.url + self._join(key), self._headers, self.auth) if offset: data = data[offset:] if size: data = data[:size] return data + + def _enrich_https_token(self): + token = self._get_secret_or_env("HTTPS_AUTH_TOKEN") + if token: + self._https_auth_token = token + self._headers.setdefault("Authorization", f"token {token}") + + def _validate_https_token(self): + if self._https_auth_token and self._schema in ["http"]: + logger.warn( + f"A AUTH TOKEN should not be provided while using {self._schema} " + f"schema as it is not secure and is not recommended." + ) diff --git a/mlrun/datastore/inmem.py b/mlrun/datastore/inmem.py index 3397079843b3..57d9c7be4b61 100644 --- a/mlrun/datastore/inmem.py +++ b/mlrun/datastore/inmem.py @@ -80,5 +80,8 @@ def as_df(self, url, subpath, columns=None, df_module=None, format="", **kwargs) reader = df_module.read_json else: raise mlrun.errors.MLRunInvalidArgumentError(f"file type unhandled {url}") + # InMemoryStore store do not filter on time + for field in ["time_column", "start_time", "end_time"]: + kwargs.pop(field, None) return reader(item, **kwargs) diff --git a/mlrun/datastore/sources.py b/mlrun/datastore/sources.py index be9bb0237df8..a9bd90e3c7be 100644 --- a/mlrun/datastore/sources.py +++ b/mlrun/datastore/sources.py @@ -32,7 +32,12 @@ from ..model import DataSource from ..platforms.iguazio import parse_path from ..utils import get_class -from .utils import store_path_to_spark +from .utils import ( + _generate_sql_query_with_time_filter, + filter_df_start_end_time, + select_columns_from_df, + store_path_to_spark, +) def get_source_from_dict(source): @@ -62,38 +67,59 @@ def _get_store(self): def to_step(self, key_field=None, time_field=None, context=None): import storey + if not self.support_storey: + raise mlrun.errors.MLRunRuntimeError( + f"{type(self).__name__} does not support storey engine" + ) + return storey.SyncEmitSource(context=context) def get_table_object(self): """get storey Table object""" return None - def to_dataframe(self): - return mlrun.store_manager.object(url=self.path).as_df() - - def filter_df_start_end_time(self, df, time_field): - # give priority to source time_field over the feature set's timestamp_key - if self.time_field: - time_field = self.time_field - - if self.start_time or self.end_time: - self.start_time = ( - datetime.min if self.start_time is None else self.start_time - ) - self.end_time = datetime.max if self.end_time is None else self.end_time - df = df.filter( - (df[time_field] > self.start_time) & (df[time_field] <= self.end_time) - ) - return df + def to_dataframe( + self, + columns=None, + df_module=None, + entities=None, + start_time=None, + end_time=None, + time_field=None, + ): + """return the source data as dataframe""" + return mlrun.store_manager.object(url=self.path).as_df( + columns=columns, + df_module=df_module, + start_time=start_time or self.start_time, + end_time=end_time or self.end_time, + time_column=time_field or self.time_field, + ) - def to_spark_df(self, session, named_view=False, time_field=None): + def to_spark_df(self, session, named_view=False, time_field=None, columns=None): if self.support_spark: df = session.read.load(**self.get_spark_options()) if named_view: df.createOrReplaceTempView(self.name) - return df + return self._filter_spark_df(df, time_field, columns) raise NotImplementedError() + def _filter_spark_df(self, df, time_field=None, columns=None): + if not (columns or time_field): + return df + + from pyspark.sql.functions import col + + if time_field: + if self.start_time: + df = df.filter(col(time_field) > self.start_time) + if self.end_time: + df = df.filter(col(time_field) <= self.end_time) + + if columns: + df = df.select([col(name) for name in columns]) + return df + def get_spark_options(self): # options used in spark.read.load(**options) raise NotImplementedError() @@ -166,7 +192,6 @@ def to_step(self, key_field=None, time_field=None, context=None): return storey.CSVSource( paths=self.path, - header=True, build_dict=True, key_field=self.key_field or key_field, storage_options=self._get_store().get_storage_options(), @@ -182,7 +207,7 @@ def get_spark_options(self): "inferSchema": "true", } - def to_spark_df(self, session, named_view=False, time_field=None): + def to_spark_df(self, session, named_view=False, time_field=None, columns=None): import pyspark.sql.functions as funcs df = session.read.load(**self.get_spark_options()) @@ -196,15 +221,28 @@ def to_spark_df(self, session, named_view=False, time_field=None): df = df.withColumn(col_name, funcs.col(col_name).cast("timestamp")) if named_view: df.createOrReplaceTempView(self.name) - return df + return self._filter_spark_df(df, time_field, columns) - def to_dataframe(self): - kwargs = self.attributes.get("reader_args", {}) - chunksize = self.attributes.get("chunksize") - if chunksize: - kwargs["chunksize"] = chunksize + def to_dataframe( + self, + columns=None, + df_module=None, + entities=None, + start_time=None, + end_time=None, + time_field=None, + ): + reader_args = self.attributes.get("reader_args", {}) return mlrun.store_manager.object(url=self.path).as_df( - parse_dates=self._parse_dates, **kwargs + columns=columns, + df_module=df_module, + format="csv", + start_time=start_time or self.start_time, + end_time=end_time or self.end_time, + time_column=time_field or self.time_field, + parse_dates=self._parse_dates, + chunksize=self.attributes.get("chunksize"), + **reader_args, ) def is_iterator(self): @@ -246,7 +284,6 @@ def __init__( start_time: Optional[Union[datetime, str]] = None, end_time: Optional[Union[datetime, str]] = None, ): - super().__init__( name, path, @@ -312,10 +349,24 @@ def get_spark_options(self): "format": "parquet", } - def to_dataframe(self): - kwargs = self.attributes.get("reader_args", {}) + def to_dataframe( + self, + columns=None, + df_module=None, + entities=None, + start_time=None, + end_time=None, + time_field=None, + ): + reader_args = self.attributes.get("reader_args", {}) return mlrun.store_manager.object(url=self.path).as_df( - format="parquet", **kwargs + columns=columns, + df_module=df_module, + start_time=start_time or self.start_time, + end_time=end_time or self.end_time, + time_column=time_field or self.time_field, + format="parquet", + **reader_args, ) @@ -323,8 +374,13 @@ class BigQuerySource(BaseSourceDriver): """ Reads Google BigQuery query results as input source for a flow. + For authentication, set the GCP_CREDENTIALS project secret to the credentials json string. + example:: + # set the credentials + project.set_secrets({"GCP_CREDENTIALS": gcp_credentials_json}) + # use sql query query_string = "SELECT * FROM `the-psf.pypi.downloads20210328` LIMIT 5000" source = BigQuerySource("bq1", query=query_string, @@ -376,6 +432,15 @@ def __init__( raise mlrun.errors.MLRunInvalidArgumentError( "cannot specify both table and query args" ) + # Otherwise, the client library does not fully respect the limit + if ( + max_results_for_table + and chunksize + and max_results_for_table % chunksize != 0 + ): + raise mlrun.errors.MLRunInvalidArgumentError( + "max_results_for_table must be a multiple of chunksize" + ) attrs = { "query": query, "table": table, @@ -395,7 +460,6 @@ def __init__( start_time=start_time, end_time=end_time, ) - self._rows_iterator = None def _get_credentials_string(self): gcp_project = self.attributes.get("gcp_project", None) @@ -417,7 +481,15 @@ def _get_credentials(self): return credentials, gcp_project or gcp_cred_dict["project_id"] return None, gcp_project - def to_dataframe(self): + def to_dataframe( + self, + columns=None, + df_module=None, + entities=None, + start_time=None, + end_time=None, + time_field=None, + ): from google.cloud import bigquery from google.cloud.bigquery_storage_v1 import BigQueryReadClient @@ -438,39 +510,43 @@ def schema_to_dtypes(schema): if query: query_job = bqclient.query(query) - self._rows_iterator = query_job.result(page_size=chunksize) - dtypes = schema_to_dtypes(self._rows_iterator.schema) - if chunksize: - # passing bqstorage_client greatly improves performance - return self._rows_iterator.to_dataframe_iterable( - bqstorage_client=BigQueryReadClient(), dtypes=dtypes - ) - else: - return self._rows_iterator.to_dataframe(dtypes=dtypes) + rows_iterator = query_job.result(page_size=chunksize) elif table: table = self.attributes.get("table") max_results = self.attributes.get("max_results") - rows = bqclient.list_rows( + rows_iterator = bqclient.list_rows( table, page_size=chunksize, max_results=max_results ) - dtypes = schema_to_dtypes(rows.schema) - if chunksize: - # passing bqstorage_client greatly improves performance - return rows.to_dataframe_iterable( - bqstorage_client=BigQueryReadClient(), dtypes=dtypes - ) - else: - return rows.to_dataframe(dtypes=dtypes) else: raise mlrun.errors.MLRunInvalidArgumentError( "table or query args must be specified" ) + dtypes = schema_to_dtypes(rows_iterator.schema) + if chunksize: + # passing bqstorage_client greatly improves performance + df = rows_iterator.to_dataframe_iterable( + bqstorage_client=BigQueryReadClient(), dtypes=dtypes + ) + else: + df = rows_iterator.to_dataframe(dtypes=dtypes) + + # TODO : filter as part of the query + return select_columns_from_df( + filter_df_start_end_time( + df, + time_column=time_field or self.time_field, + start_time=start_time or self.start_time, + end_time=end_time or self.end_time, + ), + columns=columns, + ) + def is_iterator(self): return bool(self.attributes.get("chunksize")) - def to_spark_df(self, session, named_view=False, time_field=None): + def to_spark_df(self, session, named_view=False, time_field=None, columns=None): options = copy(self.attributes.get("spark_options", {})) credentials, gcp_project = self._get_credentials_string() if credentials: @@ -500,7 +576,7 @@ def to_spark_df(self, session, named_view=False, time_field=None): df = session.read.format("bigquery").load(**options) if named_view: df.createOrReplaceTempView(self.name) - return df + return self._filter_spark_df(df, time_field, columns) class SnowflakeSource(BaseSourceDriver): @@ -664,7 +740,7 @@ def to_step(self, key_field=None, time_field=None, context=None): context=self.context or context, ) - def to_dataframe(self): + def to_dataframe(self, **kwargs): return self._df def is_iterator(self): @@ -839,7 +915,15 @@ def __init__( attributes["sasl"] = sasl super().__init__(attributes=attributes, **kwargs) - def to_dataframe(self): + def to_dataframe( + self, + columns=None, + df_module=None, + entities=None, + start_time=None, + end_time=None, + time_field=None, + ): raise mlrun.MLRunInvalidArgumentError( "KafkaSource does not support batch processing" ) @@ -880,13 +964,14 @@ def __init__( table_name: str = None, spark_options: dict = None, time_fields: List[str] = None, + parse_dates: List[str] = None, ): """ Reads SqlDB as input source for a flow. example:: - db_path = "mysql+pymysql://:@:/" - source = SqlDBSource( - collection_name='source_name', db_path=self.db, key_field='key' + db_url = "mysql+pymysql://:@:/" + source = SQLSource( + table_name='source_name', db_url=db_url, key_field='key' ) :param name: source name :param chunksize: number of rows per chunk (default large single chunk) @@ -903,19 +988,32 @@ def __init__( from the current database :param spark_options: additional spark read options :param time_fields : all the field to be parsed as timestamp. + :param parse_dates : all the field to be parsed as timestamp. """ - + if time_fields: + warnings.warn( + "'time_fields' is deprecated, use 'parse_dates' instead. " + "This will be removed in 1.6.0", + # TODO: Remove this in 1.6.0 + FutureWarning, + ) + parse_dates = time_fields db_url = db_url or mlrun.mlconf.sql.url if db_url is None: raise mlrun.errors.MLRunInvalidArgumentError( "cannot specify without db_path arg or secret MLRUN_SQL__URL" ) + if time_field: + if parse_dates: + time_fields.append(time_field) + else: + parse_dates = [time_field] attrs = { "chunksize": chunksize, "spark_options": spark_options, "table_name": table_name, "db_path": db_url, - "time_fields": time_fields, + "parse_dates": parse_dates, } attrs = {key: value for key, value in attrs.items() if value is not None} super().__init__( @@ -928,22 +1026,40 @@ def __init__( end_time=end_time, ) - def to_dataframe(self): - import sqlalchemy as db + def to_dataframe( + self, + columns=None, + df_module=None, + entities=None, + start_time=None, + end_time=None, + time_field=None, + ): + import sqlalchemy as sqlalchemy - query = self.attributes.get("query", None) db_path = self.attributes.get("db_path") table_name = self.attributes.get("table_name") - if not query: - query = f"SELECT * FROM {table_name}" + parse_dates = self.attributes.get("parse_dates") + time_field = time_field or self.time_field + start_time = start_time or self.start_time + end_time = end_time or self.end_time if table_name and db_path: - engine = db.create_engine(db_path) + engine = sqlalchemy.create_engine(db_path) + query, parse_dates = _generate_sql_query_with_time_filter( + table_name=table_name, + engine=engine, + time_column=time_field, + parse_dates=parse_dates, + start_time=start_time, + end_time=end_time, + ) with engine.connect() as con: return pd.read_sql( query, con=con, chunksize=self.attributes.get("chunksize"), - parse_dates=self.attributes.get("time_fields"), + parse_dates=parse_dates, + columns=columns, ) else: raise mlrun.errors.MLRunInvalidArgumentError( diff --git a/mlrun/datastore/spark_udf.py b/mlrun/datastore/spark_udf.py new file mode 100644 index 000000000000..f4c31f10d510 --- /dev/null +++ b/mlrun/datastore/spark_udf.py @@ -0,0 +1,44 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import hashlib + +from pyspark.sql.functions import udf +from pyspark.sql.types import StringType + + +def _hash_list(*list_to_hash): + list_to_hash = [str(element) for element in list_to_hash] + str_concatted = "".join(list_to_hash) + sha1 = hashlib.sha1() + sha1.update(str_concatted.encode("utf8")) + return sha1.hexdigest() + + +def _redis_stringify_key(*args): + if len(args) == 1: + key_list = args[0] + else: + key_list = list(args) + suffix = "}:static" + if isinstance(key_list, list): + if len(key_list) >= 3: + return str(key_list[0]) + "." + _hash_list(*key_list[1:]) + suffix + if len(key_list) == 2: + return str(key_list[0]) + "." + str(key_list[1]) + suffix + return str(key_list[0]) + suffix + return str(key_list) + suffix + + +hash_and_concat_v3io_udf = udf(_hash_list, StringType()) +hash_and_concat_redis_udf = udf(_redis_stringify_key, StringType()) diff --git a/mlrun/datastore/store_resources.py b/mlrun/datastore/store_resources.py index d85aae13ae8b..bee8811db97a 100644 --- a/mlrun/datastore/store_resources.py +++ b/mlrun/datastore/store_resources.py @@ -81,7 +81,7 @@ def get_table(self, uri): endpoint, uri = parse_path(uri) self._tabels[uri] = Table( uri, - V3ioDriver(webapi=endpoint), + V3ioDriver(webapi=endpoint or mlrun.mlconf.v3io_api), flush_interval_secs=mlrun.mlconf.feature_store.flush_interval, ) return self._tabels[uri] @@ -101,8 +101,8 @@ def get_table(self, uri): if is_store_uri(uri): resource = get_store_resource(uri) if resource.kind in [ - mlrun.api.schemas.ObjectKind.feature_set.value, - mlrun.api.schemas.ObjectKind.feature_vector.value, + mlrun.common.schemas.ObjectKind.feature_set.value, + mlrun.common.schemas.ObjectKind.feature_vector.value, ]: target = get_online_target(resource) if not target: diff --git a/mlrun/datastore/targets.py b/mlrun/datastore/targets.py index 3624fe91e022..a930a18e9df2 100644 --- a/mlrun/datastore/targets.py +++ b/mlrun/datastore/targets.py @@ -15,7 +15,9 @@ import datetime import os import random +import sys import time +import warnings from collections import Counter from copy import copy from typing import Any, Dict, List, Optional, Union @@ -34,7 +36,13 @@ from .. import errors from ..data_types import ValueType from ..platforms.iguazio import parse_path, split_path -from .utils import parse_kafka_url, store_path_to_spark +from .utils import ( + _generate_sql_query_with_time_filter, + filter_df_start_end_time, + parse_kafka_url, + select_columns_from_df, + store_path_to_spark, +) class TargetTypes: @@ -525,8 +533,8 @@ def write_dataframe( ("minute", "%M"), ]: partition_cols.append(unit) - target_df[unit] = getattr( - pd.DatetimeIndex(target_df[timestamp_key]), unit + target_df[unit] = pd.DatetimeIndex(target_df[timestamp_key]).format( + date_format=fmt ) if unit == time_partitioning_granularity: break @@ -986,6 +994,9 @@ def as_df( df_module=df_module, entities=entities, format="csv", + start_time=start_time, + end_time=end_time, + time_column=time_column, **kwargs, ) if entities: @@ -1050,24 +1061,11 @@ def add_writer_step( **self.attributes, ) + def prepare_spark_df(self, df, key_columns): + raise NotImplementedError() + def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True): - spark_options = { - "path": store_path_to_spark(self.get_target_path()), - "format": "io.iguaz.v3io.spark.sql.kv", - } - if isinstance(key_column, list) and len(key_column) >= 1: - if len(key_column) > 2: - raise mlrun.errors.MLRunInvalidArgumentError( - f"Spark supports maximun of 2 keys and {key_column} are provided" - ) - spark_options["key"] = key_column[0] - if len(key_column) > 1: - spark_options["sorting-key"] = key_column[1] - else: - spark_options["key"] = key_column - if not overwrite: - spark_options["columnUpdate"] = True - return spark_options + raise NotImplementedError() def get_dask_options(self): return {"format": "csv"} @@ -1075,15 +1073,6 @@ def get_dask_options(self): def as_df(self, columns=None, df_module=None, **kwargs): raise NotImplementedError() - def prepare_spark_df(self, df, key_columns): - import pyspark.sql.functions as funcs - - for col_name, col_type in df.dtypes: - if col_type.startswith("decimal("): - # V3IO does not support this level of precision - df = df.withColumn(col_name, funcs.col(col_name).cast("double")) - return df - def write_dataframe( self, df, key_column=None, timestamp_key=None, chunk_id=0, **kwargs ): @@ -1123,10 +1112,52 @@ def get_table_object(self): endpoint, uri = parse_path(self.get_target_path()) return Table( uri, - V3ioDriver(webapi=endpoint), + V3ioDriver(webapi=endpoint or mlrun.mlconf.v3io_api), flush_interval_secs=mlrun.mlconf.feature_store.flush_interval, ) + def get_spark_options(self, key_column=None, timestamp_key=None, overwrite=True): + spark_options = { + "path": store_path_to_spark(self.get_target_path()), + "format": "io.iguaz.v3io.spark.sql.kv", + } + if isinstance(key_column, list) and len(key_column) >= 1: + spark_options["key"] = key_column[0] + if len(key_column) > 2: + spark_options["sorting-key"] = "_spark_object_name" + if len(key_column) == 2: + spark_options["sorting-key"] = key_column[1] + else: + spark_options["key"] = key_column + if not overwrite: + spark_options["columnUpdate"] = True + return spark_options + + def prepare_spark_df(self, df, key_columns): + from pyspark.sql.functions import col + + spark_udf_directory = os.path.dirname(os.path.abspath(__file__)) + sys.path.append(spark_udf_directory) + try: + import spark_udf + + df.rdd.context.addFile(spark_udf.__file__) + + for col_name, col_type in df.dtypes: + if col_type.startswith("decimal("): + # V3IO does not support this level of precision + df = df.withColumn(col_name, col(col_name).cast("double")) + if len(key_columns) > 2: + return df.withColumn( + "_spark_object_name", + spark_udf.hash_and_concat_v3io_udf( + *[col(c) for c in key_columns[1:]] + ), + ) + finally: + sys.path.remove(spark_udf_directory) + return df + class RedisNoSqlTarget(NoSqlBaseTarget): kind = TargetTypes.redisnosql @@ -1186,11 +1217,23 @@ def get_target_path_with_credentials(self): return endpoint def prepare_spark_df(self, df, key_columns): - from pyspark.sql.functions import udf - from pyspark.sql.types import StringType + from pyspark.sql.functions import col + + spark_udf_directory = os.path.dirname(os.path.abspath(__file__)) + sys.path.append(spark_udf_directory) + try: + import spark_udf + + df.rdd.context.addFile(spark_udf.__file__) - udf1 = udf(lambda x: str(x) + "}:static", StringType()) - return df.withColumn("_spark_object_name", udf1(key_columns[0])) + df = df.withColumn( + "_spark_object_name", + spark_udf.hash_and_concat_redis_udf(*[col(c) for c in key_columns]), + ) + finally: + sys.path.remove(spark_udf_directory) + + return df class StreamTarget(BaseStoreTarget): @@ -1224,7 +1267,7 @@ def add_writer_step( graph_shape="cylinder", class_name="storey.StreamTarget", columns=column_list, - storage=V3ioDriver(webapi=endpoint), + storage=V3ioDriver(webapi=endpoint or mlrun.mlconf.v3io_api), stream_path=uri, **self.attributes, ) @@ -1441,7 +1484,15 @@ def as_df( time_column=None, **kwargs, ): - return self._df + return select_columns_from_df( + filter_df_start_end_time( + self._df, + time_column=time_column, + start_time=start_time, + end_time=end_time, + ), + columns, + ) class SQLTarget(BaseStoreTarget): @@ -1472,14 +1523,15 @@ def __init__( # create_according_to_data: bool = False, time_fields: List[str] = None, varchar_len: int = 50, + parse_dates: List[str] = None, ): """ Write to SqlDB as output target for a flow. example:: - db_path = "sqlite:///stockmarket.db" + db_url = "sqlite:///stockmarket.db" schema = {'time': datetime.datetime, 'ticker': str, 'bid': float, 'ask': float, 'ind': int} - target = SqlDBTarget(table_name=f'{name}-tatget', db_path=db_path, create_table=True, + target = SqlDBTarget(table_name=f'{name}-target', db_url=db_url, create_table=True, schema=schema, primary_key_column=key) :param name: :param path: @@ -1509,8 +1561,17 @@ def __init__( :param create_according_to_data: (not valid) :param time_fields : all the field to be parsed as timestamp. :param varchar_len : the defalut len of the all the varchar column (using if needed to create the table). + :param parse_dates : all the field to be parsed as timestamp. """ create_according_to_data = False # TODO: open for user + if time_fields: + warnings.warn( + "'time_fields' is deprecated, use 'parse_dates' instead. " + "This will be removed in 1.6.0", + # TODO: Remove this in 1.6.0 + FutureWarning, + ) + parse_dates = time_fields db_url = db_url or mlrun.mlconf.sql.url if db_url is None or table_name is None: attr = {} @@ -1523,7 +1584,7 @@ def __init__( "db_path": db_url, "create_according_to_data": create_according_to_data, "if_exists": if_exists, - "time_fields": time_fields, + "parse_dates": parse_dates, "varchar_len": varchar_len, } path = ( @@ -1610,16 +1671,24 @@ def as_df( ): db_path, table_name, _, _, _, _ = self._parse_url() engine = sqlalchemy.create_engine(db_path) + parse_dates: Optional[List[str]] = self.attributes.get("parse_dates") with engine.connect() as conn: + query, parse_dates = _generate_sql_query_with_time_filter( + table_name=table_name, + engine=engine, + time_column=time_column, + parse_dates=parse_dates, + start_time=start_time, + end_time=end_time, + ) df = pd.read_sql( - f"SELECT * FROM {self.attributes.get('table_name')}", + query, con=conn, - parse_dates=self.attributes.get("time_fields"), + parse_dates=parse_dates, + columns=columns, ) if self._primary_key_column: df.set_index(self._primary_key_column, inplace=True) - if columns: - df = df[columns] return df def write_dataframe( @@ -1730,12 +1799,12 @@ def _get_target_path(driver, resource, run_id_mode=False): if not suffix: if ( kind == ParquetTarget.kind - and resource.kind == mlrun.api.schemas.ObjectKind.feature_vector + and resource.kind == mlrun.common.schemas.ObjectKind.feature_vector ): suffix = ".parquet" kind_prefix = ( "sets" - if resource.kind == mlrun.api.schemas.ObjectKind.feature_set + if resource.kind == mlrun.common.schemas.ObjectKind.feature_set else "vectors" ) name = resource.metadata.name diff --git a/mlrun/datastore/utils.py b/mlrun/datastore/utils.py index 9fd7be42fc0e..83444b838d30 100644 --- a/mlrun/datastore/utils.py +++ b/mlrun/datastore/utils.py @@ -12,7 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from urllib.parse import urlparse +import tarfile +import tempfile +import typing +from urllib.parse import parse_qs, urlparse + +import pandas as pd +import sqlalchemy + +import mlrun.datastore def store_path_to_spark(path): @@ -36,11 +44,125 @@ def store_path_to_spark(path): return path -def parse_kafka_url(url, bootstrap_servers=None): +def parse_kafka_url( + url: str, bootstrap_servers: typing.List = None +) -> typing.Tuple[str, typing.List]: + """Generating Kafka topic and adjusting a list of bootstrap servers. + + :param url: URL path to parse using urllib.parse.urlparse. + :param bootstrap_servers: List of bootstrap servers for the kafka brokers. + + :return: A tuple of: + [0] = Kafka topic value + [1] = List of bootstrap servers + """ bootstrap_servers = bootstrap_servers or [] + + # Parse the provided URL into six components according to the general structure of a URL url = urlparse(url) + + # Add the network location to the bootstrap servers list if url.netloc: bootstrap_servers = [url.netloc] + bootstrap_servers - topic = url.path - topic = topic.lstrip("/") + + # Get the topic value from the parsed url + query_dict = parse_qs(url.query) + if "topic" in query_dict: + topic = query_dict["topic"][0] + else: + topic = url.path + topic = topic.lstrip("/") return topic, bootstrap_servers + + +def upload_tarball(source_dir, target, secrets=None): + # will delete the temp file + with tempfile.NamedTemporaryFile(suffix=".tar.gz") as temp_fh: + with tarfile.open(mode="w:gz", fileobj=temp_fh) as tar: + tar.add(source_dir, arcname="") + stores = mlrun.datastore.store_manager.set(secrets) + datastore, subpath = stores.get_or_create_store(target) + datastore.upload(subpath, temp_fh.name) + + +def filter_df_start_end_time( + df: typing.Union[pd.DataFrame, typing.Iterator[pd.DataFrame]], + time_column: str = None, + start_time: pd.Timestamp = None, + end_time: pd.Timestamp = None, +) -> typing.Union[pd.DataFrame, typing.Iterator[pd.DataFrame]]: + if not time_column or (not start_time and not end_time): + return df + if isinstance(df, pd.DataFrame): + return _execute_time_filter(df, time_column, start_time, end_time) + else: + filter_df_generator(df, time_column, start_time, end_time) + + +def filter_df_generator( + dfs: typing.Iterator[pd.DataFrame], + time_field: str, + start_time: pd.Timestamp, + end_time: pd.Timestamp, +) -> typing.Iterator[pd.DataFrame]: + for df in dfs: + yield _execute_time_filter(df, time_field, start_time, end_time) + + +def _execute_time_filter( + df: pd.DataFrame, time_column: str, start_time: pd.Timestamp, end_time: pd.Timestamp +): + df[time_column] = pd.to_datetime(df[time_column]) + if start_time: + df = df[df[time_column] > start_time] + if end_time: + df = df[df[time_column] <= end_time] + return df + + +def select_columns_from_df( + df: typing.Union[pd.DataFrame, typing.Iterator[pd.DataFrame]], + columns: typing.List[str], +) -> typing.Union[pd.DataFrame, typing.Iterator[pd.DataFrame]]: + if not columns: + return df + if isinstance(df, pd.DataFrame): + return df[columns] + else: + return select_columns_generator(df, columns) + + +def select_columns_generator( + dfs: typing.Union[pd.DataFrame, typing.Iterator[pd.DataFrame]], + columns: typing.List[str], +) -> typing.Iterator[pd.DataFrame]: + for df in dfs: + yield df[columns] + + +def _generate_sql_query_with_time_filter( + table_name: str, + engine: sqlalchemy.engine.Engine, + time_column: str, + parse_dates: typing.List[str], + start_time: pd.Timestamp, + end_time: pd.Timestamp, +): + table = sqlalchemy.Table( + table_name, + sqlalchemy.MetaData(), + autoload=True, + autoload_with=engine, + ) + query = sqlalchemy.select(table) + if time_column: + if parse_dates and time_column not in parse_dates: + parse_dates.append(time_column) + else: + parse_dates = [time_column] + if start_time: + query = query.filter(getattr(table.c, time_column) > start_time) + if end_time: + query = query.filter(getattr(table.c, time_column) <= end_time) + + return query, parse_dates diff --git a/mlrun/db/__init__.py b/mlrun/db/__init__.py index 3692ac49cdf9..63cadc04c5c5 100644 --- a/mlrun/db/__init__.py +++ b/mlrun/db/__init__.py @@ -18,7 +18,7 @@ from ..platforms import add_or_refresh_credentials from ..utils import logger from .base import RunDBError, RunDBInterface # noqa -from .filedb import FileRunDB +from .nopdb import NopDB from .sqldb import SQLDB @@ -69,12 +69,14 @@ def get_run_db(url="", secrets=None, force_reconnect=False): kwargs = {} if "://" not in str(url) or scheme in ["file", "s3", "v3io", "v3ios"]: logger.warning( - "Could not detect path to API server, Using Deprecated client interface" + "Could not detect path to API server, not connected to API server!" ) logger.warning( - "Please make sure your env variable MLRUN_DBPATH is configured correctly to point to the API server!" + "MLRUN_DBPATH is not set. Set this environment variable to the URL of the API server" + " in order to connect" ) - cls = FileRunDB + cls = NopDB + elif scheme in ("http", "https"): # import here to avoid circular imports from .httpdb import HTTPRunDB diff --git a/mlrun/db/base.py b/mlrun/db/base.py index 28c6c4ed4596..eb47602e45d6 100644 --- a/mlrun/db/base.py +++ b/mlrun/db/base.py @@ -13,12 +13,13 @@ # limitations under the License. import datetime +import typing import warnings from abc import ABC, abstractmethod from typing import List, Optional, Union -from mlrun.api import schemas -from mlrun.api.schemas import ModelEndpoint +import mlrun.common.schemas +import mlrun.model_monitoring.model_endpoint class RunDBError(Exception): @@ -49,7 +50,7 @@ def update_run(self, updates: dict, uid, project="", iter=0): pass @abstractmethod - def abort_run(self, uid, project="", iter=0): + def abort_run(self, uid, project="", iter=0, timeout=45): pass @abstractmethod @@ -71,10 +72,12 @@ def list_runs( start_time_to: datetime.datetime = None, last_update_time_from: datetime.datetime = None, last_update_time_to: datetime.datetime = None, - partition_by: Union[schemas.RunPartitionByField, str] = None, + partition_by: Union[mlrun.common.schemas.RunPartitionByField, str] = None, rows_per_partition: int = 1, - partition_sort_by: Union[schemas.SortField, str] = None, - partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, + partition_sort_by: Union[mlrun.common.schemas.SortField, str] = None, + partition_order: Union[ + mlrun.common.schemas.OrderType, str + ] = mlrun.common.schemas.OrderType.desc, max_partitions: int = 0, with_notifications: bool = False, ): @@ -108,7 +111,7 @@ def list_artifacts( iter: int = None, best_iteration: bool = False, kind: str = None, - category: Union[str, schemas.ArtifactCategories] = None, + category: Union[str, mlrun.common.schemas.ArtifactCategories] = None, ): pass @@ -148,7 +151,7 @@ def tag_objects( self, project: str, tag_name: str, - tag_objects: schemas.TagObjects, + tag_objects: mlrun.common.schemas.TagObjects, replace: bool = False, ): pass @@ -158,7 +161,7 @@ def delete_objects_tag( self, project: str, tag_name: str, - tag_objects: schemas.TagObjects, + tag_objects: mlrun.common.schemas.TagObjects, ): pass @@ -184,11 +187,11 @@ def delete_artifacts_tags( @staticmethod def _resolve_artifacts_to_tag_objects( artifacts, - ) -> schemas.TagObjects: + ) -> mlrun.common.schemas.TagObjects: """ :param artifacts: Can be a list of :py:class:`~mlrun.artifacts.Artifact` objects or dictionaries, or a single object. - :return: :py:class:`~mlrun.api.schemas.TagObjects` + :return: :py:class:`~mlrun.common.schemas.TagObjects` """ # to avoid circular imports we import here import mlrun.artifacts.base @@ -204,7 +207,7 @@ def _resolve_artifacts_to_tag_objects( else artifact ) artifact_identifiers.append( - schemas.ArtifactIdentifier( + mlrun.common.schemas.ArtifactIdentifier( key=mlrun.utils.get_in_artifact(artifact_obj, "key"), # we are passing tree as uid when storing an artifact, so if uid is not defined, # pass the tree as uid @@ -214,13 +217,15 @@ def _resolve_artifacts_to_tag_objects( iter=mlrun.utils.get_in_artifact(artifact_obj, "iter"), ) ) - return schemas.TagObjects(kind="artifact", identifiers=artifact_identifiers) + return mlrun.common.schemas.TagObjects( + kind="artifact", identifiers=artifact_identifiers + ) @abstractmethod def delete_project( self, name: str, - deletion_strategy: schemas.DeletionStrategy = schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), ): pass @@ -228,8 +233,8 @@ def delete_project( def store_project( self, name: str, - project: schemas.Project, - ) -> schemas.Project: + project: mlrun.common.schemas.Project, + ) -> mlrun.common.schemas.Project: pass @abstractmethod @@ -237,40 +242,45 @@ def patch_project( self, name: str, project: dict, - patch_mode: schemas.PatchMode = schemas.PatchMode.replace, - ) -> schemas.Project: + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, + ) -> mlrun.common.schemas.Project: pass @abstractmethod def create_project( self, - project: schemas.Project, - ) -> schemas.Project: + project: mlrun.common.schemas.Project, + ) -> mlrun.common.schemas.Project: pass @abstractmethod def list_projects( self, owner: str = None, - format_: schemas.ProjectsFormat = schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: List[str] = None, - state: schemas.ProjectState = None, - ) -> schemas.ProjectsOutput: + state: mlrun.common.schemas.ProjectState = None, + ) -> mlrun.common.schemas.ProjectsOutput: pass @abstractmethod - def get_project(self, name: str) -> schemas.Project: + def get_project(self, name: str) -> mlrun.common.schemas.Project: pass @abstractmethod def list_artifact_tags( - self, project=None, category: Union[str, schemas.ArtifactCategories] = None + self, + project=None, + category: Union[str, mlrun.common.schemas.ArtifactCategories] = None, ): pass @abstractmethod def create_feature_set( - self, feature_set: Union[dict, schemas.FeatureSet], project="", versioned=True + self, + feature_set: Union[dict, mlrun.common.schemas.FeatureSet], + project="", + versioned=True, ) -> dict: pass @@ -288,7 +298,7 @@ def list_features( tag: str = None, entities: List[str] = None, labels: List[str] = None, - ) -> schemas.FeaturesOutput: + ) -> mlrun.common.schemas.FeaturesOutput: pass @abstractmethod @@ -298,7 +308,7 @@ def list_entities( name: str = None, tag: str = None, labels: List[str] = None, - ) -> schemas.EntitiesOutput: + ) -> mlrun.common.schemas.EntitiesOutput: pass @abstractmethod @@ -311,17 +321,21 @@ def list_feature_sets( entities: List[str] = None, features: List[str] = None, labels: List[str] = None, - partition_by: Union[schemas.FeatureStorePartitionByField, str] = None, + partition_by: Union[ + mlrun.common.schemas.FeatureStorePartitionByField, str + ] = None, rows_per_partition: int = 1, - partition_sort_by: Union[schemas.SortField, str] = None, - partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, + partition_sort_by: Union[mlrun.common.schemas.SortField, str] = None, + partition_order: Union[ + mlrun.common.schemas.OrderType, str + ] = mlrun.common.schemas.OrderType.desc, ) -> List[dict]: pass @abstractmethod def store_feature_set( self, - feature_set: Union[dict, schemas.FeatureSet], + feature_set: Union[dict, mlrun.common.schemas.FeatureSet], name=None, project="", tag=None, @@ -338,7 +352,9 @@ def patch_feature_set( project="", tag=None, uid=None, - patch_mode: Union[str, schemas.PatchMode] = schemas.PatchMode.replace, + patch_mode: Union[ + str, mlrun.common.schemas.PatchMode + ] = mlrun.common.schemas.PatchMode.replace, ): pass @@ -349,7 +365,7 @@ def delete_feature_set(self, name, project="", tag=None, uid=None): @abstractmethod def create_feature_vector( self, - feature_vector: Union[dict, schemas.FeatureVector], + feature_vector: Union[dict, mlrun.common.schemas.FeatureVector], project="", versioned=True, ) -> dict: @@ -369,17 +385,21 @@ def list_feature_vectors( tag: str = None, state: str = None, labels: List[str] = None, - partition_by: Union[schemas.FeatureStorePartitionByField, str] = None, + partition_by: Union[ + mlrun.common.schemas.FeatureStorePartitionByField, str + ] = None, rows_per_partition: int = 1, - partition_sort_by: Union[schemas.SortField, str] = None, - partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, + partition_sort_by: Union[mlrun.common.schemas.SortField, str] = None, + partition_order: Union[ + mlrun.common.schemas.OrderType, str + ] = mlrun.common.schemas.OrderType.desc, ) -> List[dict]: pass @abstractmethod def store_feature_vector( self, - feature_vector: Union[dict, schemas.FeatureVector], + feature_vector: Union[dict, mlrun.common.schemas.FeatureVector], name=None, project="", tag=None, @@ -396,7 +416,9 @@ def patch_feature_vector( project="", tag=None, uid=None, - patch_mode: Union[str, schemas.PatchMode] = schemas.PatchMode.replace, + patch_mode: Union[ + str, mlrun.common.schemas.PatchMode + ] = mlrun.common.schemas.PatchMode.replace, ): pass @@ -413,10 +435,10 @@ def list_pipelines( page_token: str = "", filter_: str = "", format_: Union[ - str, schemas.PipelinesFormat - ] = schemas.PipelinesFormat.metadata_only, + str, mlrun.common.schemas.PipelinesFormat + ] = mlrun.common.schemas.PipelinesFormat.metadata_only, page_size: int = None, - ) -> schemas.PipelinesOutput: + ) -> mlrun.common.schemas.PipelinesOutput: pass @abstractmethod @@ -424,8 +446,8 @@ def create_project_secrets( self, project: str, provider: Union[ - str, schemas.SecretProviderName - ] = schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, secrets: dict = None, ): pass @@ -436,10 +458,10 @@ def list_project_secrets( project: str, token: str, provider: Union[ - str, schemas.SecretProviderName - ] = schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, secrets: List[str] = None, - ) -> schemas.SecretsData: + ) -> mlrun.common.schemas.SecretsData: pass @abstractmethod @@ -447,10 +469,10 @@ def list_project_secret_keys( self, project: str, provider: Union[ - str, schemas.SecretProviderName - ] = schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, token: str = None, - ) -> schemas.SecretKeysData: + ) -> mlrun.common.schemas.SecretKeysData: pass @abstractmethod @@ -458,8 +480,8 @@ def delete_project_secrets( self, project: str, provider: Union[ - str, schemas.SecretProviderName - ] = schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, secrets: List[str] = None, ): pass @@ -469,8 +491,8 @@ def create_user_secrets( self, user: str, provider: Union[ - str, schemas.SecretProviderName - ] = schemas.SecretProviderName.vault, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.vault, secrets: dict = None, ): pass @@ -480,7 +502,9 @@ def create_model_endpoint( self, project: str, endpoint_id: str, - model_endpoint: ModelEndpoint, + model_endpoint: Union[ + mlrun.model_monitoring.model_endpoint.ModelEndpoint, dict + ], ): pass @@ -527,31 +551,33 @@ def patch_model_endpoint( pass @abstractmethod - def create_marketplace_source( - self, source: Union[dict, schemas.IndexedMarketplaceSource] + def create_hub_source( + self, source: Union[dict, mlrun.common.schemas.IndexedHubSource] ): pass @abstractmethod - def store_marketplace_source( - self, source_name: str, source: Union[dict, schemas.IndexedMarketplaceSource] + def store_hub_source( + self, + source_name: str, + source: Union[dict, mlrun.common.schemas.IndexedHubSource], ): pass @abstractmethod - def list_marketplace_sources(self): + def list_hub_sources(self): pass @abstractmethod - def get_marketplace_source(self, source_name: str): + def get_hub_source(self, source_name: str): pass @abstractmethod - def delete_marketplace_source(self, source_name: str): + def delete_hub_source(self, source_name: str): pass @abstractmethod - def get_marketplace_catalog( + def get_hub_catalog( self, source_name: str, version: str = None, @@ -561,7 +587,7 @@ def get_marketplace_catalog( pass @abstractmethod - def get_marketplace_item( + def get_hub_item( self, source_name: str, item_name: str, @@ -573,6 +599,25 @@ def get_marketplace_item( @abstractmethod def verify_authorization( - self, authorization_verification_input: schemas.AuthorizationVerificationInput + self, + authorization_verification_input: mlrun.common.schemas.AuthorizationVerificationInput, + ): + pass + + def get_builder_status( + self, + func: "mlrun.runtimes.BaseRuntime", + offset: int = 0, + logs: bool = True, + last_log_timestamp: float = 0.0, + verbose: bool = False, + ): + pass + + def set_run_notifications( + self, + project: str, + runs: typing.List[mlrun.model.RunObject], + notifications: typing.List[mlrun.model.Notification], ): pass diff --git a/mlrun/db/filedb.py b/mlrun/db/filedb.py deleted file mode 100644 index d96a1cdf26be..000000000000 --- a/mlrun/db/filedb.py +++ /dev/null @@ -1,890 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import pathlib -from datetime import datetime, timedelta, timezone -from os import listdir, makedirs, path, remove, scandir -from typing import List, Optional, Union - -import yaml -from dateutil.parser import parse as parse_time - -import mlrun.api.schemas -import mlrun.errors - -from ..api import schemas -from ..api.schemas import ModelEndpoint -from ..config import config -from ..datastore import store_manager -from ..lists import ArtifactList, RunList -from ..utils import ( - dict_to_json, - dict_to_yaml, - fill_function_hash, - generate_object_uri, - get_in, - logger, - match_labels, - match_times, - match_value, - match_value_options, - update_in, -) -from .base import RunDBError, RunDBInterface - -run_logs = "runs" -artifacts_dir = "artifacts" -functions_dir = "functions" -schedules_dir = "schedules" - - -# TODO: remove fileDB, doesn't needs to be used anymore -class FileRunDB(RunDBInterface): - kind = "file" - - def __init__(self, dirpath="", format=".yaml"): - self.format = format - self.dirpath = dirpath - self._datastore = None - self._subpath = None - self._secrets = None - makedirs(self.schedules_dir, exist_ok=True) - - def connect(self, secrets=None): - self._secrets = secrets - return self - - def _connect(self, secrets=None): - sm = store_manager.set(secrets or self._secrets) - self._datastore, self._subpath = sm.get_or_create_store(self.dirpath) - return self - - @property - def datastore(self): - if not self._datastore: - self._connect() - return self._datastore - - def store_log(self, uid, project="", body=None, append=False): - filepath = self._filepath(run_logs, project, uid, "") + ".log" - makedirs(path.dirname(filepath), exist_ok=True) - mode = "ab" if append else "wb" - with open(filepath, mode) as fp: - fp.write(body) - fp.close() - - def get_log(self, uid, project="", offset=0, size=0): - filepath = self._filepath(run_logs, project, uid, "") + ".log" - if pathlib.Path(filepath).is_file(): - with open(filepath, "rb") as fp: - if offset: - fp.seek(offset) - if not size: - size = 2**18 - return "", fp.read(size) - return "", None - - def _run_path(self, uid, iter): - if iter: - return f"{uid}-{iter}" - return uid - - def store_run(self, struct, uid, project="", iter=0): - data = self._dumps(struct) - filepath = ( - self._filepath(run_logs, project, self._run_path(uid, iter), "") - + self.format - ) - self.datastore.put(filepath, data) - - def update_run(self, updates: dict, uid, project="", iter=0): - run = self.read_run(uid, project, iter=iter) - if run and updates: - for key, val in updates.items(): - update_in(run, key, val) - self.store_run(run, uid, project, iter=iter) - - def abort_run(self, uid, project="", iter=0): - raise NotImplementedError() - - def read_run(self, uid, project="", iter=0): - filepath = ( - self._filepath(run_logs, project, self._run_path(uid, iter), "") - + self.format - ) - if not pathlib.Path(filepath).is_file(): - raise mlrun.errors.MLRunNotFoundError(uid) - data = self.datastore.get(filepath) - return self._loads(data) - - def list_runs( - self, - name="", - uid: Optional[Union[str, List[str]]] = None, - project="", - labels=None, - state="", - sort=True, - last=1000, - iter=False, - start_time_from: datetime = None, - start_time_to: datetime = None, - last_update_time_from: datetime = None, - last_update_time_to: datetime = None, - partition_by: Union[schemas.RunPartitionByField, str] = None, - rows_per_partition: int = 1, - partition_sort_by: Union[schemas.SortField, str] = None, - partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, - max_partitions: int = 0, - with_notifications: bool = False, - ): - if partition_by is not None: - raise mlrun.errors.MLRunInvalidArgumentError( - "Runs partitioning not supported" - ) - if uid and isinstance(uid, list): - raise mlrun.errors.MLRunInvalidArgumentError( - "Runs list with multiple uids not supported" - ) - - labels = [] if labels is None else labels - filepath = self._filepath(run_logs, project) - results = RunList() - if isinstance(labels, str): - labels = labels.split(",") - for run, _ in self._load_list(filepath, "*"): - if ( - match_value(name, run, "metadata.name") - and match_labels(get_in(run, "metadata.labels", {}), labels) - and match_value_options(state, run, "status.state") - and match_value(uid, run, "metadata.uid") - and match_times( - start_time_from, - start_time_to, - run, - "status.start_time", - ) - and match_times( - last_update_time_from, - last_update_time_to, - run, - "status.last_update", - ) - and (iter or get_in(run, "metadata.iteration", 0) == 0) - ): - results.append(run) - - if sort or last: - results.sort( - key=lambda i: get_in(i, ["status", "start_time"], ""), reverse=True - ) - if last and len(results) > last: - return RunList(results[:last]) - return results - - def del_run(self, uid, project="", iter=0): - filepath = ( - self._filepath(run_logs, project, self._run_path(uid, iter), "") - + self.format - ) - self._safe_del(filepath) - - def del_runs(self, name="", project="", labels=None, state="", days_ago=0): - - labels = [] if labels is None else labels - if not any([name, state, days_ago, labels]): - raise RunDBError( - "filter is too wide, select name and/or state and/or days_ago" - ) - - filepath = self._filepath(run_logs, project) - if isinstance(labels, str): - labels = labels.split(",") - - if days_ago: - days_ago = datetime.now() - timedelta(days=days_ago) - - def date_before(run): - d = get_in(run, "status.start_time", "") - if not d: - return False - return parse_time(d) < days_ago - - for run, p in self._load_list(filepath, "*"): - if ( - match_value(name, run, "metadata.name") - and match_labels(get_in(run, "metadata.labels", {}), labels) - and match_value(state, run, "status.state") - and (not days_ago or date_before(run)) - ): - self._safe_del(p) - - def store_artifact(self, key, artifact, uid, iter=None, tag="", project=""): - if "updated" not in artifact: - artifact["updated"] = datetime.now(timezone.utc).isoformat() - data = self._dumps(artifact) - if iter: - key = f"{iter}-{key}" - filepath = self._filepath(artifacts_dir, project, key, uid) + self.format - self.datastore.put(filepath, data) - filepath = ( - self._filepath(artifacts_dir, project, key, tag or "latest") + self.format - ) - self.datastore.put(filepath, data) - - def read_artifact(self, key, tag="", iter=None, project=""): - tag = tag or "latest" - if iter: - key = f"{iter}-{key}" - filepath = self._filepath(artifacts_dir, project, key, tag) + self.format - - if not pathlib.Path(filepath).is_file(): - raise RunDBError(key) - data = self.datastore.get(filepath) - return self._loads(data) - - def list_artifacts( - self, - name="", - project="", - tag="", - labels=None, - since=None, - until=None, - iter: int = None, - best_iteration: bool = False, - kind: str = None, - category: Union[str, schemas.ArtifactCategories] = None, - ): - if iter or kind or category: - raise NotImplementedError( - "iter/kind/category parameters are not supported for filedb implementation" - ) - - labels = [] if labels is None else labels - tag = tag or "latest" - name = name or "" - logger.info(f"reading artifacts in {project} name/mask: {name} tag: {tag} ...") - filepath = self._filepath(artifacts_dir, project, tag=tag) - results = ArtifactList() - results.tag = tag - if isinstance(labels, str): - labels = labels.split(",") - if tag == "*": - mask = "**/*" + name - if name: - mask += "*" - else: - mask = "**/*" - - time_pred = make_time_pred(since, until) - for artifact, p in self._load_list(filepath, mask): - if (name == "" or name in get_in(artifact, "key", "")) and match_labels( - get_in(artifact, "labels", {}), labels - ): - if not time_pred(artifact): - continue - if "artifacts/latest" in p: - artifact["tree"] = "latest" - results.append(artifact) - - return results - - def del_artifact(self, key, tag="", project=""): - tag = tag or "latest" - filepath = self._filepath(artifacts_dir, project, key, tag) + self.format - self._safe_del(filepath) - - def del_artifacts(self, name="", project="", tag="", labels=None): - labels = [] if labels is None else labels - tag = tag or "latest" - filepath = self._filepath(artifacts_dir, project, tag=tag) - - if isinstance(labels, str): - labels = labels.split(",") - if tag == "*": - mask = "**/*" + name - if name: - mask += "*" - else: - mask = "**/*" - - for artifact, p in self._load_list(filepath, mask): - if (name == "" or name == get_in(artifact, "key", "")) and match_labels( - get_in(artifact, "labels", {}), labels - ): - - self._safe_del(p) - - def store_function(self, function, name, project="", tag="", versioned=False): - tag = tag or get_in(function, "metadata.tag") or "latest" - hash_key = fill_function_hash(function, tag) - update_in(function, "metadata.updated", datetime.now(timezone.utc)) - update_in(function, "metadata.tag", "") - data = self._dumps(function) - filepath = ( - path.join( - self.dirpath, - functions_dir, - project or config.default_project, - name, - tag, - ) - + self.format - ) - self.datastore.put(filepath, data) - if versioned: - - # the "hash_key" version should not include the status - function["status"] = None - - # versioned means we want this function to be queryable by its hash key so save another file that the - # hash key is the file name - filepath = ( - path.join( - self.dirpath, - functions_dir, - project or config.default_project, - name, - hash_key, - ) - + self.format - ) - data = self._dumps(function) - self.datastore.put(filepath, data) - return hash_key - - def get_function(self, name, project="", tag="", hash_key=""): - tag = tag or "latest" - file_name = hash_key or tag - filepath = ( - path.join( - self.dirpath, - functions_dir, - project or config.default_project, - name, - file_name, - ) - + self.format - ) - if not pathlib.Path(filepath).is_file(): - function_uri = generate_object_uri(project, name, tag, hash_key) - raise mlrun.errors.MLRunNotFoundError(f"Function not found {function_uri}") - data = self.datastore.get(filepath) - parsed_data = self._loads(data) - - # tag should be filled only when queried by tag - parsed_data["metadata"]["tag"] = "" if hash_key else tag - return parsed_data - - def delete_function(self, name: str, project: str = ""): - raise NotImplementedError() - - def list_functions(self, name=None, project="", tag="", labels=None): - labels = labels or [] - logger.info(f"reading functions in {project} name/mask: {name} tag: {tag} ...") - filepath = path.join( - self.dirpath, - functions_dir, - project or config.default_project, - ) - filepath += "/" - - # function name -> tag name -> function dict - functions_with_tag_filename = {} - # function name -> hash key -> function dict - functions_with_hash_key_filename = {} - # function name -> hash keys set - function_with_tag_hash_keys = {} - if isinstance(labels, str): - labels = labels.split(",") - mask = "**/*" - if name: - filepath = f"{filepath}{name}/" - mask = "*" - for func, fullname in self._load_list(filepath, mask): - if match_labels(get_in(func, "metadata.labels", {}), labels): - file_name, _ = path.splitext(path.basename(fullname)) - function_name = path.basename(path.dirname(fullname)) - target_dict = functions_with_tag_filename - - tag_name = file_name - # Heuristic - if tag length is bigger than 20 it's probably a hash key - if len(tag_name) > 20: # hash vs tags - tag_name = "" - target_dict = functions_with_hash_key_filename - else: - function_with_tag_hash_keys.setdefault(function_name, set()).add( - func["metadata"]["hash"] - ) - update_in(func, "metadata.tag", tag_name) - target_dict.setdefault(function_name, {})[file_name] = func - - # clean duplicated function e.g. function that was saved both in a hash key filename and tag filename - for ( - function_name, - hash_keys_to_function_dict_map, - ) in functions_with_hash_key_filename.items(): - function_hash_keys_to_remove = [] - for ( - function_hash_key, - function_dict, - ) in hash_keys_to_function_dict_map.items(): - if function_hash_key in function_with_tag_hash_keys.get( - function_name, set() - ): - function_hash_keys_to_remove.append(function_hash_key) - - for function_hash_key in function_hash_keys_to_remove: - del hash_keys_to_function_dict_map[function_hash_key] - - results = [] - for functions_map in [ - functions_with_hash_key_filename, - functions_with_tag_filename, - ]: - for function_name, filename_to_function_map in functions_map.items(): - results.extend(filename_to_function_map.values()) - - return results - - def _filepath(self, table, project, key="", tag=""): - if tag == "*": - tag = "" - if tag: - key = "/" + key - project = project or config.default_project - return path.join(self.dirpath, table, project, tag + key) - - def list_projects( - self, - owner: str = None, - format_: mlrun.api.schemas.ProjectsFormat = mlrun.api.schemas.ProjectsFormat.full, - labels: List[str] = None, - state: mlrun.api.schemas.ProjectState = None, - names: Optional[List[str]] = None, - ) -> mlrun.api.schemas.ProjectsOutput: - if ( - owner - or format_ == mlrun.api.schemas.ProjectsFormat.full - or labels - or state - or names - ): - raise NotImplementedError() - run_dir = path.join(self.dirpath, run_logs) - if not path.isdir(run_dir): - return mlrun.api.schemas.ProjectsOutput(projects=[]) - project_names = [ - d for d in listdir(run_dir) if path.isdir(path.join(run_dir, d)) - ] - return mlrun.api.schemas.ProjectsOutput(projects=project_names) - - def tag_objects( - self, - project: str, - tag_name: str, - tag_objects: schemas.TagObjects, - replace: bool = False, - ): - raise NotImplementedError() - - def delete_objects_tag( - self, project: str, tag_name: str, tag_objects: schemas.TagObjects - ): - raise NotImplementedError() - - def tag_artifacts( - self, - artifacts, - project: str, - tag_name: str, - replace: bool = False, - ): - raise NotImplementedError() - - def delete_artifacts_tags( - self, - artifacts, - project: str, - tag_name: str, - ): - raise NotImplementedError() - - def get_project(self, name: str) -> mlrun.api.schemas.Project: - # returns None if project not found, mainly for tests, until we remove fileDB - return None - - def delete_project( - self, - name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), - ): - raise NotImplementedError() - - def store_project( - self, - name: str, - project: mlrun.api.schemas.Project, - ) -> mlrun.api.schemas.Project: - raise NotImplementedError() - - def patch_project( - self, - name: str, - project: dict, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, - ) -> mlrun.api.schemas.Project: - raise NotImplementedError() - - def create_project( - self, - project: mlrun.api.schemas.Project, - ) -> mlrun.api.schemas.Project: - raise NotImplementedError() - - @property - def schedules_dir(self): - return path.join(self.dirpath, schedules_dir) - - def store_schedule(self, data): - sched_id = 1 + sum(1 for _ in scandir(self.schedules_dir)) - fname = path.join(self.schedules_dir, f"{sched_id}{self.format}") - with open(fname, "w") as out: - out.write(self._dumps(data)) - - def list_schedules(self): - pattern = f"*{self.format}" - for p in pathlib.Path(self.schedules_dir).glob(pattern): - with p.open() as fp: - yield self._loads(fp.read()) - - return [] - - _encodings = { - ".yaml": ("to_yaml", dict_to_yaml), - ".json": ("to_json", dict_to_json), - } - - def _dumps(self, obj): - meth_name, enc_fn = self._encodings.get(self.format, (None, None)) - if meth_name is None: - raise ValueError(f"unsupported format - {self.format}") - - meth = getattr(obj, meth_name, None) - if meth: - return meth() - - return enc_fn(obj) - - def _loads(self, data): - if self.format == ".yaml": - return yaml.load(data, Loader=yaml.FullLoader) - else: - return json.loads(data) - - def _load_list(self, dirpath, mask): - for p in pathlib.Path(dirpath).glob(mask + self.format): - if p.is_file(): - if ".ipynb_checkpoints" in p.parts: - continue - data = self._loads(p.read_text()) - if data: - yield data, str(p) - - def _safe_del(self, filepath): - if path.isfile(filepath): - remove(filepath) - else: - raise RunDBError(f"run file is not found or valid ({filepath})") - - def create_feature_set(self, feature_set, project="", versioned=True): - raise NotImplementedError() - - def get_feature_set( - self, name: str, project: str = "", tag: str = None, uid: str = None - ): - raise NotImplementedError() - - def list_features( - self, - project: str, - name: str = None, - tag: str = None, - entities: List[str] = None, - labels: List[str] = None, - ): - raise NotImplementedError() - - def list_entities( - self, - project: str, - name: str = None, - tag: str = None, - labels: List[str] = None, - ): - raise NotImplementedError() - - def list_feature_sets( - self, - project: str = "", - name: str = None, - tag: str = None, - state: str = None, - entities: List[str] = None, - features: List[str] = None, - labels: List[str] = None, - partition_by: str = None, - rows_per_partition: int = 1, - partition_sort_by: str = None, - partition_order: str = "desc", - ): - raise NotImplementedError() - - def store_feature_set( - self, feature_set, name=None, project="", tag=None, uid=None, versioned=True - ): - raise NotImplementedError() - - def patch_feature_set( - self, - name, - feature_set, - project="", - tag=None, - uid=None, - patch_mode="replace", - ): - raise NotImplementedError() - - def delete_feature_set(self, name, project="", tag=None, uid=None): - raise NotImplementedError() - - def create_feature_vector(self, feature_vector, project="", versioned=True) -> dict: - raise NotImplementedError() - - def get_feature_vector( - self, name: str, project: str = "", tag: str = None, uid: str = None - ) -> dict: - raise NotImplementedError() - - def list_feature_vectors( - self, - project: str = "", - name: str = None, - tag: str = None, - state: str = None, - labels: List[str] = None, - partition_by: str = None, - rows_per_partition: int = 1, - partition_sort_by: str = None, - partition_order: str = "desc", - ) -> List[dict]: - raise NotImplementedError() - - def store_feature_vector( - self, - feature_vector, - name=None, - project="", - tag=None, - uid=None, - versioned=True, - ): - raise NotImplementedError() - - def patch_feature_vector( - self, - name, - feature_vector_update: dict, - project="", - tag=None, - uid=None, - patch_mode="replace", - ): - raise NotImplementedError() - - def delete_feature_vector(self, name, project="", tag=None, uid=None): - raise NotImplementedError() - - def list_pipelines( - self, - project: str, - namespace: str = None, - sort_by: str = "", - page_token: str = "", - filter_: str = "", - format_: Union[ - str, mlrun.api.schemas.PipelinesFormat - ] = mlrun.api.schemas.PipelinesFormat.metadata_only, - page_size: int = None, - ) -> mlrun.api.schemas.PipelinesOutput: - raise NotImplementedError() - - def create_project_secrets( - self, - project: str, - provider: str = mlrun.api.schemas.SecretProviderName.kubernetes.value, - secrets: dict = None, - ): - raise NotImplementedError() - - def list_project_secrets( - self, - project: str, - token: str, - provider: str = mlrun.api.schemas.SecretProviderName.kubernetes.value, - secrets: List[str] = None, - ) -> mlrun.api.schemas.SecretsData: - raise NotImplementedError() - - def list_project_secret_keys( - self, - project: str, - provider: str = mlrun.api.schemas.SecretProviderName.kubernetes, - token: str = None, - ) -> mlrun.api.schemas.SecretKeysData: - raise NotImplementedError() - - def delete_project_secrets( - self, - project: str, - provider: str = mlrun.api.schemas.SecretProviderName.kubernetes.value, - secrets: List[str] = None, - ): - raise NotImplementedError() - - def create_user_secrets( - self, - user: str, - provider: str = mlrun.api.schemas.secret.SecretProviderName.vault.value, - secrets: dict = None, - ): - raise NotImplementedError() - - def list_artifact_tags(self, project=None, category=None): - raise NotImplementedError() - - def create_model_endpoint( - self, - project: str, - endpoint_id: str, - model_endpoint: ModelEndpoint, - ): - raise NotImplementedError() - - def delete_model_endpoint( - self, - project: str, - endpoint_id: str, - ): - raise NotImplementedError() - - def list_model_endpoints( - self, - project: str, - model: Optional[str] = None, - function: Optional[str] = None, - labels: List[str] = None, - start: str = "now-1h", - end: str = "now", - metrics: Optional[List[str]] = None, - ): - raise NotImplementedError() - - def get_model_endpoint( - self, - project: str, - endpoint_id: str, - start: Optional[str] = None, - end: Optional[str] = None, - metrics: Optional[List[str]] = None, - features: bool = False, - ): - raise NotImplementedError() - - def patch_model_endpoint( - self, - project: str, - endpoint_id: str, - attributes: dict, - ): - raise NotImplementedError() - - def create_marketplace_source( - self, source: Union[dict, schemas.IndexedMarketplaceSource] - ): - raise NotImplementedError() - - def store_marketplace_source( - self, source_name: str, source: Union[dict, schemas.IndexedMarketplaceSource] - ): - raise NotImplementedError() - - def list_marketplace_sources(self): - raise NotImplementedError() - - def get_marketplace_source(self, source_name: str): - raise NotImplementedError() - - def delete_marketplace_source(self, source_name: str): - raise NotImplementedError() - - def get_marketplace_catalog( - self, - source_name: str, - version: str = None, - tag: str = None, - force_refresh: bool = False, - ): - raise NotImplementedError() - - def get_marketplace_item( - self, - source_name: str, - item_name: str, - version: str = None, - tag: str = "latest", - force_refresh: bool = False, - ): - raise NotImplementedError() - - def verify_authorization( - self, - authorization_verification_input: mlrun.api.schemas.AuthorizationVerificationInput, - ): - raise NotImplementedError() - - -def make_time_pred(since, until): - if not (since or until): - return lambda artifact: True - - since = since or datetime.min - until = until or datetime.max - - if since.tzinfo is None: - since = since.replace(tzinfo=timezone.utc) - if until.tzinfo is None: - until = until.replace(tzinfo=timezone.utc) - - def pred(artifact): - val = artifact.get("updated") - if not val: - return True - t = parse_time(val).replace(tzinfo=timezone.utc) - return since <= t <= until - - return pred diff --git a/mlrun/db/httpdb.py b/mlrun/db/httpdb.py index ccf97dfd9a38..6b8fa1588975 100644 --- a/mlrun/db/httpdb.py +++ b/mlrun/db/httpdb.py @@ -27,11 +27,12 @@ import semver import mlrun +import mlrun.api.utils.helpers +import mlrun.common.schemas +import mlrun.model_monitoring.model_endpoint import mlrun.projects -from mlrun.api import schemas from mlrun.errors import MLRunInvalidArgumentError, err_to_str -from ..api.schemas import ModelEndpoint from ..artifacts import Artifact from ..config import config from ..feature_store import FeatureSet, FeatureVector @@ -192,13 +193,13 @@ def api_call( if "Authorization" not in kw.setdefault("headers", {}): kw["headers"].update({"Authorization": "Bearer " + self.token}) - if mlrun.api.schemas.HeaderNames.client_version not in kw.setdefault( + if mlrun.common.schemas.HeaderNames.client_version not in kw.setdefault( "headers", {} ): kw["headers"].update( { - mlrun.api.schemas.HeaderNames.client_version: self.client_version, - mlrun.api.schemas.HeaderNames.python_version: self.python_version, + mlrun.common.schemas.HeaderNames.client_version: self.client_version, + mlrun.common.schemas.HeaderNames.python_version: self.python_version, } ) @@ -242,7 +243,7 @@ def api_call( def _init_session(self): return mlrun.utils.HTTPSessionWithRetry( retry_on_exception=config.httpdb.retry_api_call_on_exception - == mlrun.api.schemas.HTTPSessionRetryMode.enabled.value + == mlrun.common.schemas.HTTPSessionRetryMode.enabled.value ) def _path_of(self, prefix, project, uid): @@ -499,16 +500,16 @@ def store_run(self, struct, uid, project="", iter=0): body = _as_json(struct) self.api_call("POST", path, error, params=params, body=body) - def update_run(self, updates: dict, uid, project="", iter=0): + def update_run(self, updates: dict, uid, project="", iter=0, timeout=45): """Update the details of a stored run in the DB.""" path = self._path_of("run", project, uid) params = {"iter": iter} error = f"update run {project}/{uid}" body = _as_json(updates) - self.api_call("PATCH", path, error, params=params, body=body) + self.api_call("PATCH", path, error, params=params, body=body, timeout=timeout) - def abort_run(self, uid, project="", iter=0): + def abort_run(self, uid, project="", iter=0, timeout=45): """ Abort a running run - will remove the run's runtime resources and mark its state as aborted """ @@ -517,6 +518,7 @@ def abort_run(self, uid, project="", iter=0): uid, project, iter, + timeout, ) def read_run(self, uid, project="", iter=0): @@ -548,29 +550,33 @@ def del_run(self, uid, project="", iter=0): def list_runs( self, - name=None, + name: Optional[str] = None, uid: Optional[Union[str, List[str]]] = None, - project=None, - labels=None, - state=None, - sort=True, - last=0, - iter=False, - start_time_from: datetime = None, - start_time_to: datetime = None, - last_update_time_from: datetime = None, - last_update_time_to: datetime = None, - partition_by: Union[schemas.RunPartitionByField, str] = None, + project: Optional[str] = None, + labels: Optional[Union[str, List[str]]] = None, + state: Optional[str] = None, + sort: bool = True, + last: int = 0, + iter: bool = False, + start_time_from: Optional[datetime] = None, + start_time_to: Optional[datetime] = None, + last_update_time_from: Optional[datetime] = None, + last_update_time_to: Optional[datetime] = None, + partition_by: Optional[ + Union[mlrun.common.schemas.RunPartitionByField, str] + ] = None, rows_per_partition: int = 1, - partition_sort_by: Union[schemas.SortField, str] = None, - partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, + partition_sort_by: Optional[Union[mlrun.common.schemas.SortField, str]] = None, + partition_order: Union[ + mlrun.common.schemas.OrderType, str + ] = mlrun.common.schemas.OrderType.desc, max_partitions: int = 0, with_notifications: bool = False, ) -> RunList: """Retrieve a list of runs, filtered by various options. Example:: - runs = db.list_runs(name='download', project='iris', labels='owner=admin') + runs = db.list_runs(name='download', project='iris', labels=['owner=admin', 'kind=job']) # If running in Jupyter, can use the .show() function to display the results db.list_runs(name='', project=project_name).show() @@ -578,8 +584,8 @@ def list_runs( :param name: Name of the run to retrieve. :param uid: Unique ID of the run, or a list of run UIDs. :param project: Project that the runs belongs to. - :param labels: List runs that have a specific label assigned. Currently only a single label filter can be - applied, otherwise result will be empty. + :param labels: List runs that have specific labels assigned. a single or multi label filter can be + applied. :param state: List only runs whose state is specified. :param sort: Whether to sort the result according to their start time. Otherwise, results will be returned by their internal order in the DB (order will not be guaranteed). @@ -615,13 +621,13 @@ def list_runs( "start_time_to": datetime_to_iso(start_time_to), "last_update_time_from": datetime_to_iso(last_update_time_from), "last_update_time_to": datetime_to_iso(last_update_time_to), - "with_notifications": with_notifications, + "with-notifications": with_notifications, } if partition_by: params.update( self._generate_partition_by_params( - schemas.RunPartitionByField, + mlrun.common.schemas.RunPartitionByField, partition_by, rows_per_partition, partition_sort_by, @@ -690,7 +696,7 @@ def read_artifact(self, key, tag=None, iter=None, project=""): endpoint_path = f"projects/{project}/artifacts/{key}?tag={tag}" error = f"read artifact {project}/{key}" # explicitly set artifacts format to 'full' since old servers may default to 'legacy' - params = {"format": schemas.ArtifactsFormat.full.value} + params = {"format": mlrun.common.schemas.ArtifactsFormat.full.value} if iter: params["iter"] = str(iter) resp = self.api_call("GET", endpoint_path, error, params=params) @@ -718,7 +724,7 @@ def list_artifacts( iter: int = None, best_iteration: bool = False, kind: str = None, - category: Union[str, schemas.ArtifactCategories] = None, + category: Union[str, mlrun.common.schemas.ArtifactCategories] = None, ) -> ArtifactList: """List artifacts filtered by various parameters. @@ -762,7 +768,7 @@ def list_artifacts( "best-iteration": best_iteration, "kind": kind, "category": category, - "format": schemas.ArtifactsFormat.full.value, + "format": mlrun.common.schemas.ArtifactsFormat.full.value, } error = "list artifacts" endpoint_path = f"projects/{project}/artifacts" @@ -795,7 +801,7 @@ def del_artifacts(self, name=None, project=None, tag=None, labels=None, days_ago def list_artifact_tags( self, project=None, - category: Union[str, schemas.ArtifactCategories] = None, + category: Union[str, mlrun.common.schemas.ArtifactCategories] = None, ) -> List[str]: """Return a list of all the tags assigned to artifacts in the scope of the given project.""" @@ -823,7 +829,7 @@ def store_function( params = {"tag": tag, "versioned": versioned} project = project or config.default_project - path = self._path_of("func", project, name) + path = f"projects/{project}/functions/{name}" error = f"store function {project}/{name}" resp = self.api_call( @@ -838,7 +844,7 @@ def get_function(self, name, project="", tag=None, hash_key=""): params = {"tag": tag, "hash_key": hash_key} project = project or config.default_project - path = self._path_of("func", project, name) + path = f"projects/{project}/functions/{name}" error = f"get function {project}/{name}" resp = self.api_call("GET", path, error, params=params) return resp.json()["func"] @@ -860,15 +866,15 @@ def list_functions(self, name=None, project=None, tag=None, labels=None): :param labels: Return functions that have specific labels assigned to them. :returns: List of function objects (as dictionary). """ - + project = project or config.default_project params = { - "project": project or config.default_project, "name": name, "tag": tag, "label": labels or [], } error = "list functions" - resp = self.api_call("GET", "funcs", error, params=params) + path = f"projects/{project}/functions" + resp = self.api_call("GET", path, error, params=params) return resp.json()["funcs"] def list_runtime_resources( @@ -877,11 +883,13 @@ def list_runtime_resources( label_selector: Optional[str] = None, kind: Optional[str] = None, object_id: Optional[str] = None, - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ) -> Union[ - mlrun.api.schemas.RuntimeResourcesOutput, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResourcesOutput, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ]: """List current runtime resources, which are usually (but not limited to) Kubernetes pods or CRDs. Function applies for runs of type `['dask', 'job', 'spark', 'remote-spark', 'mpijob']`, and will return per @@ -910,25 +918,25 @@ def list_runtime_resources( ) if group_by is None: structured_list = [ - mlrun.api.schemas.KindRuntimeResources(**kind_runtime_resources) + mlrun.common.schemas.KindRuntimeResources(**kind_runtime_resources) for kind_runtime_resources in response.json() ] return structured_list - elif group_by == mlrun.api.schemas.ListRuntimeResourcesGroupByField.job: + elif group_by == mlrun.common.schemas.ListRuntimeResourcesGroupByField.job: structured_dict = {} for project, job_runtime_resources_map in response.json().items(): for job_id, runtime_resources in job_runtime_resources_map.items(): structured_dict.setdefault(project, {})[ job_id - ] = mlrun.api.schemas.RuntimeResources(**runtime_resources) + ] = mlrun.common.schemas.RuntimeResources(**runtime_resources) return structured_dict - elif group_by == mlrun.api.schemas.ListRuntimeResourcesGroupByField.project: + elif group_by == mlrun.common.schemas.ListRuntimeResourcesGroupByField.project: structured_dict = {} for project, kind_runtime_resources_map in response.json().items(): for kind, runtime_resources in kind_runtime_resources_map.items(): structured_dict.setdefault(project, {})[ kind - ] = mlrun.api.schemas.RuntimeResources(**runtime_resources) + ] = mlrun.common.schemas.RuntimeResources(**runtime_resources) return structured_dict else: raise NotImplementedError( @@ -943,7 +951,7 @@ def delete_runtime_resources( object_id: Optional[str] = None, force: bool = False, grace_period: int = None, - ) -> mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput: + ) -> mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput: """Delete all runtime resources which are in terminal state. :param project: Delete only runtime resources of a specific project, by default None, which will delete only @@ -958,7 +966,7 @@ def delete_runtime_resources( :param grace_period: Grace period given to the runtime resource before they are actually removed, counted from the moment they moved to terminal state (defaults to mlrun.mlconf.runtime_resources_deletion_grace_period). - :returns: :py:class:`~mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput` listing the runtime resources + :returns: :py:class:`~mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput` listing the runtime resources that were removed. """ if grace_period is None: @@ -988,10 +996,12 @@ def delete_runtime_resources( for kind, runtime_resources in kind_runtime_resources_map.items(): structured_dict.setdefault(project, {})[ kind - ] = mlrun.api.schemas.RuntimeResources(**runtime_resources) + ] = mlrun.common.schemas.RuntimeResources(**runtime_resources) return structured_dict - def create_schedule(self, project: str, schedule: schemas.ScheduleInput): + def create_schedule( + self, project: str, schedule: mlrun.common.schemas.ScheduleInput + ): """Create a new schedule on the given project. The details on the actual object to schedule as well as the schedule itself are within the schedule object provided. The :py:class:`~ScheduleCronTrigger` follows the guidelines in @@ -1003,7 +1013,7 @@ def create_schedule(self, project: str, schedule: schemas.ScheduleInput): Example:: - from mlrun.api import schemas + from mlrun.common import schemas # Execute the get_data_func function every Tuesday at 15:30 schedule = schemas.ScheduleInput( @@ -1022,7 +1032,7 @@ def create_schedule(self, project: str, schedule: schemas.ScheduleInput): self.api_call("POST", path, error_message, body=dict_to_json(schedule.dict())) def update_schedule( - self, project: str, name: str, schedule: schemas.ScheduleUpdate + self, project: str, name: str, schedule: mlrun.common.schemas.ScheduleUpdate ): """Update an existing schedule, replace it with the details contained in the schedule object.""" @@ -1034,7 +1044,7 @@ def update_schedule( def get_schedule( self, project: str, name: str, include_last_run: bool = False - ) -> schemas.ScheduleOutput: + ) -> mlrun.common.schemas.ScheduleOutput: """Retrieve details of the schedule in question. Besides returning the details of the schedule object itself, this function also returns the next scheduled run for this specific schedule, as well as potentially the results of the last run executed through this schedule. @@ -1050,15 +1060,15 @@ def get_schedule( resp = self.api_call( "GET", path, error_message, params={"include_last_run": include_last_run} ) - return schemas.ScheduleOutput(**resp.json()) + return mlrun.common.schemas.ScheduleOutput(**resp.json()) def list_schedules( self, project: str, name: str = None, - kind: schemas.ScheduleKinds = None, + kind: mlrun.common.schemas.ScheduleKinds = None, include_last_run: bool = False, - ) -> schemas.SchedulesOutput: + ) -> mlrun.common.schemas.SchedulesOutput: """Retrieve list of schedules of specific name or kind. :param project: Project name. @@ -1073,7 +1083,7 @@ def list_schedules( path = f"projects/{project}/schedules" error_message = f"Failed listing schedules for {project} ? {kind} {name}" resp = self.api_call("GET", path, error_message, params=params) - return schemas.SchedulesOutput(**resp.json()) + return mlrun.common.schemas.SchedulesOutput(**resp.json()) def delete_schedule(self, project: str, name: str): """Delete a specific schedule by name.""" @@ -1136,19 +1146,20 @@ def remote_builder( def get_builder_status( self, func: BaseRuntime, - offset=0, - logs=True, - last_log_timestamp=0, - verbose=False, + offset: int = 0, + logs: bool = True, + last_log_timestamp: float = 0.0, + verbose: bool = False, ): """Retrieve the status of a build operation currently in progress. - :param func: Function object that is being built. - :param offset: Offset into the build logs to retrieve logs from. - :param logs: Should build logs be retrieved. - :param last_log_timestamp: Last timestamp of logs that were already retrieved. Function will return only logs - later than this parameter. - :param verbose: Add verbose logs into the output. + :param func: Function object that is being built. + :param offset: Offset into the build logs to retrieve logs from. + :param logs: Should build logs be retrieved. + :param last_log_timestamp: Last timestamp of logs that were already retrieved. Function will return only logs + later than this parameter. + :param verbose: Add verbose logs into the output. + :returns: The following parameters: - Text of builder logs. @@ -1202,7 +1213,7 @@ def get_builder_status( text = resp.content.decode() return text, last_log_timestamp - def remote_start(self, func_url) -> schemas.BackgroundTask: + def remote_start(self, func_url) -> mlrun.common.schemas.BackgroundTask: """Execute a function remotely, Used for ``dask`` functions. :param func_url: URL to the function to be executed. @@ -1225,13 +1236,13 @@ def remote_start(self, func_url) -> schemas.BackgroundTask: logger.error(f"bad resp!!\n{resp.text}") raise ValueError("bad function start response") - return schemas.BackgroundTask(**resp.json()) + return mlrun.common.schemas.BackgroundTask(**resp.json()) def get_project_background_task( self, project: str, name: str, - ) -> schemas.BackgroundTask: + ) -> mlrun.common.schemas.BackgroundTask: """Retrieve updated information on a project background task being executed.""" project = project or config.default_project @@ -1240,15 +1251,15 @@ def get_project_background_task( f"Failed getting project background task. project={project}, name={name}" ) response = self.api_call("GET", path, error_message) - return schemas.BackgroundTask(**response.json()) + return mlrun.common.schemas.BackgroundTask(**response.json()) - def get_background_task(self, name: str) -> schemas.BackgroundTask: + def get_background_task(self, name: str) -> mlrun.common.schemas.BackgroundTask: """Retrieve updated information on a background task being executed.""" path = f"background-tasks/{name}" error_message = f"Failed getting background task. name={name}" response = self.api_call("GET", path, error_message) - return schemas.BackgroundTask(**response.json()) + return mlrun.common.schemas.BackgroundTask(**response.json()) def remote_status(self, project, name, kind, selector): """Retrieve status of a function being executed remotely (relevant to ``dask`` functions). @@ -1273,7 +1284,9 @@ def remote_status(self, project, name, kind, selector): return resp.json()["data"] def submit_job( - self, runspec, schedule: Union[str, schemas.ScheduleCronTrigger] = None + self, + runspec, + schedule: Union[str, mlrun.common.schemas.ScheduleCronTrigger] = None, ): """Submit a job for remote execution. @@ -1285,7 +1298,7 @@ def submit_job( try: req = {"task": runspec.to_dict()} if schedule: - if isinstance(schedule, schemas.ScheduleCronTrigger): + if isinstance(schedule, mlrun.common.schemas.ScheduleCronTrigger): schedule = schedule.dict() req["schedule"] = schedule timeout = (int(config.submit_timeout) or 120) + 20 @@ -1367,7 +1380,9 @@ def submit_pipeline( if arguments: if not isinstance(arguments, dict): raise ValueError("arguments must be dict type") - headers[schemas.HeaderNames.pipeline_arguments] = str(arguments) + headers[mlrun.common.schemas.HeaderNames.pipeline_arguments] = str( + arguments + ) if not path.isfile(pipe_file): raise OSError(f"file {pipe_file} doesnt exist") @@ -1406,10 +1421,10 @@ def list_pipelines( page_token: str = "", filter_: str = "", format_: Union[ - str, mlrun.api.schemas.PipelinesFormat - ] = mlrun.api.schemas.PipelinesFormat.metadata_only, + str, mlrun.common.schemas.PipelinesFormat + ] = mlrun.common.schemas.PipelinesFormat.metadata_only, page_size: int = None, - ) -> mlrun.api.schemas.PipelinesOutput: + ) -> mlrun.common.schemas.PipelinesOutput: """Retrieve a list of KFP pipelines. This function can be invoked to get all pipelines from all projects, by specifying ``project=*``, in which case pagination can be used and the various sorting and pagination properties can be applied. If a specific project is requested, then the pagination options cannot be @@ -1445,7 +1460,7 @@ def list_pipelines( response = self.api_call( "GET", f"projects/{project}/pipelines", error_message, params=params ) - return mlrun.api.schemas.PipelinesOutput(**response.json()) + return mlrun.common.schemas.PipelinesOutput(**response.json()) def get_pipeline( self, @@ -1453,27 +1468,23 @@ def get_pipeline( namespace: str = None, timeout: int = 10, format_: Union[ - str, mlrun.api.schemas.PipelinesFormat - ] = mlrun.api.schemas.PipelinesFormat.summary, + str, mlrun.common.schemas.PipelinesFormat + ] = mlrun.common.schemas.PipelinesFormat.summary, project: str = None, ): """Retrieve details of a specific pipeline using its run ID (as provided when the pipeline was executed).""" - try: - params = {} - if namespace: - params["namespace"] = namespace - params["format"] = format_ - project_path = project if project else "*" - resp = self.api_call( - "GET", - f"projects/{project_path}/pipelines/{run_id}", - params=params, - timeout=timeout, - ) - except OSError as err: - logger.error(f"error cannot get pipeline: {err_to_str(err)}") - raise OSError(f"error: cannot get pipeline, {err_to_str(err)}") + params = {} + if namespace: + params["namespace"] = namespace + params["format"] = format_ + project_path = project if project else "*" + resp = self.api_call( + "GET", + f"projects/{project_path}/pipelines/{run_id}", + params=params, + timeout=timeout, + ) if not resp.ok: logger.error(f"bad resp!!\n{resp.text}") @@ -1489,7 +1500,7 @@ def _resolve_reference(tag, uid): def create_feature_set( self, - feature_set: Union[dict, schemas.FeatureSet, FeatureSet], + feature_set: Union[dict, mlrun.common.schemas.FeatureSet, FeatureSet], project="", versioned=True, ) -> dict: @@ -1502,7 +1513,7 @@ def create_feature_set( will be kept in the DB and can be retrieved until explicitly deleted. :returns: The :py:class:`~mlrun.feature_store.FeatureSet` object (as dict). """ - if isinstance(feature_set, schemas.FeatureSet): + if isinstance(feature_set, mlrun.common.schemas.FeatureSet): feature_set = feature_set.dict() elif isinstance(feature_set, FeatureSet): feature_set = feature_set.to_dict() @@ -1636,10 +1647,14 @@ def list_feature_sets( entities: List[str] = None, features: List[str] = None, labels: List[str] = None, - partition_by: Union[schemas.FeatureStorePartitionByField, str] = None, + partition_by: Union[ + mlrun.common.schemas.FeatureStorePartitionByField, str + ] = None, rows_per_partition: int = 1, - partition_sort_by: Union[schemas.SortField, str] = None, - partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, + partition_sort_by: Union[mlrun.common.schemas.SortField, str] = None, + partition_order: Union[ + mlrun.common.schemas.OrderType, str + ] = mlrun.common.schemas.OrderType.desc, ) -> List[FeatureSet]: """Retrieve a list of feature-sets matching the criteria provided. @@ -1673,7 +1688,7 @@ def list_feature_sets( if partition_by: params.update( self._generate_partition_by_params( - schemas.FeatureStorePartitionByField, + mlrun.common.schemas.FeatureStorePartitionByField, partition_by, rows_per_partition, partition_sort_by, @@ -1693,7 +1708,7 @@ def list_feature_sets( def store_feature_set( self, - feature_set: Union[dict, schemas.FeatureSet, FeatureSet], + feature_set: Union[dict, mlrun.common.schemas.FeatureSet, FeatureSet], name=None, project="", tag=None, @@ -1718,7 +1733,7 @@ def store_feature_set( reference = self._resolve_reference(tag, uid) params = {"versioned": versioned} - if isinstance(feature_set, schemas.FeatureSet): + if isinstance(feature_set, mlrun.common.schemas.FeatureSet): feature_set = feature_set.dict() elif isinstance(feature_set, FeatureSet): feature_set = feature_set.to_dict() @@ -1741,7 +1756,9 @@ def patch_feature_set( project="", tag=None, uid=None, - patch_mode: Union[str, schemas.PatchMode] = schemas.PatchMode.replace, + patch_mode: Union[ + str, mlrun.common.schemas.PatchMode + ] = mlrun.common.schemas.PatchMode.replace, ): """Modify (patch) an existing :py:class:`~mlrun.feature_store.FeatureSet` object. The object is identified by its name (and project it belongs to), as well as optionally a ``tag`` or its @@ -1764,7 +1781,7 @@ def patch_feature_set( """ project = project or config.default_project reference = self._resolve_reference(tag, uid) - headers = {schemas.HeaderNames.patch_mode: patch_mode} + headers = {mlrun.common.schemas.HeaderNames.patch_mode: patch_mode} path = f"projects/{project}/feature-sets/{name}/references/{reference}" error_message = f"Failed updating feature-set {project}/{name}" self.api_call( @@ -1793,7 +1810,7 @@ def delete_feature_set(self, name, project="", tag=None, uid=None): def create_feature_vector( self, - feature_vector: Union[dict, schemas.FeatureVector, FeatureVector], + feature_vector: Union[dict, mlrun.common.schemas.FeatureVector, FeatureVector], project="", versioned=True, ) -> dict: @@ -1805,7 +1822,7 @@ def create_feature_vector( will be kept in the DB and can be retrieved until explicitly deleted. :returns: The :py:class:`~mlrun.feature_store.FeatureVector` object (as dict). """ - if isinstance(feature_vector, schemas.FeatureVector): + if isinstance(feature_vector, mlrun.common.schemas.FeatureVector): feature_vector = feature_vector.dict() elif isinstance(feature_vector, FeatureVector): feature_vector = feature_vector.to_dict() @@ -1849,10 +1866,14 @@ def list_feature_vectors( tag: str = None, state: str = None, labels: List[str] = None, - partition_by: Union[schemas.FeatureStorePartitionByField, str] = None, + partition_by: Union[ + mlrun.common.schemas.FeatureStorePartitionByField, str + ] = None, rows_per_partition: int = 1, - partition_sort_by: Union[schemas.SortField, str] = None, - partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, + partition_sort_by: Union[mlrun.common.schemas.SortField, str] = None, + partition_order: Union[ + mlrun.common.schemas.OrderType, str + ] = mlrun.common.schemas.OrderType.desc, ) -> List[FeatureVector]: """Retrieve a list of feature-vectors matching the criteria provided. @@ -1882,7 +1903,7 @@ def list_feature_vectors( if partition_by: params.update( self._generate_partition_by_params( - schemas.FeatureStorePartitionByField, + mlrun.common.schemas.FeatureStorePartitionByField, partition_by, rows_per_partition, partition_sort_by, @@ -1902,7 +1923,7 @@ def list_feature_vectors( def store_feature_vector( self, - feature_vector: Union[dict, schemas.FeatureVector, FeatureVector], + feature_vector: Union[dict, mlrun.common.schemas.FeatureVector, FeatureVector], name=None, project="", tag=None, @@ -1927,7 +1948,7 @@ def store_feature_vector( reference = self._resolve_reference(tag, uid) params = {"versioned": versioned} - if isinstance(feature_vector, schemas.FeatureVector): + if isinstance(feature_vector, mlrun.common.schemas.FeatureVector): feature_vector = feature_vector.dict() elif isinstance(feature_vector, FeatureVector): feature_vector = feature_vector.to_dict() @@ -1952,7 +1973,9 @@ def patch_feature_vector( project="", tag=None, uid=None, - patch_mode: Union[str, schemas.PatchMode] = schemas.PatchMode.replace, + patch_mode: Union[ + str, mlrun.common.schemas.PatchMode + ] = mlrun.common.schemas.PatchMode.replace, ): """Modify (patch) an existing :py:class:`~mlrun.feature_store.FeatureVector` object. The object is identified by its name (and project it belongs to), as well as optionally a ``tag`` or its @@ -1970,7 +1993,7 @@ def patch_feature_vector( """ reference = self._resolve_reference(tag, uid) project = project or config.default_project - headers = {schemas.HeaderNames.patch_mode: patch_mode} + headers = {mlrun.common.schemas.HeaderNames.patch_mode: patch_mode} path = f"projects/{project}/feature-vectors/{name}/references/{reference}" error_message = f"Failed updating feature-vector {project}/{name}" self.api_call( @@ -2000,7 +2023,7 @@ def tag_objects( self, project: str, tag_name: str, - objects: Union[mlrun.api.schemas.TagObjects, dict], + objects: Union[mlrun.common.schemas.TagObjects, dict], replace: bool = False, ): """Tag a list of objects. @@ -2020,7 +2043,7 @@ def tag_objects( error_message, body=dict_to_json( objects.dict() - if isinstance(objects, mlrun.api.schemas.TagObjects) + if isinstance(objects, mlrun.common.schemas.TagObjects) else objects ), ) @@ -2029,7 +2052,7 @@ def delete_objects_tag( self, project: str, tag_name: str, - tag_objects: Union[mlrun.api.schemas.TagObjects, dict], + tag_objects: Union[mlrun.common.schemas.TagObjects, dict], ): """Delete a tag from a list of objects. @@ -2046,7 +2069,7 @@ def delete_objects_tag( error_message, body=dict_to_json( tag_objects.dict() - if isinstance(tag_objects, mlrun.api.schemas.TagObjects) + if isinstance(tag_objects, mlrun.common.schemas.TagObjects) else tag_objects ), ) @@ -2089,10 +2112,10 @@ def list_projects( self, owner: str = None, format_: Union[ - str, mlrun.api.schemas.ProjectsFormat - ] = mlrun.api.schemas.ProjectsFormat.full, + str, mlrun.common.schemas.ProjectsFormat + ] = mlrun.common.schemas.ProjectsFormat.full, labels: List[str] = None, - state: Union[str, mlrun.api.schemas.ProjectState] = None, + state: Union[str, mlrun.common.schemas.ProjectState] = None, ) -> List[Union[mlrun.projects.MlrunProject, str]]: """Return a list of the existing projects, potentially filtered by specific criteria. @@ -2100,6 +2123,7 @@ def list_projects( :param format_: Format of the results. Possible values are: - ``full`` (default value) - Return full project objects. + - ``minimal`` - Return minimal project objects (minimization happens in the BE). - ``name_only`` - Return just the names of the projects. :param labels: Filter by labels attached to the project. @@ -2115,17 +2139,17 @@ def list_projects( error_message = f"Failed listing projects, query: {params}" response = self.api_call("GET", "projects", error_message, params=params) - if format_ == mlrun.api.schemas.ProjectsFormat.name_only: + if format_ == mlrun.common.schemas.ProjectsFormat.name_only: + + # projects is just a list of strings return response.json()["projects"] - elif format_ == mlrun.api.schemas.ProjectsFormat.full: - return [ - mlrun.projects.MlrunProject.from_dict(project_dict) - for project_dict in response.json()["projects"] - ] - else: - raise NotImplementedError( - f"Provided format is not supported. format={format_}" - ) + + # forwards compatibility - we want to be able to handle new formats that might be added in the future + # if format is not known to the api, it is up to the server to return either an error or a default format + return [ + mlrun.projects.MlrunProject.from_dict(project_dict) + for project_dict in response.json()["projects"] + ] def get_project(self, name: str) -> mlrun.projects.MlrunProject: """Get details for a specific project.""" @@ -2142,8 +2166,8 @@ def delete_project( self, name: str, deletion_strategy: Union[ - str, mlrun.api.schemas.DeletionStrategy - ] = mlrun.api.schemas.DeletionStrategy.default(), + str, mlrun.common.schemas.DeletionStrategy + ] = mlrun.common.schemas.DeletionStrategy.default(), ): """Delete a project. @@ -2156,7 +2180,9 @@ def delete_project( """ path = f"projects/{name}" - headers = {schemas.HeaderNames.deletion_strategy: deletion_strategy} + headers = { + mlrun.common.schemas.HeaderNames.deletion_strategy: deletion_strategy + } error_message = f"Failed deleting project {name}" response = self.api_call("DELETE", path, error_message, headers=headers) if response.status_code == http.HTTPStatus.ACCEPTED: @@ -2165,13 +2191,13 @@ def delete_project( def store_project( self, name: str, - project: Union[dict, mlrun.projects.MlrunProject, mlrun.api.schemas.Project], + project: Union[dict, mlrun.projects.MlrunProject, mlrun.common.schemas.Project], ) -> mlrun.projects.MlrunProject: """Store a project in the DB. This operation will overwrite existing project of the same name if exists.""" path = f"projects/{name}" error_message = f"Failed storing project {name}" - if isinstance(project, mlrun.api.schemas.Project): + if isinstance(project, mlrun.common.schemas.Project): project = project.dict() elif isinstance(project, mlrun.projects.MlrunProject): project = project.to_dict() @@ -2189,7 +2215,9 @@ def patch_project( self, name: str, project: dict, - patch_mode: Union[str, schemas.PatchMode] = schemas.PatchMode.replace, + patch_mode: Union[ + str, mlrun.common.schemas.PatchMode + ] = mlrun.common.schemas.PatchMode.replace, ) -> mlrun.projects.MlrunProject: """Patch an existing project object. @@ -2200,7 +2228,7 @@ def patch_project( """ path = f"projects/{name}" - headers = {schemas.HeaderNames.patch_mode: patch_mode} + headers = {mlrun.common.schemas.HeaderNames.patch_mode: patch_mode} error_message = f"Failed patching project {name}" response = self.api_call( "PATCH", path, error_message, body=dict_to_json(project), headers=headers @@ -2209,11 +2237,11 @@ def patch_project( def create_project( self, - project: Union[dict, mlrun.projects.MlrunProject, mlrun.api.schemas.Project], + project: Union[dict, mlrun.projects.MlrunProject, mlrun.common.schemas.Project], ) -> mlrun.projects.MlrunProject: """Create a new project. A project with the same name must not exist prior to creation.""" - if isinstance(project, mlrun.api.schemas.Project): + if isinstance(project, mlrun.common.schemas.Project): project = project.dict() elif isinstance(project, mlrun.projects.MlrunProject): project = project.to_dict() @@ -2236,7 +2264,7 @@ def _verify_project_in_terminal_state(): project = self.get_project(project_name) if ( project.status.state - not in mlrun.api.schemas.ProjectState.terminal_states() + not in mlrun.common.schemas.ProjectState.terminal_states() ): raise Exception( f"Project not in terminal state. State: {project.status.state}" @@ -2253,11 +2281,11 @@ def _verify_project_in_terminal_state(): def _wait_for_background_task_to_reach_terminal_state( self, name: str - ) -> schemas.BackgroundTask: + ) -> mlrun.common.schemas.BackgroundTask: def _verify_background_task_in_terminal_state(): background_task = self.get_background_task(name) state = background_task.status.state - if state not in mlrun.api.schemas.BackgroundTaskState.terminal_states(): + if state not in mlrun.common.schemas.BackgroundTaskState.terminal_states(): raise Exception( f"Background task not in terminal state. name={name}, state={state}" ) @@ -2274,10 +2302,10 @@ def _verify_background_task_in_terminal_state(): def _wait_for_project_to_be_deleted(self, project_name: str): def _verify_project_deleted(): projects = self.list_projects( - format_=mlrun.api.schemas.ProjectsFormat.name_only + format_=mlrun.common.schemas.ProjectsFormat.name_only ) if project_name in projects: - raise Exception("Project still exists") + raise Exception(f"Project {project_name} still exists") return mlrun.utils.helpers.retry_until_successful( self._wait_for_project_deletion_interval, @@ -2291,8 +2319,8 @@ def create_project_secrets( self, project: str, provider: Union[ - str, schemas.SecretProviderName - ] = schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, secrets: dict = None, ): """Create project-context secrets using either ``vault`` or ``kubernetes`` provider. @@ -2310,19 +2338,21 @@ def create_project_secrets( :param project: The project context for which to generate the infra and store secrets. :param provider: The name of the secrets-provider to work with. Accepts a - :py:class:`~mlrun.api.schemas.secret.SecretProviderName` enum. + :py:class:`~mlrun.common.schemas.secret.SecretProviderName` enum. :param secrets: A set of secret values to store. Example:: secrets = {'password': 'myPassw0rd', 'aws_key': '111222333'} db.create_project_secrets( "project1", - provider=mlrun.api.schemas.SecretProviderName.kubernetes, + provider=mlrun.common.schemas.SecretProviderName.kubernetes, secrets=secrets ) """ path = f"projects/{project}/secrets" - secrets_input = schemas.SecretsData(secrets=secrets, provider=provider) + secrets_input = mlrun.common.schemas.SecretsData( + secrets=secrets, provider=provider + ) body = secrets_input.dict() error_message = f"Failed creating secret provider {project}/{provider}" self.api_call( @@ -2337,10 +2367,10 @@ def list_project_secrets( project: str, token: str = None, provider: Union[ - str, schemas.SecretProviderName - ] = schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, secrets: List[str] = None, - ) -> schemas.SecretsData: + ) -> mlrun.common.schemas.SecretsData: """Retrieve project-context secrets from Vault. Note: @@ -2355,14 +2385,17 @@ def list_project_secrets( to this specific project. ``kubernetes`` provider only supports an empty list. """ - if provider == schemas.SecretProviderName.vault.value and not token: + if ( + provider == mlrun.common.schemas.SecretProviderName.vault.value + and not token + ): raise MLRunInvalidArgumentError( "A vault token must be provided when accessing vault secrets" ) path = f"projects/{project}/secrets" params = {"provider": provider, "secret": secrets} - headers = {schemas.HeaderNames.secret_store_token: token} + headers = {mlrun.common.schemas.HeaderNames.secret_store_token: token} error_message = f"Failed retrieving secrets {project}/{provider}" result = self.api_call( "GET", @@ -2371,16 +2404,16 @@ def list_project_secrets( params=params, headers=headers, ) - return schemas.SecretsData(**result.json()) + return mlrun.common.schemas.SecretsData(**result.json()) def list_project_secret_keys( self, project: str, provider: Union[ - str, schemas.SecretProviderName - ] = schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, token: str = None, - ) -> schemas.SecretKeysData: + ) -> mlrun.common.schemas.SecretKeysData: """Retrieve project-context secret keys from Vault or Kubernetes. Note: @@ -2389,12 +2422,15 @@ def list_project_secret_keys( :param project: The project name. :param provider: The name of the secrets-provider to work with. Accepts a - :py:class:`~mlrun.api.schemas.secret.SecretProviderName` enum. + :py:class:`~mlrun.common.schemas.secret.SecretProviderName` enum. :param token: Vault token to use for retrieving secrets. Only in use if ``provider`` is ``vault``. Must be a valid Vault token, with permissions to retrieve secrets of the project in question. """ - if provider == schemas.SecretProviderName.vault.value and not token: + if ( + provider == mlrun.common.schemas.SecretProviderName.vault.value + and not token + ): raise MLRunInvalidArgumentError( "A vault token must be provided when accessing vault secrets" ) @@ -2402,8 +2438,8 @@ def list_project_secret_keys( path = f"projects/{project}/secret-keys" params = {"provider": provider} headers = ( - {schemas.HeaderNames.secret_store_token: token} - if provider == schemas.SecretProviderName.vault.value + {mlrun.common.schemas.HeaderNames.secret_store_token: token} + if provider == mlrun.common.schemas.SecretProviderName.vault.value else None ) error_message = f"Failed retrieving secret keys {project}/{provider}" @@ -2414,14 +2450,14 @@ def list_project_secret_keys( params=params, headers=headers, ) - return schemas.SecretKeysData(**result.json()) + return mlrun.common.schemas.SecretKeysData(**result.json()) def delete_project_secrets( self, project: str, provider: Union[ - str, schemas.SecretProviderName - ] = schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, secrets: List[str] = None, ): """Delete project-context secrets from Kubernetes. @@ -2446,8 +2482,8 @@ def create_user_secrets( self, user: str, provider: Union[ - str, schemas.SecretProviderName - ] = schemas.SecretProviderName.vault, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.vault, secrets: dict = None, ): """Create user-context secret in Vault. Please refer to :py:func:`create_project_secrets` for more details @@ -2462,7 +2498,7 @@ def create_user_secrets( :param secrets: A set of secret values to store within the Vault. """ path = "user-secrets" - secrets_creation_request = schemas.UserSecretCreationRequest( + secrets_creation_request = mlrun.common.schemas.UserSecretCreationRequest( user=user, provider=provider, secrets=secrets, @@ -2532,7 +2568,9 @@ def create_model_endpoint( self, project: str, endpoint_id: str, - model_endpoint: ModelEndpoint, + model_endpoint: Union[ + mlrun.model_monitoring.model_endpoint.ModelEndpoint, dict + ], ): """ Creates a DB record with the given model_endpoint record. @@ -2542,11 +2580,16 @@ def create_model_endpoint( :param model_endpoint: An object representing the model endpoint. """ + if isinstance( + model_endpoint, mlrun.model_monitoring.model_endpoint.ModelEndpoint + ): + model_endpoint = model_endpoint.to_dict() + path = f"projects/{project}/model-endpoints/{endpoint_id}" self.api_call( method="POST", path=path, - body=model_endpoint.json(), + body=dict_to_json(model_endpoint), ) def delete_model_endpoint( @@ -2555,7 +2598,7 @@ def delete_model_endpoint( endpoint_id: str, ): """ - Deletes the KV record of a given model endpoint, project and endpoint_id are used for lookup + Deletes the DB record of a given model endpoint, project and endpoint_id are used for lookup :param project: The name of the project :param endpoint_id: The id of the endpoint @@ -2578,13 +2621,15 @@ def list_model_endpoints( metrics: Optional[List[str]] = None, top_level: bool = False, uids: Optional[List[str]] = None, - ) -> schemas.ModelEndpointList: + ) -> List[mlrun.model_monitoring.model_endpoint.ModelEndpoint]: """ - Returns a list of ModelEndpointState objects. Each object represents the current state of a model endpoint. - This functions supports filtering by the following parameters: + Returns a list of `ModelEndpoint` objects. Each `ModelEndpoint` object represents the current state of a + model endpoint. This functions supports filtering by the following parameters: 1) model 2) function 3) labels + 4) top level + 5) uids By default, when no filters are applied, all available endpoints for the given project will be listed. In addition, this functions provides a facade for listing endpoint related metrics. This facade is time-based @@ -2594,8 +2639,8 @@ def list_model_endpoints( :param project: The name of the project :param model: The name of the model to filter by :param function: The name of the function to filter by - :param labels: A list of labels to filter by. Label filters work by either filtering a specific value of a label - (i.e. list("key==value")) or by looking for the existence of a given key (i.e. "key") + :param labels: A list of labels to filter by. Label filters work by either filtering a specific value of a + label (i.e. list("key=value")) or by looking for the existence of a given key (i.e. "key") :param metrics: A list of metrics to return for each endpoint, read more in 'TimeMetric' :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or @@ -2606,10 +2651,14 @@ def list_model_endpoints( `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the earliest time. :param top_level: if true will return only routers and endpoint that are NOT children of any router - :param uids: if passed will return ModelEndpointList of endpoints with uid in uids + :param uids: if passed will return a list `ModelEndpoint` object with uid in uids """ path = f"projects/{project}/model-endpoints" + + if labels and isinstance(labels, dict): + labels = [f"{key}={value}" for key, value in labels.items()] + response = self.api_call( method="GET", path=path, @@ -2624,7 +2673,15 @@ def list_model_endpoints( "uid": uids, }, ) - return schemas.ModelEndpointList(**response.json()) + + # Generate a list of a model endpoint dictionaries + model_endpoints = response.json()["endpoints"] + if model_endpoints: + return [ + mlrun.model_monitoring.model_endpoint.ModelEndpoint.from_dict(obj) + for obj in model_endpoints + ] + return [] def get_model_endpoint( self, @@ -2634,21 +2691,29 @@ def get_model_endpoint( end: Optional[str] = None, metrics: Optional[List[str]] = None, feature_analysis: bool = False, - ) -> schemas.ModelEndpoint: - """ - Returns a ModelEndpoint object with additional metrics and feature related data. - - :param project: The name of the project - :param endpoint_id: The id of the model endpoint - :param metrics: A list of metrics to return for each endpoint, read more in 'TimeMetric' - :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or `'now-[0-9]+[mhd]'`, - where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the earliest time. - :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 - time, a Unix timestamp in milliseconds, a relative time (`'now'` or `'now-[0-9]+[mhd]'`, - where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the earliest time. - :param feature_analysis: When True, the base feature statistics and current feature statistics will be added to - the output of the resulting object + ) -> mlrun.model_monitoring.model_endpoint.ModelEndpoint: + """ + Returns a single `ModelEndpoint` object with additional metrics and feature related data. + + :param project: The name of the project + :param endpoint_id: The unique id of the model endpoint. + :param start: The start time of the metrics. Can be represented by a string containing an + RFC 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or + 0 for the earliest time. + :param end: The end time of the metrics. Can be represented by a string containing an + RFC 3339 time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or + 0 for the earliest time. + :param metrics: A list of metrics to return for the model endpoint. There are pre-defined + metrics for model endpoints such as predictions_per_second and + latency_avg_5m but also custom metrics defined by the user. Please note that + these metrics are stored in the time series DB and the results will be + appeared under model_endpoint.spec.metrics. + :param feature_analysis: When True, the base feature statistics and current feature statistics will + be added to the output of the resulting object. + + :return: A `ModelEndpoint` object. """ path = f"projects/{project}/model-endpoints/{endpoint_id}" @@ -2662,7 +2727,10 @@ def get_model_endpoint( "feature_analysis": feature_analysis, }, ) - return schemas.ModelEndpoint(**response.json()) + + return mlrun.model_monitoring.model_endpoint.ModelEndpoint.from_dict( + response.json() + ) def patch_model_endpoint( self, @@ -2676,10 +2744,10 @@ def patch_model_endpoint( :param project: The name of the project. :param endpoint_id: The id of the endpoint. :param attributes: Dictionary of attributes that will be used for update the model endpoint. The keys - of this dictionary should exist in the target table. The values should be - from type string or from a valid numerical type such as int or float. More details - about the model endpoint available attributes can be found under - :py:class:`~mlrun.api.schemas.ModelEndpoint`. + of this dictionary should exist in the target table. Note that the values should be + from type string or from a valid numerical type such as int or float. + More details about the model endpoint available attributes can be found under + :py:class:`~mlrun.common.schemas.ModelEndpoint`. Example:: @@ -2707,18 +2775,18 @@ def patch_model_endpoint( params=attributes, ) - def create_marketplace_source( - self, source: Union[dict, schemas.IndexedMarketplaceSource] + def create_hub_source( + self, source: Union[dict, mlrun.common.schemas.IndexedHubSource] ): """ - Add a new marketplace source. + Add a new hub source. - MLRun maintains an ordered list of marketplace sources (“sources”) Each source has + MLRun maintains an ordered list of hub sources (“sources”) Each source has its details registered and its order within the list. When creating a new source, the special order ``-1`` can be used to mark this source as last in the list. However, once the source is in the MLRun list, its order will always be ``>0``. - The global marketplace source always exists in the list, and is always the last source + The global hub source always exists in the list, and is always the last source (``order = -1``). It cannot be modified nor can it be moved to another order in the list. The source object may contain credentials which are needed to access the datastore where the source is stored. @@ -2727,49 +2795,55 @@ def create_marketplace_source( Example:: - import mlrun.api.schemas + import mlrun.common.schemas # Add a private source as the last one (will be #1 in the list) - private_source = mlrun.api.schemas.IndexedMarketplaceSource( + private_source = mlrun.common.schemas.IndexedHubeSource( order=-1, - source=mlrun.api.schemas.MarketplaceSource( - metadata=mlrun.api.schemas.MarketplaceObjectMetadata(name="priv", description="a private source"), - spec=mlrun.api.schemas.MarketplaceSourceSpec(path="/local/path/to/source", channel="development") + source=mlrun.common.schemas.HubSource( + metadata=mlrun.common.schemas.HubObjectMetadata( + name="priv", description="a private source" + ), + spec=mlrun.common.schemas.HubSourceSpec(path="/local/path/to/source", channel="development") ) ) - db.create_marketplace_source(private_source) + db.create_hub_source(private_source) # Add another source as 1st in the list - will push previous one to be #2 - another_source = mlrun.api.schemas.IndexedMarketplaceSource( + another_source = mlrun.common.schemas.IndexedHubSource( order=1, - source=mlrun.api.schemas.MarketplaceSource( - metadata=mlrun.api.schemas.MarketplaceObjectMetadata(name="priv-2", description="another source"), - spec=mlrun.api.schemas.MarketplaceSourceSpec( + source=mlrun.common.schemas.HubSource( + metadata=mlrun.common.schemas.HubObjectMetadata( + name="priv-2", description="another source" + ), + spec=mlrun.common.schemas.HubSourceSpec( path="/local/path/to/source/2", channel="development", credentials={...} ) ) ) - db.create_marketplace_source(another_source) + db.create_hub_source(another_source) :param source: The source and its order, of type - :py:class:`~mlrun.api.schemas.marketplace.IndexedMarketplaceSource`, or in dictionary form. + :py:class:`~mlrun.common.schemas.hub.IndexedHubSource`, or in dictionary form. :returns: The source object as inserted into the database, with credentials stripped. """ - path = "marketplace/sources" - if isinstance(source, schemas.IndexedMarketplaceSource): + path = "hub/sources" + if isinstance(source, mlrun.common.schemas.IndexedHubSource): source = source.dict() response = self.api_call(method="POST", path=path, json=source) - return schemas.IndexedMarketplaceSource(**response.json()) + return mlrun.common.schemas.IndexedHubSource(**response.json()) - def store_marketplace_source( - self, source_name: str, source: Union[dict, schemas.IndexedMarketplaceSource] + def store_hub_source( + self, + source_name: str, + source: Union[dict, mlrun.common.schemas.IndexedHubSource], ): """ - Create or replace a marketplace source. + Create or replace a hub source. For an example of the source format and explanation of the source order logic, - please see :py:func:`~create_marketplace_source`. This method can be used to modify the source itself or its + please see :py:func:`~create_hub_source`. This method can be used to modify the source itself or its order in the list of sources. :param source_name: Name of the source object to modify/create. It must match the ``source.metadata.name`` @@ -2777,47 +2851,47 @@ def store_marketplace_source( :param source: Source object to store in the database. :returns: The source object as stored in the DB. """ - path = f"marketplace/sources/{source_name}" - if isinstance(source, schemas.IndexedMarketplaceSource): + path = f"hub/sources/{source_name}" + if isinstance(source, mlrun.common.schemas.IndexedHubSource): source = source.dict() response = self.api_call(method="PUT", path=path, json=source) - return schemas.IndexedMarketplaceSource(**response.json()) + return mlrun.common.schemas.IndexedHubSource(**response.json()) - def list_marketplace_sources(self): + def list_hub_sources(self): """ - List marketplace sources in the MLRun DB. + List hub sources in the MLRun DB. """ - path = "marketplace/sources" + path = "hub/sources" response = self.api_call(method="GET", path=path).json() results = [] for item in response: - results.append(schemas.IndexedMarketplaceSource(**item)) + results.append(mlrun.common.schemas.IndexedHubSource(**item)) return results - def get_marketplace_source(self, source_name: str): + def get_hub_source(self, source_name: str): """ - Retrieve a marketplace source from the DB. + Retrieve a hub source from the DB. - :param source_name: Name of the marketplace source to retrieve. + :param source_name: Name of the hub source to retrieve. """ - path = f"marketplace/sources/{source_name}" + path = f"hub/sources/{source_name}" response = self.api_call(method="GET", path=path) - return schemas.IndexedMarketplaceSource(**response.json()) + return mlrun.common.schemas.IndexedHubSource(**response.json()) - def delete_marketplace_source(self, source_name: str): + def delete_hub_source(self, source_name: str): """ - Delete a marketplace source from the DB. + Delete a hub source from the DB. The source will be deleted from the list, and any following sources will be promoted - for example, if the 1st source is deleted, the 2nd source will become #1 in the list. - The global marketplace source cannot be deleted. + The global hub source cannot be deleted. - :param source_name: Name of the marketplace source to delete. + :param source_name: Name of the hub source to delete. """ - path = f"marketplace/sources/{source_name}" + path = f"hub/sources/{source_name}" self.api_call(method="DELETE", path=path) - def get_marketplace_catalog( + def get_hub_catalog( self, source_name: str, version: str = None, @@ -2825,29 +2899,29 @@ def get_marketplace_catalog( force_refresh: bool = False, ): """ - Retrieve the item catalog for a specified marketplace source. + Retrieve the item catalog for a specified hub source. The list of items can be filtered according to various filters, using item's metadata to filter. :param source_name: Name of the source. :param version: Filter items according to their version. :param tag: Filter items based on tag. - :param force_refresh: Make the server fetch the catalog from the actual marketplace source, + :param force_refresh: Make the server fetch the catalog from the actual hub source, rather than rely on cached information which may exist from previous get requests. For example, if the source was re-built, this will make the server get the updated information. Default is ``False``. - :returns: :py:class:`~mlrun.api.schemas.marketplace.MarketplaceCatalog` object, which is essentially a list - of :py:class:`~mlrun.api.schemas.marketplace.MarketplaceItem` entries. + :returns: :py:class:`~mlrun.common.schemas.hub.HubCatalog` object, which is essentially a list + of :py:class:`~mlrun.common.schemas.hub.HubItem` entries. """ - path = (f"marketplace/sources/{source_name}/items",) + path = (f"hub/sources/{source_name}/items",) params = { "version": version, "tag": tag, "force-refresh": force_refresh, } response = self.api_call(method="GET", path=path, params=params) - return schemas.MarketplaceCatalog(**response.json()) + return mlrun.common.schemas.HubCatalog(**response.json()) - def get_marketplace_item( + def get_hub_item( self, source_name: str, item_name: str, @@ -2856,33 +2930,61 @@ def get_marketplace_item( force_refresh: bool = False, ): """ - Retrieve a specific marketplace item. + Retrieve a specific hub item. :param source_name: Name of source. :param item_name: Name of the item to retrieve, as it appears in the catalog. :param version: Get a specific version of the item. Default is ``None``. :param tag: Get a specific version of the item identified by tag. Default is ``latest``. - :param force_refresh: Make the server fetch the information from the actual marketplace + :param force_refresh: Make the server fetch the information from the actual hub source, rather than rely on cached information. Default is ``False``. - :returns: :py:class:`~mlrun.api.schemas.marketplace.MarketplaceItem`. + :returns: :py:class:`~mlrun.common.schemas.hub.HubItem`. """ - path = (f"marketplace/sources/{source_name}/items/{item_name}",) + path = (f"hub/sources/{source_name}/items/{item_name}",) params = { "version": version, "tag": tag, "force-refresh": force_refresh, } response = self.api_call(method="GET", path=path, params=params) - return schemas.MarketplaceItem(**response.json()) + return mlrun.common.schemas.HubItem(**response.json()) + + def get_hub_asset( + self, + source_name: str, + item_name: str, + asset_name: str, + version: str = None, + tag: str = "latest", + ): + """ + Get hub asset from item. + + :param source_name: Name of source. + :param item_name: Name of the item which holds the asset. + :param asset_name: Name of the asset to retrieve. + :param version: Get a specific version of the item. Default is ``None``. + :param tag: Get a specific version of the item identified by tag. Default is ``latest``. + + :return: http response with the asset in the content attribute + """ + path = (f"hub/sources/{source_name}/items/{item_name}/assets/{asset_name}",) + params = { + "version": version, + "tag": tag, + } + response = self.api_call(method="GET", path=path, params=params) + return response def verify_authorization( - self, authorization_verification_input: schemas.AuthorizationVerificationInput + self, + authorization_verification_input: mlrun.common.schemas.AuthorizationVerificationInput, ): """Verifies authorization for the provided action on the provided resource. :param authorization_verification_input: Instance of - :py:class:`~mlrun.api.schemas.AuthorizationVerificationInput` that includes all the needed parameters for + :py:class:`~mlrun.common.schemas.AuthorizationVerificationInput` that includes all the needed parameters for the auth verification """ error_message = "Authorization check failed" @@ -2893,10 +2995,10 @@ def verify_authorization( body=dict_to_json(authorization_verification_input.dict()), ) - def trigger_migrations(self) -> Optional[schemas.BackgroundTask]: + def trigger_migrations(self) -> Optional[mlrun.common.schemas.BackgroundTask]: """Trigger migrations (will do nothing if no migrations are needed) and wait for them to finish if actually triggered - :returns: :py:class:`~mlrun.api.schemas.BackgroundTask`. + :returns: :py:class:`~mlrun.common.schemas.BackgroundTask`. """ response = self.api_call( "POST", @@ -2904,12 +3006,62 @@ def trigger_migrations(self) -> Optional[schemas.BackgroundTask]: "Failed triggering migrations", ) if response.status_code == http.HTTPStatus.ACCEPTED: - background_task = schemas.BackgroundTask(**response.json()) + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) return self._wait_for_background_task_to_reach_terminal_state( background_task.metadata.name ) return None + def set_run_notifications( + self, + project: str, + run_uid: str, + notifications: typing.List[mlrun.model.Notification] = None, + ): + """ + Set notifications on a run. This will override any existing notifications on the run. + :param project: Project containing the run. + :param run_uid: UID of the run. + :param notifications: List of notifications to set on the run. Default is an empty list. + """ + notifications = notifications or [] + + self.api_call( + "PUT", + f"projects/{project}/runs/{run_uid}/notifications", + f"Failed to set notifications on run. uid={run_uid}, project={project}", + json={ + "notifications": [ + notification.to_dict() for notification in notifications + ], + }, + ) + + def set_schedule_notifications( + self, + project: str, + schedule_name: str, + notifications: typing.List[mlrun.model.Notification] = None, + ): + """ + Set notifications on a schedule. This will override any existing notifications on the schedule. + :param project: Project containing the schedule. + :param schedule_name: Name of the schedule. + :param notifications: List of notifications to set on the schedule. Default is an empty list. + """ + notifications = notifications or [] + + self.api_call( + "PUT", + f"projects/{project}/schedules/{schedule_name}/notifications", + f"Failed to set notifications on schedule. schedule={schedule_name}, project={project}", + json={ + "notifications": [ + notification.to_dict() for notification in notifications + ], + }, + ) + def _as_json(obj): fn = getattr(obj, "to_json", None) diff --git a/mlrun/db/nopdb.py b/mlrun/db/nopdb.py new file mode 100644 index 000000000000..166c58758a26 --- /dev/null +++ b/mlrun/db/nopdb.py @@ -0,0 +1,491 @@ +# Copyright 2022 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime +from typing import List, Optional, Union + +import mlrun.common.schemas +import mlrun.errors + +from ..config import config +from ..utils import logger +from .base import RunDBInterface + + +class NopDB(RunDBInterface): + def __init__(self, url=None, *args, **kwargs): + self.url = url + + def __getattribute__(self, attr): + def nop(*args, **kwargs): + env_var_message = ( + "MLRUN_DBPATH is not set. Set this environment variable to the URL of the API " + "server in order to connect" + ) + if config.httpdb.nop_db.raise_error: + raise mlrun.errors.MLRunBadRequestError(env_var_message) + + if config.httpdb.nop_db.verbose: + logger.warning( + "Could not detect path to API server, not connected to API server!" + ) + logger.warning(env_var_message) + + return + + if attr == "connect": + return super().__getattribute__(attr) + else: + nop() + return super().__getattribute__(attr) + + def connect(self, secrets=None): + pass + + def store_log(self, uid, project="", body=None, append=False): + pass + + def get_log(self, uid, project="", offset=0, size=0): + pass + + def store_run(self, struct, uid, project="", iter=0): + pass + + def update_run(self, updates: dict, uid, project="", iter=0): + pass + + def abort_run(self, uid, project="", iter=0, timeout=45): + pass + + def read_run(self, uid, project="", iter=0): + pass + + def list_runs( + self, + name="", + uid: Optional[Union[str, List[str]]] = None, + project="", + labels=None, + state="", + sort=True, + last=0, + iter=False, + start_time_from: datetime.datetime = None, + start_time_to: datetime.datetime = None, + last_update_time_from: datetime.datetime = None, + last_update_time_to: datetime.datetime = None, + partition_by: Union[mlrun.common.schemas.RunPartitionByField, str] = None, + rows_per_partition: int = 1, + partition_sort_by: Union[mlrun.common.schemas.SortField, str] = None, + partition_order: Union[ + mlrun.common.schemas.OrderType, str + ] = mlrun.common.schemas.OrderType.desc, + max_partitions: int = 0, + ): + pass + + def del_run(self, uid, project="", iter=0): + pass + + def del_runs(self, name="", project="", labels=None, state="", days_ago=0): + pass + + def store_artifact(self, key, artifact, uid, iter=None, tag="", project=""): + pass + + def read_artifact(self, key, tag="", iter=None, project=""): + pass + + def list_artifacts( + self, + name="", + project="", + tag="", + labels=None, + since=None, + until=None, + iter: int = None, + best_iteration: bool = False, + kind: str = None, + category: Union[str, mlrun.common.schemas.ArtifactCategories] = None, + ): + pass + + def del_artifact(self, key, tag="", project=""): + pass + + def del_artifacts(self, name="", project="", tag="", labels=None): + pass + + def store_function(self, function, name, project="", tag="", versioned=False): + pass + + def get_function(self, name, project="", tag="", hash_key=""): + pass + + def delete_function(self, name: str, project: str = ""): + pass + + def list_functions(self, name=None, project="", tag="", labels=None): + pass + + def tag_objects( + self, + project: str, + tag_name: str, + tag_objects: mlrun.common.schemas.TagObjects, + replace: bool = False, + ): + pass + + def delete_objects_tag( + self, project: str, tag_name: str, tag_objects: mlrun.common.schemas.TagObjects + ): + pass + + def tag_artifacts( + self, artifacts, project: str, tag_name: str, replace: bool = False + ): + pass + + def delete_artifacts_tags(self, artifacts, project: str, tag_name: str): + pass + + def delete_project( + self, + name: str, + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), + ): + pass + + def store_project( + self, name: str, project: mlrun.common.schemas.Project + ) -> mlrun.common.schemas.Project: + pass + + def patch_project( + self, + name: str, + project: dict, + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, + ) -> mlrun.common.schemas.Project: + pass + + def create_project( + self, project: mlrun.common.schemas.Project + ) -> mlrun.common.schemas.Project: + pass + + def list_projects( + self, + owner: str = None, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, + labels: List[str] = None, + state: mlrun.common.schemas.ProjectState = None, + ) -> mlrun.common.schemas.ProjectsOutput: + pass + + def get_project(self, name: str) -> mlrun.common.schemas.Project: + pass + + def list_artifact_tags( + self, + project=None, + category: Union[str, mlrun.common.schemas.ArtifactCategories] = None, + ): + pass + + def create_feature_set( + self, + feature_set: Union[dict, mlrun.common.schemas.FeatureSet], + project="", + versioned=True, + ) -> dict: + pass + + def get_feature_set( + self, name: str, project: str = "", tag: str = None, uid: str = None + ) -> dict: + pass + + def list_features( + self, + project: str, + name: str = None, + tag: str = None, + entities: List[str] = None, + labels: List[str] = None, + ) -> mlrun.common.schemas.FeaturesOutput: + pass + + def list_entities( + self, project: str, name: str = None, tag: str = None, labels: List[str] = None + ) -> mlrun.common.schemas.EntitiesOutput: + pass + + def list_feature_sets( + self, + project: str = "", + name: str = None, + tag: str = None, + state: str = None, + entities: List[str] = None, + features: List[str] = None, + labels: List[str] = None, + partition_by: Union[ + mlrun.common.schemas.FeatureStorePartitionByField, str + ] = None, + rows_per_partition: int = 1, + partition_sort_by: Union[mlrun.common.schemas.SortField, str] = None, + partition_order: Union[ + mlrun.common.schemas.OrderType, str + ] = mlrun.common.schemas.OrderType.desc, + ) -> List[dict]: + pass + + def store_feature_set( + self, + feature_set: Union[dict, mlrun.common.schemas.FeatureSet], + name=None, + project="", + tag=None, + uid=None, + versioned=True, + ): + pass + + def patch_feature_set( + self, + name, + feature_set: dict, + project="", + tag=None, + uid=None, + patch_mode: Union[ + str, mlrun.common.schemas.PatchMode + ] = mlrun.common.schemas.PatchMode.replace, + ): + pass + + def delete_feature_set(self, name, project="", tag=None, uid=None): + pass + + def create_feature_vector( + self, + feature_vector: Union[dict, mlrun.common.schemas.FeatureVector], + project="", + versioned=True, + ) -> dict: + pass + + def get_feature_vector( + self, name: str, project: str = "", tag: str = None, uid: str = None + ) -> dict: + pass + + def list_feature_vectors( + self, + project: str = "", + name: str = None, + tag: str = None, + state: str = None, + labels: List[str] = None, + partition_by: Union[ + mlrun.common.schemas.FeatureStorePartitionByField, str + ] = None, + rows_per_partition: int = 1, + partition_sort_by: Union[mlrun.common.schemas.SortField, str] = None, + partition_order: Union[ + mlrun.common.schemas.OrderType, str + ] = mlrun.common.schemas.OrderType.desc, + ) -> List[dict]: + pass + + def store_feature_vector( + self, + feature_vector: Union[dict, mlrun.common.schemas.FeatureVector], + name=None, + project="", + tag=None, + uid=None, + versioned=True, + ): + pass + + def patch_feature_vector( + self, + name, + feature_vector_update: dict, + project="", + tag=None, + uid=None, + patch_mode: Union[ + str, mlrun.common.schemas.PatchMode + ] = mlrun.common.schemas.PatchMode.replace, + ): + pass + + def delete_feature_vector(self, name, project="", tag=None, uid=None): + pass + + def list_pipelines( + self, + project: str, + namespace: str = None, + sort_by: str = "", + page_token: str = "", + filter_: str = "", + format_: Union[ + str, mlrun.common.schemas.PipelinesFormat + ] = mlrun.common.schemas.PipelinesFormat.metadata_only, + page_size: int = None, + ) -> mlrun.common.schemas.PipelinesOutput: + pass + + def create_project_secrets( + self, + project: str, + provider: Union[ + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, + secrets: dict = None, + ): + pass + + def list_project_secrets( + self, + project: str, + token: str, + provider: Union[ + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, + secrets: List[str] = None, + ) -> mlrun.common.schemas.SecretsData: + pass + + def list_project_secret_keys( + self, + project: str, + provider: Union[ + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, + token: str = None, + ) -> mlrun.common.schemas.SecretKeysData: + pass + + def delete_project_secrets( + self, + project: str, + provider: Union[ + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, + secrets: List[str] = None, + ): + pass + + def create_user_secrets( + self, + user: str, + provider: Union[ + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.vault, + secrets: dict = None, + ): + pass + + def create_model_endpoint( + self, + project: str, + endpoint_id: str, + model_endpoint: mlrun.common.schemas.ModelEndpoint, + ): + pass + + def delete_model_endpoint(self, project: str, endpoint_id: str): + pass + + def list_model_endpoints( + self, + project: str, + model: Optional[str] = None, + function: Optional[str] = None, + labels: List[str] = None, + start: str = "now-1h", + end: str = "now", + metrics: Optional[List[str]] = None, + ): + pass + + def get_model_endpoint( + self, + project: str, + endpoint_id: str, + start: Optional[str] = None, + end: Optional[str] = None, + metrics: Optional[List[str]] = None, + features: bool = False, + ): + pass + + def patch_model_endpoint(self, project: str, endpoint_id: str, attributes: dict): + pass + + def create_hub_source( + self, source: Union[dict, mlrun.common.schemas.IndexedHubSource] + ): + pass + + def store_hub_source( + self, + source_name: str, + source: Union[dict, mlrun.common.schemas.IndexedHubSource], + ): + pass + + def list_hub_sources(self): + pass + + def get_hub_source(self, source_name: str): + pass + + def delete_hub_source(self, source_name: str): + pass + + def get_hub_catalog( + self, + source_name: str, + channel: str = None, + version: str = None, + tag: str = None, + force_refresh: bool = False, + ): + pass + + def get_hub_item( + self, + source_name: str, + item_name: str, + channel: str = "development", + version: str = None, + tag: str = "latest", + force_refresh: bool = False, + ): + pass + + def verify_authorization( + self, + authorization_verification_input: mlrun.common.schemas.AuthorizationVerificationInput, + ): + pass diff --git a/mlrun/db/sqldb.py b/mlrun/db/sqldb.py index 7b6249442a78..a8519d70484e 100644 --- a/mlrun/db/sqldb.py +++ b/mlrun/db/sqldb.py @@ -15,7 +15,8 @@ import datetime from typing import List, Optional, Union -import mlrun.api.schemas +import mlrun.common.schemas +import mlrun.model_monitoring.model_endpoint from mlrun.api.db.base import DBError from mlrun.api.db.sqldb.db import SQLDB as SQLAPIDB from mlrun.api.db.sqldb.session import create_session @@ -28,8 +29,6 @@ # service, in order to prevent the api from calling itself several times for each submission request (since the runDB # will be httpdb to that same api service) we have this class which is kind of a proxy between the RunDB interface to # the api service's DB interface -from ..api import schemas -from ..api.schemas import ModelEndpoint from .base import RunDBError, RunDBInterface @@ -95,7 +94,7 @@ def update_run(self, updates: dict, uid, project="", iter=0): updates, ) - def abort_run(self, uid, project="", iter=0): + def abort_run(self, uid, project="", iter=0, timeout=45): raise NotImplementedError() def read_run(self, uid, project=None, iter=None): @@ -123,10 +122,12 @@ def list_runs( start_time_to: datetime.datetime = None, last_update_time_from: datetime.datetime = None, last_update_time_to: datetime.datetime = None, - partition_by: Union[schemas.RunPartitionByField, str] = None, + partition_by: Union[mlrun.common.schemas.RunPartitionByField, str] = None, rows_per_partition: int = 1, - partition_sort_by: Union[schemas.SortField, str] = None, - partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc, + partition_sort_by: Union[mlrun.common.schemas.SortField, str] = None, + partition_order: Union[ + mlrun.common.schemas.OrderType, str + ] = mlrun.common.schemas.OrderType.desc, max_partitions: int = 0, with_notifications: bool = False, ): @@ -216,12 +217,12 @@ def list_artifacts( iter: int = None, best_iteration: bool = False, kind: str = None, - category: Union[str, schemas.ArtifactCategories] = None, + category: Union[str, mlrun.common.schemas.ArtifactCategories] = None, ): import mlrun.api.crud if category and isinstance(category, str): - category = schemas.ArtifactCategories(category) + category = mlrun.common.schemas.ArtifactCategories(category) return self._transform_db_error( mlrun.api.crud.Artifacts().list_artifacts, @@ -309,7 +310,9 @@ def list_functions(self, name=None, project=None, tag=None, labels=None): ) def list_artifact_tags( - self, project=None, category: Union[str, schemas.ArtifactCategories] = None + self, + project=None, + category: Union[str, mlrun.common.schemas.ArtifactCategories] = None, ): return self._transform_db_error( self.db.list_artifact_tags, self.session, project @@ -319,7 +322,7 @@ def tag_objects( self, project: str, tag_name: str, - tag_objects: mlrun.api.schemas.TagObjects, + tag_objects: mlrun.common.schemas.TagObjects, replace: bool = False, ): import mlrun.api.crud @@ -345,7 +348,7 @@ def delete_objects_tag( self, project: str, tag_name: str, - tag_objects: mlrun.api.schemas.TagObjects, + tag_objects: mlrun.common.schemas.TagObjects, ): import mlrun.api.crud @@ -394,34 +397,65 @@ def list_schedules(self): def store_project( self, name: str, - project: mlrun.api.schemas.Project, - ) -> mlrun.api.schemas.Project: - raise NotImplementedError() + project: mlrun.common.schemas.Project, + ) -> mlrun.common.schemas.Project: + import mlrun.api.crud + + if isinstance(project, dict): + project = mlrun.common.schemas.Project(**project) + + return self._transform_db_error( + mlrun.api.crud.Projects().store_project, + self.session, + name=name, + project=project, + ) def patch_project( self, name: str, project: dict, - patch_mode: mlrun.api.schemas.PatchMode = mlrun.api.schemas.PatchMode.replace, - ) -> mlrun.api.schemas.Project: - raise NotImplementedError() + patch_mode: mlrun.common.schemas.PatchMode = mlrun.common.schemas.PatchMode.replace, + ) -> mlrun.common.schemas.Project: + import mlrun.api.crud + + return self._transform_db_error( + mlrun.api.crud.Projects().patch_project, + self.session, + name=name, + project=project, + patch_mode=patch_mode, + ) def create_project( self, - project: mlrun.api.schemas.Project, - ) -> mlrun.api.schemas.Project: - raise NotImplementedError() + project: mlrun.common.schemas.Project, + ) -> mlrun.common.schemas.Project: + import mlrun.api.crud + + return self._transform_db_error( + mlrun.api.crud.Projects().create_project, + self.session, + project=project, + ) def delete_project( self, name: str, - deletion_strategy: mlrun.api.schemas.DeletionStrategy = mlrun.api.schemas.DeletionStrategy.default(), + deletion_strategy: mlrun.common.schemas.DeletionStrategy = mlrun.common.schemas.DeletionStrategy.default(), ): - raise NotImplementedError() + import mlrun.api.crud + + return self._transform_db_error( + mlrun.api.crud.Projects().delete_project, + self.session, + name=name, + deletion_strategy=deletion_strategy, + ) def get_project( self, name: str = None, project_id: int = None - ) -> mlrun.api.schemas.Project: + ) -> mlrun.common.schemas.Project: import mlrun.api.crud return self._transform_db_error( @@ -433,11 +467,20 @@ def get_project( def list_projects( self, owner: str = None, - format_: mlrun.api.schemas.ProjectsFormat = mlrun.api.schemas.ProjectsFormat.full, + format_: mlrun.common.schemas.ProjectsFormat = mlrun.common.schemas.ProjectsFormat.full, labels: List[str] = None, - state: mlrun.api.schemas.ProjectState = None, - ) -> mlrun.api.schemas.ProjectsOutput: - raise NotImplementedError() + state: mlrun.common.schemas.ProjectState = None, + ) -> mlrun.common.schemas.ProjectsOutput: + import mlrun.api.crud + + return self._transform_db_error( + mlrun.api.crud.Projects().list_projects, + self.session, + owner=owner, + format_=format_, + labels=labels, + state=state, + ) @staticmethod def _transform_db_error(func, *args, **kwargs): @@ -519,10 +562,10 @@ def list_feature_sets( entities: List[str] = None, features: List[str] = None, labels: List[str] = None, - partition_by: mlrun.api.schemas.FeatureStorePartitionByField = None, + partition_by: mlrun.common.schemas.FeatureStorePartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: mlrun.api.schemas.SortField = None, - partition_order: mlrun.api.schemas.OrderType = mlrun.api.schemas.OrderType.desc, + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, ): import mlrun.api.crud @@ -544,7 +587,7 @@ def list_feature_sets( def store_feature_set( self, - feature_set: Union[dict, mlrun.api.schemas.FeatureSet], + feature_set: Union[dict, mlrun.common.schemas.FeatureSet], name=None, project="", tag=None, @@ -554,7 +597,7 @@ def store_feature_set( import mlrun.api.crud if isinstance(feature_set, dict): - feature_set = mlrun.api.schemas.FeatureSet(**feature_set) + feature_set = mlrun.common.schemas.FeatureSet(**feature_set) name = name or feature_set.metadata.name project = project or feature_set.metadata.project @@ -629,10 +672,10 @@ def list_feature_vectors( tag: str = None, state: str = None, labels: List[str] = None, - partition_by: mlrun.api.schemas.FeatureStorePartitionByField = None, + partition_by: mlrun.common.schemas.FeatureStorePartitionByField = None, rows_per_partition: int = 1, - partition_sort_by: mlrun.api.schemas.SortField = None, - partition_order: mlrun.api.schemas.OrderType = mlrun.api.schemas.OrderType.desc, + partition_sort_by: mlrun.common.schemas.SortField = None, + partition_order: mlrun.common.schemas.OrderType = mlrun.common.schemas.OrderType.desc, ): import mlrun.api.crud @@ -714,18 +757,18 @@ def list_pipelines( page_token: str = "", filter_: str = "", format_: Union[ - str, mlrun.api.schemas.PipelinesFormat - ] = mlrun.api.schemas.PipelinesFormat.metadata_only, + str, mlrun.common.schemas.PipelinesFormat + ] = mlrun.common.schemas.PipelinesFormat.metadata_only, page_size: int = None, - ) -> mlrun.api.schemas.PipelinesOutput: + ) -> mlrun.common.schemas.PipelinesOutput: raise NotImplementedError() def create_project_secrets( self, project: str, provider: Union[ - str, mlrun.api.schemas.SecretProviderName - ] = mlrun.api.schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, secrets: dict = None, ): raise NotImplementedError() @@ -735,28 +778,28 @@ def list_project_secrets( project: str, token: str, provider: Union[ - str, mlrun.api.schemas.SecretProviderName - ] = mlrun.api.schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, secrets: List[str] = None, - ) -> mlrun.api.schemas.SecretsData: + ) -> mlrun.common.schemas.SecretsData: raise NotImplementedError() def list_project_secret_keys( self, project: str, provider: Union[ - str, mlrun.api.schemas.SecretProviderName - ] = mlrun.api.schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, token: str = None, - ) -> mlrun.api.schemas.SecretKeysData: + ) -> mlrun.common.schemas.SecretKeysData: raise NotImplementedError() def delete_project_secrets( self, project: str, provider: Union[ - str, mlrun.api.schemas.SecretProviderName - ] = mlrun.api.schemas.SecretProviderName.kubernetes, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.kubernetes, secrets: List[str] = None, ): raise NotImplementedError() @@ -765,8 +808,8 @@ def create_user_secrets( self, user: str, provider: Union[ - str, mlrun.api.schemas.SecretProviderName - ] = mlrun.api.schemas.SecretProviderName.vault, + str, mlrun.common.schemas.SecretProviderName + ] = mlrun.common.schemas.SecretProviderName.vault, secrets: dict = None, ): raise NotImplementedError() @@ -775,7 +818,9 @@ def create_model_endpoint( self, project: str, endpoint_id: str, - model_endpoint: ModelEndpoint, + model_endpoint: Union[ + mlrun.model_monitoring.model_endpoint.ModelEndpoint, dict + ], ): raise NotImplementedError() @@ -817,26 +862,28 @@ def patch_model_endpoint( ): raise NotImplementedError() - def create_marketplace_source( - self, source: Union[dict, schemas.IndexedMarketplaceSource] + def create_hub_source( + self, source: Union[dict, mlrun.common.schemas.IndexedHubSource] ): raise NotImplementedError() - def store_marketplace_source( - self, source_name: str, source: Union[dict, schemas.IndexedMarketplaceSource] + def store_hub_source( + self, + source_name: str, + source: Union[dict, mlrun.common.schemas.IndexedHubSource], ): raise NotImplementedError() - def list_marketplace_sources(self): + def list_hub_sources(self): raise NotImplementedError() - def get_marketplace_source(self, source_name: str): + def get_hub_source(self, source_name: str): raise NotImplementedError() - def delete_marketplace_source(self, source_name: str): + def delete_hub_source(self, source_name: str): raise NotImplementedError() - def get_marketplace_catalog( + def get_hub_catalog( self, source_name: str, version: str = None, @@ -845,7 +892,7 @@ def get_marketplace_catalog( ): raise NotImplementedError() - def get_marketplace_item( + def get_hub_item( self, source_name: str, item_name: str, @@ -857,7 +904,7 @@ def get_marketplace_item( def verify_authorization( self, - authorization_verification_input: mlrun.api.schemas.AuthorizationVerificationInput, + authorization_verification_input: mlrun.common.schemas.AuthorizationVerificationInput, ): # on server side authorization is done in endpoint anyway, so for server side we can "pass" on check # done from ingest() diff --git a/mlrun/errors.py b/mlrun/errors.py index e5ea58635424..c224720213f0 100644 --- a/mlrun/errors.py +++ b/mlrun/errors.py @@ -179,6 +179,10 @@ class MLRunInternalServerError(MLRunHTTPStatusError): error_status_code = HTTPStatus.INTERNAL_SERVER_ERROR.value +class MLRunServiceUnavailableError(MLRunHTTPStatusError): + error_status_code = HTTPStatus.SERVICE_UNAVAILABLE.value + + class MLRunRuntimeError(MLRunHTTPStatusError, RuntimeError): error_status_code = HTTPStatus.INTERNAL_SERVER_ERROR.value @@ -213,4 +217,5 @@ def __init__( HTTPStatus.CONFLICT.value: MLRunConflictError, HTTPStatus.PRECONDITION_FAILED.value: MLRunPreconditionFailedError, HTTPStatus.INTERNAL_SERVER_ERROR.value: MLRunInternalServerError, + HTTPStatus.SERVICE_UNAVAILABLE.value: MLRunServiceUnavailableError, } diff --git a/mlrun/execution.py b/mlrun/execution.py index 98eaf6722edd..95129db1cafb 100644 --- a/mlrun/execution.py +++ b/mlrun/execution.py @@ -304,8 +304,8 @@ def from_dict( self._init_dbs(rundb) - if spec and not is_api: - # init data related objects (require DB & Secrets to be set first), skip when running in the api service + if spec: + # init data related objects (require DB & Secrets to be set first) self._data_stores.from_dict(spec) if inputs and isinstance(inputs, dict): for k, v in inputs.items(): @@ -380,7 +380,7 @@ def parameters(self): @property def inputs(self): - """dictionary of input data items (read-only)""" + """dictionary of input data item urls (read-only)""" return self._inputs @property @@ -463,16 +463,26 @@ def get_param(self, key: str, default=None): def _load_project_object(self): if not self._project_object: if not self._project: - self.logger.warning("get_project_param called without a project name") + self.logger.warning( + "Project cannot be loaded without a project name set in the context" + ) return None if not self._rundb: self.logger.warning( - "cannot retrieve project parameters - MLRun DB is not accessible" + "Cannot retrieve project data - MLRun DB is not accessible" ) return None self._project_object = self._rundb.get_project(self._project) return self._project_object + def get_project_object(self): + """ + Get the MLRun project object by the project name set in the context. + + :return: The project object or None if it couldn't be retrieved. + """ + return self._load_project_object() + def get_project_param(self, key: str, default=None): """get a parameter from the run's project's parameters""" if not self._load_project_object(): @@ -497,27 +507,34 @@ def _set_input(self, key, url=""): url = key if self.in_path and is_relative_path(url): url = os.path.join(self._in_path, url) - obj = self._data_stores.object( - url, - key, - project=self._project, - allow_empty_resources=self._allow_empty_resources, - ) - self._inputs[key] = obj - return obj + self._inputs[key] = url def get_input(self, key: str, url: str = ""): - """get an input :py:class:`~mlrun.DataItem` object, data objects have methods such as - .get(), .download(), .url, .. to access the actual data + """ + Get an input :py:class:`~mlrun.DataItem` object, + data objects have methods such as .get(), .download(), .url, .. to access the actual data. + Requires access to the data store secrets if configured. - example:: + Example:: data = context.get_input("my_data").get() + + :param key: The key name for the input url entry. + :param url: The url of the input data (file, stream, ..) - optional, saved in the inputs dictionary + if the key is not already present. + + :return: :py:class:`~mlrun.datastore.base.DataItem` object """ if key not in self._inputs: - return self._set_input(key, url) - else: - return self._inputs[key] + self._set_input(key, url) + + url = self._inputs[key] + return self._data_stores.object( + url, + key, + project=self._project, + allow_empty_resources=self._allow_empty_resources, + ) def log_result(self, key: str, value, commit=False): """log a scalar result value @@ -945,7 +962,7 @@ def set_if_not_none(_struct, key, val): "handler": self._handler, "outputs": self._outputs, run_keys.output_path: self.artifact_path, - run_keys.inputs: {k: v.artifact_url for k, v in self._inputs.items()}, + run_keys.inputs: self._inputs, "notifications": self._notifications, }, "status": { @@ -982,7 +999,7 @@ def set_if_not_none(_struct, key, val): "metadata.annotations": self._annotations, "spec.parameters": self._parameters, "spec.outputs": self._outputs, - "spec.inputs": {k: v.artifact_url for k, v in self._inputs.items()}, + "spec.inputs": self._inputs, "status.results": self._results, "status.start_time": to_date_str(self._start_time), "status.last_update": to_date_str(self._last_update), diff --git a/mlrun/feature_store/api.py b/mlrun/feature_store/api.py index f3427e886c8d..9ead9570e3c1 100644 --- a/mlrun/feature_store/api.py +++ b/mlrun/feature_store/api.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import importlib.util +import pathlib +import sys import warnings from datetime import datetime -from typing import List, Optional, Union -from urllib.parse import urlparse +from typing import Any, Dict, List, Optional, Union import pandas as pd @@ -28,7 +30,6 @@ from ..datastore.targets import ( BaseStoreTarget, get_default_prefix_for_source, - get_default_targets, get_target_driver, kind_to_driver, validate_target_list, @@ -39,7 +40,7 @@ from ..runtimes import RuntimeKinds from ..runtimes.function_reference import FunctionReference from ..serving.server import Response -from ..utils import get_caller_globals, logger, normalize_name, str_to_timestamp +from ..utils import get_caller_globals, logger, normalize_name from .common import ( RunConfig, get_feature_set_by_uri, @@ -77,7 +78,7 @@ def _features_to_vector_and_check_permissions(features, update_stats): "feature vector name must be specified" ) verify_feature_vector_permissions( - vector, mlrun.api.schemas.AuthorizationAction.update + vector, mlrun.common.schemas.AuthorizationAction.update ) vector.save() @@ -102,9 +103,9 @@ def get_offline_features( engine: str = None, engine_args: dict = None, query: str = None, - join_type: str = "inner", order_by: Union[str, List[str]] = None, spark_service: str = None, + timestamp_for_filtering: Union[str, Dict[str, str]] = None, ) -> OfflineVectorResponse: """retrieve offline feature vector results @@ -134,37 +135,44 @@ def get_offline_features( print(vector.get_stats_table()) resp.to_parquet("./out.parquet") - :param feature_vector: feature vector uri or FeatureVector object. passing feature vector obj requires update - permissions - :param entity_rows: dataframe with entity rows to join with - :param target: where to write the results to - :param drop_columns: list of columns to drop from the final result - :param entity_timestamp_column: timestamp column name in the entity rows dataframe - :param run_config: function and/or run configuration - see :py:class:`~mlrun.feature_store.RunConfig` - :param start_time: datetime, low limit of time needed to be filtered. Optional. - entity_timestamp_column must be passed when using time filtering. - :param end_time: datetime, high limit of time needed to be filtered. Optional. - entity_timestamp_column must be passed when using time filtering. - :param with_indexes: return vector with index columns and timestamp_key from the feature sets (default False) - :param update_stats: update features statistics from the requested feature sets on the vector. Default is False. - :param engine: processing engine kind ("local", "dask", or "spark") - :param engine_args: kwargs for the processing engine - :param query: The query string used to filter rows - :param spark_service: Name of the spark service to be used (when using a remote-spark runtime) - :param join_type: {'left', 'right', 'outer', 'inner'}, default 'inner' - Supported retrieval engines: "dask", "local" - This parameter is in use when entity_timestamp_column and - feature_vector.spec.timestamp_field are None, if one of them - isn't none we're preforming as_of join. - Possible values : - * left: use only keys from left frame (SQL: left outer join) - * right: use only keys from right frame (SQL: right outer join) - * outer: use union of keys from both frames (SQL: full outer join) - * inner: use intersection of keys from both frames (SQL: inner join). - :param order_by: Name or list of names to order by. The name or the names in the list can be the feature name - or the alias of the feature you pass in the feature list. + :param feature_vector: feature vector uri or FeatureVector object. passing feature vector obj requires + update permissions + :param entity_rows: dataframe with entity rows to join with + :param target: where to write the results to + :param drop_columns: list of columns to drop from the final result + :param entity_timestamp_column: timestamp column name in the entity rows dataframe. can be specified + only if param entity_rows was specified. + :param run_config: function and/or run configuration + see :py:class:`~mlrun.feature_store.RunConfig` + :param start_time: datetime, low limit of time needed to be filtered. Optional. + :param end_time: datetime, high limit of time needed to be filtered. Optional. + :param with_indexes: Return vector with/without the entities and the timestamp_key of the feature sets + and with/without entity_timestamp_column and timestamp_for_filtering columns. + This property can be specified also in the feature vector spec + (feature_vector.spec.with_indexes) + (default False) + :param update_stats: update features statistics from the requested feature sets on the vector. + (default False). + :param engine: processing engine kind ("local", "dask", or "spark") + :param engine_args: kwargs for the processing engine + :param query: The query string used to filter rows on the output + :param spark_service: Name of the spark service to be used (when using a remote-spark runtime) + :param order_by: Name or list of names to order by. The name or the names in the list can be the + feature name or the alias of the feature you pass in the feature list. + :param timestamp_for_filtering: name of the column to filter by, can be str for all the feature sets or a + dictionary ({: , ...}) + that indicates the timestamp column name for each feature set. Optional. + By default, the filter executes on the timestamp_key of each feature set. + Note: the time filtering is performed on each feature set before the + merge process using start_time and end_time params. + """ + if entity_rows is None and entity_timestamp_column is not None: + raise mlrun.errors.MLRunInvalidArgumentError( + "entity_timestamp_column param " + "can not be specified without entity_rows param" + ) + if isinstance(feature_vector, FeatureVector): update_stats = True @@ -187,24 +195,17 @@ def get_offline_features( engine_args, spark_service, entity_rows, - timestamp_column=entity_timestamp_column, + entity_timestamp_column=entity_timestamp_column, run_config=run_config, drop_columns=drop_columns, with_indexes=with_indexes, query=query, - join_type=join_type, order_by=order_by, + start_time=start_time, + end_time=end_time, + timestamp_for_filtering=timestamp_for_filtering, ) - start_time = str_to_timestamp(start_time) - end_time = str_to_timestamp(end_time) - if (start_time or end_time) and not entity_timestamp_column: - raise TypeError( - "entity_timestamp_column or feature_vector.spec.timestamp_field is required when passing start/end time" - ) - if start_time and not end_time: - # if end_time is not specified set it to now() - end_time = pd.Timestamp.now() merger = merger_engine(feature_vector, **(engine_args or {})) return merger.start( entity_rows, @@ -213,10 +214,10 @@ def get_offline_features( drop_columns=drop_columns, start_time=start_time, end_time=end_time, + timestamp_for_filtering=timestamp_for_filtering, with_indexes=with_indexes, update_stats=update_stats, query=query, - join_type=join_type, order_by=order_by, ) @@ -327,6 +328,21 @@ def _rename_source_dataframe_columns(df): return df +def _get_namespace(run_config: RunConfig) -> Dict[str, Any]: + # if running locally, we need to import the file dynamically to get its namespace + if run_config and run_config.local and run_config.function: + filename = run_config.function.spec.filename + if filename: + module_name = pathlib.Path(filename).name.rsplit(".", maxsplit=1)[0] + spec = importlib.util.spec_from_file_location(module_name, filename) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return vars(__import__(module_name)) + else: + return get_caller_globals() + + def ingest( featureset: Union[FeatureSet, str] = None, source=None, @@ -372,7 +388,8 @@ def ingest( :param targets: optional list of data target objects :param namespace: namespace or module containing graph classes :param return_df: indicate if to return a dataframe with the graph results - :param infer_options: schema and stats infer options (:py:class:`~mlrun.feature_store.InferOptions`) + :param infer_options: schema (for discovery of entities, features in featureset), index, stats, + histogram and preview infer options (:py:class:`~mlrun.feature_store.InferOptions`) :param run_config: function and/or run configuration for remote jobs, see :py:class:`~mlrun.feature_store.RunConfig` :param mlrun_context: mlrun context (when running as a job), for internal use ! @@ -410,6 +427,15 @@ def ingest( raise mlrun.errors.MLRunInvalidArgumentError( "feature set and source must be specified" ) + if ( + not mlrun_context + and not targets + and not (featureset.spec.targets or featureset.spec.with_default_targets) + and (run_config is not None and not run_config.local) + ): + raise mlrun.errors.MLRunInvalidArgumentError( + f"Feature set {featureset.metadata.name} is remote ingested with no targets defined, aborting" + ) if featureset is not None: featureset.validate_steps(namespace=namespace) @@ -421,7 +447,7 @@ def ingest( ) # remote job execution verify_feature_set_permissions( - featureset, mlrun.api.schemas.AuthorizationAction.update + featureset, mlrun.common.schemas.AuthorizationAction.update ) run_config = run_config.copy() if run_config else RunConfig() source, run_config.parameters = set_task_params( @@ -453,7 +479,7 @@ def ingest( featureset.validate_steps(namespace=namespace) verify_feature_set_permissions( - featureset, mlrun.api.schemas.AuthorizationAction.update + featureset, mlrun.common.schemas.AuthorizationAction.update ) if not source: raise mlrun.errors.MLRunInvalidArgumentError( @@ -482,19 +508,21 @@ def ingest( f"Source.end_time is {str(source.end_time)}" ) - if mlrun_context: - mlrun_context.logger.info( - f"starting ingestion task to {featureset.uri}.{filter_time_string}" - ) + if mlrun_context: + mlrun_context.logger.info( + f"starting ingestion task to {featureset.uri}.{filter_time_string}" + ) + return_df = False if featureset.spec.passthrough: featureset.spec.source = source featureset.spec.validate_no_processing_for_passthrough() - namespace = namespace or get_caller_globals() + if not namespace: + namespace = _get_namespace(run_config) - targets_to_ingest = targets or featureset.spec.targets or get_default_targets() + targets_to_ingest = targets or featureset.spec.targets targets_to_ingest = copy.deepcopy(targets_to_ingest) validate_target_paths_for_engine(targets_to_ingest, featureset.spec.engine, source) @@ -638,10 +666,14 @@ def preview( :param entity_columns: list of entity (index) column names :param timestamp_key: DEPRECATED. Use FeatureSet parameter. :param namespace: namespace or module containing graph classes - :param options: schema and stats infer options (:py:class:`~mlrun.feature_store.InferOptions`) + :param options: schema (for discovery of entities, features in featureset), index, stats, + histogram and preview infer options (:py:class:`~mlrun.feature_store.InferOptions`) :param verbose: verbose log :param sample_size: num of rows to sample from the dataset (for large datasets) """ + if isinstance(source, pd.DataFrame): + source = _rename_source_dataframe_columns(source) + # preview reads the source as a pandas df, which is not fully compatible with spark if featureset.spec.engine == "spark": raise mlrun.errors.MLRunInvalidArgumentError( @@ -666,7 +698,7 @@ def preview( source = mlrun.store_manager.object(url=source).as_df() verify_feature_set_permissions( - featureset, mlrun.api.schemas.AuthorizationAction.update + featureset, mlrun.common.schemas.AuthorizationAction.update ) featureset.spec.validate_no_processing_for_passthrough() @@ -691,7 +723,9 @@ def preview( ) # reduce the size of the ingestion if we do not infer stats rows_limit = ( - 0 if InferOptions.get_common_options(options, InferOptions.Stats) else 1000 + None + if InferOptions.get_common_options(options, InferOptions.Stats) + else 1000 ) source = init_featureset_graph( source, @@ -762,7 +796,7 @@ def deploy_ingestion_service( featureset = get_feature_set_by_uri(featureset) verify_feature_set_permissions( - featureset, mlrun.api.schemas.AuthorizationAction.update + featureset, mlrun.common.schemas.AuthorizationAction.update ) verify_feature_set_exists(featureset) @@ -775,7 +809,7 @@ def deploy_ingestion_service( name=featureset.metadata.name, ) - targets_to_ingest = targets or featureset.spec.targets or get_default_targets() + targets_to_ingest = targets or featureset.spec.targets targets_to_ingest = copy.deepcopy(targets_to_ingest) featureset.update_targets_for_ingest(targets_to_ingest) @@ -837,7 +871,11 @@ def _ingest_with_spark( f"{featureset.metadata.project}-{featureset.metadata.name}" ) - spark = pyspark.sql.SparkSession.builder.appName(session_name).getOrCreate() + spark = ( + pyspark.sql.SparkSession.builder.appName(session_name) + .config("spark.sql.session.timeZone", "UTC") + .getOrCreate() + ) created_spark_context = True timestamp_key = featureset.spec.timestamp_key @@ -848,7 +886,6 @@ def _ingest_with_spark( df = source else: df = source.to_spark_df(spark, time_field=timestamp_key) - df = source.filter_df_start_end_time(df, timestamp_key) if featureset.spec.graph and featureset.spec.graph.steps: df = run_spark_graph(df, featureset, namespace, spark) @@ -868,14 +905,6 @@ def _ingest_with_spark( target.set_resource(featureset) if featureset.spec.passthrough and target.is_offline: continue - if target.path and urlparse(target.path).scheme == "": - if mlrun_context: - mlrun_context.logger.error( - "Paths for spark ingest must contain schema, i.e v3io, s3, az" - ) - raise mlrun.errors.MLRunInvalidArgumentError( - "Paths for spark ingest must contain schema, i.e v3io, s3, az" - ) spark_options = target.get_spark_options( key_columns, timestamp_key, overwrite ) @@ -962,11 +991,15 @@ def _infer_from_static_df( ): """infer feature-set schema & stats from static dataframe (without pipeline)""" if hasattr(df, "to_dataframe"): + if hasattr(df, "time_field"): + time_field = df.time_field or featureset.spec.timestamp_key + else: + time_field = featureset.spec.timestamp_key if df.is_iterator(): # todo: describe over multiple chunks - df = next(df.to_dataframe()) + df = next(df.to_dataframe(time_field=time_field)) else: - df = df.to_dataframe() + df = df.to_dataframe(time_field=time_field) inferer = get_infer_interface(df) if InferOptions.get_common_options(options, InferOptions.schema()): featureset.spec.timestamp_key = inferer.infer_schema( diff --git a/mlrun/feature_store/common.py b/mlrun/feature_store/common.py index 8198e217f6bc..77c1e756a72a 100644 --- a/mlrun/feature_store/common.py +++ b/mlrun/feature_store/common.py @@ -16,11 +16,11 @@ import mlrun import mlrun.errors -from mlrun.api.schemas import AuthorizationVerificationInput +from mlrun.common.schemas import AuthorizationVerificationInput from mlrun.runtimes import BaseRuntime from mlrun.runtimes.function_reference import FunctionReference from mlrun.runtimes.utils import enrich_function_from_dict -from mlrun.utils import StorePrefix, logger, mlconf, parse_versioned_object_uri +from mlrun.utils import StorePrefix, logger, parse_versioned_object_uri from ..config import config @@ -86,13 +86,13 @@ def get_feature_set_by_uri(uri, project=None): db = mlrun.get_run_db() project, name, tag, uid = parse_feature_set_uri(uri, project) resource = ( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set.to_resource_string( + mlrun.common.schemas.AuthorizationResourceTypes.feature_set.to_resource_string( project, "feature-set" ) ) auth_input = AuthorizationVerificationInput( - resource=resource, action=mlrun.api.schemas.AuthorizationAction.read + resource=resource, action=mlrun.common.schemas.AuthorizationAction.read ) db.verify_authorization(auth_input) @@ -115,19 +115,17 @@ def get_feature_vector_by_uri(uri, project=None, update=True): project, name, tag, uid = parse_versioned_object_uri(uri, default_project) - resource = ( - mlrun.api.schemas.AuthorizationResourceTypes.feature_vector.to_resource_string( - project, "feature-vector" - ) + resource = mlrun.common.schemas.AuthorizationResourceTypes.feature_vector.to_resource_string( + project, "feature-vector" ) if update: auth_input = AuthorizationVerificationInput( - resource=resource, action=mlrun.api.schemas.AuthorizationAction.update + resource=resource, action=mlrun.common.schemas.AuthorizationAction.update ) else: auth_input = AuthorizationVerificationInput( - resource=resource, action=mlrun.api.schemas.AuthorizationAction.read + resource=resource, action=mlrun.common.schemas.AuthorizationAction.read ) db.verify_authorization(auth_input) @@ -136,12 +134,12 @@ def get_feature_vector_by_uri(uri, project=None, update=True): def verify_feature_set_permissions( - feature_set, action: mlrun.api.schemas.AuthorizationAction + feature_set, action: mlrun.common.schemas.AuthorizationAction ): project, _, _, _ = parse_feature_set_uri(feature_set.uri) resource = ( - mlrun.api.schemas.AuthorizationResourceTypes.feature_set.to_resource_string( + mlrun.common.schemas.AuthorizationResourceTypes.feature_set.to_resource_string( project, "feature-set" ) ) @@ -164,14 +162,12 @@ def verify_feature_set_exists(feature_set): def verify_feature_vector_permissions( - feature_vector, action: mlrun.api.schemas.AuthorizationAction + feature_vector, action: mlrun.common.schemas.AuthorizationAction ): - project = feature_vector._metadata.project or mlconf.default_project + project = feature_vector._metadata.project or config.default_project - resource = ( - mlrun.api.schemas.AuthorizationResourceTypes.feature_vector.to_resource_string( - project, "feature-vector" - ) + resource = mlrun.common.schemas.AuthorizationResourceTypes.feature_vector.to_resource_string( + project, "feature-vector" ) db = mlrun.get_run_db() @@ -218,7 +214,7 @@ def __init__( config = RunConfig("mycode.py", image="mlrun/mlrun", requirements=["spacy"]) # config for using function object - function = mlrun.import_function("hub://some_function") + function = mlrun.import_function("hub://some-function") config = RunConfig(function) :param function: this can be function uri or function object or path to function code (.py/.ipynb) diff --git a/mlrun/feature_store/feature_set.py b/mlrun/feature_store/feature_set.py index e19eaecd30c1..dd0a96f20f89 100644 --- a/mlrun/feature_store/feature_set.py +++ b/mlrun/feature_store/feature_set.py @@ -19,7 +19,7 @@ from storey import EmitEveryEvent, EmitPolicy import mlrun -import mlrun.api.schemas +import mlrun.common.schemas from ..config import config as mlconf from ..datastore import get_store_uri @@ -131,6 +131,7 @@ def __init__( self.engine = engine self.output_path = output_path or mlconf.artifact_path self.passthrough = passthrough + self.with_default_targets = True @property def entities(self) -> List[Entity]: @@ -185,7 +186,8 @@ def engine(self) -> str: @engine.setter def engine(self, engine: str): engine_list = ["pandas", "spark", "storey"] - if engine and engine not in engine_list: + engine = engine if engine else "storey" + if engine not in engine_list: raise mlrun.errors.MLRunInvalidArgumentError( f"engine must be one of {','.join(engine_list)}" ) @@ -316,7 +318,7 @@ def emit_policy_to_dict(policy: EmitPolicy): class FeatureSet(ModelObj): """Feature set object, defines a set of features and their data pipeline""" - kind = mlrun.api.schemas.ObjectKind.feature_set.value + kind = mlrun.common.schemas.ObjectKind.feature_set.value _dict_fields = ["kind", "metadata", "spec", "status"] def __init__( @@ -375,6 +377,7 @@ def __init__( self.status = None self._last_state = "" self._aggregations = {} + self.set_targets() @property def spec(self) -> FeatureSetSpec: @@ -473,10 +476,25 @@ def set_targets( ) targets = targets or [] if with_defaults: + self.spec.with_default_targets = True targets.extend(get_default_targets()) + else: + self.spec.with_default_targets = False - validate_target_list(targets=targets) + self.spec.targets = [] + self.__set_targets_add_targets_helper(targets) + + if default_final_step: + self.spec.graph.final_step = default_final_step + + def __set_targets_add_targets_helper(self, targets): + """ + Add the desired target list + :param targets: list of target type names ('csv', 'nosql', ..) or target objects + CSVTarget(), ParquetTarget(), NoSqlTarget(), StreamTarget(), .. + """ + validate_target_list(targets=targets) for target in targets: kind = target.kind if hasattr(target, "kind") else target if kind not in TargetTypes.all(): @@ -488,8 +506,6 @@ def set_targets( target, name=str(target), partitioned=(target == "parquet") ) self.spec.targets.update(target) - if default_final_step: - self.spec.graph.final_step = default_final_step def validate_steps(self, namespace): if not self.spec: @@ -525,7 +541,7 @@ def purge_targets(self, target_names: List[str] = None, silent: bool = False): :param silent: Fail silently if target doesn't exist in featureset status""" verify_feature_set_permissions( - self, mlrun.api.schemas.AuthorizationAction.delete + self, mlrun.common.schemas.AuthorizationAction.delete ) purge_targets = self._reload_and_get_status_targets( @@ -926,7 +942,17 @@ def to_dataframe( raise mlrun.errors.MLRunNotFoundError( "passthrough feature set {self.metadata.name} with no source" ) - return self.spec.source.to_dataframe() + df = self.spec.source.to_dataframe( + columns=columns, + start_time=start_time, + end_time=end_time, + time_field=time_column, + **kwargs, + ) + # to_dataframe() can sometimes return an iterator of dataframes instead of one dataframe + if not isinstance(df, pd.DataFrame): + df = pd.concat(df) + return df target = get_offline_target(self, name=target_name) if not target: diff --git a/mlrun/feature_store/feature_vector.py b/mlrun/feature_store/feature_vector.py index bb4b84edda72..c219170d9448 100644 --- a/mlrun/feature_store/feature_vector.py +++ b/mlrun/feature_store/feature_vector.py @@ -154,7 +154,7 @@ def features(self, features: List[Feature]): class FeatureVector(ModelObj): """Feature vector, specify selected features, their metadata and material views""" - kind = mlrun.api.schemas.ObjectKind.feature_vector.value + kind = mlrun.common.schemas.ObjectKind.feature_vector.value _dict_fields = ["kind", "metadata", "spec", "status"] def __init__( diff --git a/mlrun/feature_store/ingestion.py b/mlrun/feature_store/ingestion.py index d2649c395103..d07b1691ac75 100644 --- a/mlrun/feature_store/ingestion.py +++ b/mlrun/feature_store/ingestion.py @@ -89,7 +89,7 @@ def init_featureset_graph( key_fields = entity_columns if entity_columns else None sizes = [0] * len(targets) - data_result = None + result_dfs = [] total_rows = 0 targets = [get_target_driver(target, featureset) for target in targets] if featureset.spec.passthrough: @@ -100,11 +100,11 @@ def init_featureset_graph( # set the entities to be the indexes of the df event.body = entities_to_index(featureset, event.body) - data = server.run(event, get_body=True) - if data is not None: + df = server.run(event, get_body=True) + if df is not None: for i, target in enumerate(targets): size = target.write_dataframe( - data, + df, key_column=key_fields, timestamp_key=featureset.spec.timestamp_key, chunk_id=chunk_id, @@ -112,21 +112,18 @@ def init_featureset_graph( if size: sizes[i] += size chunk_id += 1 - if data_result is None: - # in case of multiple chunks only return the first chunk (last may be too small) - data_result = data - total_rows += data.shape[0] + result_dfs.append(df) + total_rows += df.shape[0] if rows_limit and total_rows >= rows_limit: break - # todo: fire termination event if iterator - for i, target in enumerate(targets): target_status = target.update_resource_status("ready", size=sizes[i]) if verbose: logger.info(f"wrote target: {target_status}") - return data_result + result_df = pd.concat(result_dfs) + return result_df.head(rows_limit) def featureset_initializer(server): diff --git a/mlrun/feature_store/retrieval/base.py b/mlrun/feature_store/retrieval/base.py index ec0459f7524f..e0cb12abce96 100644 --- a/mlrun/feature_store/retrieval/base.py +++ b/mlrun/feature_store/retrieval/base.py @@ -16,12 +16,15 @@ import typing from datetime import datetime +import dask.dataframe as dd +import pandas as pd + import mlrun from mlrun.datastore.targets import CSVTarget, ParquetTarget from mlrun.feature_store.feature_set import FeatureSet from mlrun.feature_store.feature_vector import Feature -from ...utils import logger +from ...utils import logger, str_to_timestamp from ..feature_vector import OfflineVectorResponse @@ -73,23 +76,19 @@ def start( drop_columns=None, start_time=None, end_time=None, + timestamp_for_filtering=None, with_indexes=None, update_stats=None, query=None, - join_type="inner", order_by=None, ): self._target = target - self._join_type = join_type # calculate the index columns and columns we need to drop self._drop_columns = drop_columns or self._drop_columns if self.vector.spec.with_indexes or with_indexes: self._drop_indexes = False - if entity_timestamp_column and self._drop_indexes: - self._append_drop_column(entity_timestamp_column) - # retrieve the feature set objects/fields needed for the vector feature_set_objects, feature_set_fields = self.vector.parse_features( update_stats=update_stats @@ -103,12 +102,21 @@ def start( # update the feature vector objects with refreshed stats self.vector.save() + if self._drop_indexes and entity_timestamp_column: + self._append_drop_column(entity_timestamp_column) + for feature_set in feature_set_objects.values(): - if not entity_timestamp_column and self._drop_indexes: + if self._drop_indexes: self._append_drop_column(feature_set.spec.timestamp_key) for key in feature_set.spec.entities.keys(): self._append_index(key) + start_time = str_to_timestamp(start_time) + end_time = str_to_timestamp(end_time) + if start_time and not end_time: + # if end_time is not specified set it to now() + end_time = pd.Timestamp.now() + return self._generate_vector( entity_rows, entity_timestamp_column, @@ -116,6 +124,7 @@ def start( feature_set_fields=feature_set_fields, start_time=start_time, end_time=end_time, + timestamp_for_filtering=timestamp_for_filtering, query=query, order_by=order_by, ) @@ -168,6 +177,7 @@ def _generate_vector( feature_set_fields, start_time=None, end_time=None, + timestamp_for_filtering=None, query=None, order_by=None, ): @@ -185,6 +195,7 @@ def _generate_vector( feature_set_objects, feature_set_fields ) + filtered = False for node in fs_link_list: name = node.name feature_set = feature_set_objects[name] @@ -198,24 +209,44 @@ def _generate_vector( self._append_drop_column(column) column_names.append(column) + if isinstance(timestamp_for_filtering, dict): + time_column = timestamp_for_filtering.get( + name, feature_set.spec.timestamp_key + ) + elif isinstance(timestamp_for_filtering, str): + time_column = timestamp_for_filtering + else: + time_column = feature_set.spec.timestamp_key + + if time_column != feature_set.spec.timestamp_key and time_column not in [ + feature.name for feature in feature_set.spec.features + ]: + raise mlrun.errors.MLRunInvalidArgumentError( + f"Feature set `{name}` " + f"does not have a column named `{time_column}` to filter on." + ) + + if self._drop_indexes: + self._append_drop_column(time_column) + if (start_time or end_time) and time_column: + filtered = True + df = self._get_engine_df( feature_set, name, column_names, - start_time, - end_time, - entity_timestamp_column, + start_time if time_column else None, + end_time if time_column else None, + time_column, ) column_names += node.data["save_index"] node.data["save_cols"] += node.data["save_index"] + fs_entities_and_timestamp = list(feature_set.spec.entities.keys()) if feature_set.spec.timestamp_key: - entity_timestamp_column_list = [feature_set.spec.timestamp_key] - column_names += entity_timestamp_column_list - node.data["save_cols"] += entity_timestamp_column_list - if not entity_timestamp_column: - # if not entity_timestamp_column the firs `FeatureSet` will define it - entity_timestamp_column = feature_set.spec.timestamp_key + column_names.append(feature_set.spec.timestamp_key) + node.data["save_cols"].append(feature_set.spec.timestamp_key) + fs_entities_and_timestamp.append(feature_set.spec.timestamp_key) # rename columns to be unique for each feature set and select if needed rename_col_dict = { @@ -223,9 +254,10 @@ def _generate_vector( for column in column_names if column not in node.data["save_cols"] } - fs_entities = list(feature_set.spec.entities.keys()) df_temp = self._rename_columns_and_select( - df, rename_col_dict, columns=list(set(column_names + fs_entities)) + df, + rename_col_dict, + columns=list(set(column_names + fs_entities_and_timestamp)), ) if df_temp is not None: @@ -240,7 +272,7 @@ def _generate_vector( # update alias according to the unique column name new_columns = [] if not self._drop_indexes: - new_columns.extend([(ind, ind) for ind in fs_entities]) + new_columns.extend([(ind, ind) for ind in fs_entities_and_timestamp]) for column, alias in columns: if column in rename_col_dict: new_columns.append((rename_col_dict[column], alias or column)) @@ -248,6 +280,12 @@ def _generate_vector( new_columns.append((column, alias)) self._update_alias(dictionary={name: alias for name, alias in new_columns}) + # None of the feature sets was filtered as required + if not filtered and (start_time or end_time): + raise mlrun.errors.MLRunRuntimeError( + "start_time and end_time can only be provided in conjunction with " + "a timestamp column, or when the at least one feature_set has a timestamp key" + ) # convert pandas entity_rows to spark DF if needed if ( entity_rows is not None @@ -255,22 +293,30 @@ def _generate_vector( and self.engine == "spark" ): entity_rows = self.spark.createDataFrame(entity_rows) + elif ( + entity_rows is not None + and not hasattr(entity_rows, "dask") + and self.engine == "dask" + ): + entity_rows = dd.from_pandas( + entity_rows, npartitions=len(entity_rows.columns) + ) # join the feature data frames - self.merge( + result_timestamp = self.merge( entity_df=entity_rows, - entity_timestamp_column=entity_timestamp_column, + entity_timestamp_column=entity_timestamp_column + if entity_rows is not None + else None, featuresets=feature_sets, featureset_dfs=dfs, keys=keys, ) all_columns = None - if not self._drop_indexes and entity_timestamp_column: - if entity_timestamp_column not in self._alias.values(): - self._update_alias( - key=entity_timestamp_column, val=entity_timestamp_column - ) + if not self._drop_indexes and result_timestamp: + if result_timestamp not in self._alias.values(): + self._update_alias(key=result_timestamp, val=result_timestamp) all_columns = list(self._alias.keys()) df_temp = self._rename_columns_and_select( @@ -343,13 +389,8 @@ def merge( keys[0][0] = keys[0][1] = list(featuresets[0].spec.entities.keys()) for featureset, featureset_df, lr_key in zip(featuresets, featureset_dfs, keys): - if featureset.spec.timestamp_key: + if featureset.spec.timestamp_key and entity_timestamp_column: merge_func = self._asof_join - if self._join_type != "inner": - logger.warn( - "Merge all the features with as_of_join and don't " - "take into account the join_type that was given" - ) else: merge_func = self._join @@ -361,6 +402,9 @@ def merge( lr_key[0], lr_key[1], ) + entity_timestamp_column = ( + entity_timestamp_column or featureset.spec.timestamp_key + ) # unpersist as required by the implementation (e.g. spark) and delete references # to dataframe to allow for GC to free up the memory (local, dask) @@ -368,6 +412,7 @@ def merge( del featureset_df self._result_df = merged_df + return entity_timestamp_column @abc.abstractmethod def _asof_join( @@ -653,7 +698,7 @@ def _get_engine_df( column_names: typing.List[str] = None, start_time: typing.Union[str, datetime] = None, end_time: typing.Union[str, datetime] = None, - entity_timestamp_column: str = None, + time_column: typing.Optional[str] = None, ): """ Return the feature_set data frame according to the args @@ -663,7 +708,7 @@ def _get_engine_df( :param column_names: list of columns to select (if not all) :param start_time: filter by start time :param end_time: filter by end time - :param entity_timestamp_column: specify the time column name in the file + :param time_column: specify the time column name to filter on :return: Data frame of the current engine """ diff --git a/mlrun/feature_store/retrieval/dask_merger.py b/mlrun/feature_store/retrieval/dask_merger.py index 585782b20669..951a660665e1 100644 --- a/mlrun/feature_store/retrieval/dask_merger.py +++ b/mlrun/feature_store/retrieval/dask_merger.py @@ -45,12 +45,21 @@ def _asof_join( left_keys: list, right_keys: list, ): + def sort_partition(partition, timestamp): + return partition.sort_values(timestamp) + + entity_df = entity_df.map_partitions( + sort_partition, timestamp=entity_timestamp_column + ) + featureset_df = featureset_df.map_partitions( + sort_partition, timestamp=featureset.spec.timestamp_key + ) merged_df = merge_asof( entity_df, featureset_df, left_on=entity_timestamp_column, - right_on=entity_timestamp_column, + right_on=featureset.spec.timestamp_key, left_by=left_keys or None, right_by=right_keys or None, suffixes=("", f"_{featureset.metadata.name}_"), @@ -117,14 +126,14 @@ def _get_engine_df( column_names=None, start_time=None, end_time=None, - entity_timestamp_column=None, + time_column=None, ): df = feature_set.to_dataframe( columns=column_names, df_module=dd, start_time=start_time, end_time=end_time, - time_column=entity_timestamp_column, + time_column=time_column, index=False, ) diff --git a/mlrun/feature_store/retrieval/job.py b/mlrun/feature_store/retrieval/job.py index 63a48e14234c..dfa3f7505547 100644 --- a/mlrun/feature_store/retrieval/job.py +++ b/mlrun/feature_store/retrieval/job.py @@ -33,13 +33,15 @@ def run_merge_job( engine_args: dict, spark_service: str = None, entity_rows=None, - timestamp_column=None, + entity_timestamp_column=None, run_config=None, drop_columns=None, with_indexes=None, query=None, - join_type="inner", order_by=None, + start_time=None, + end_time=None, + timestamp_for_filtering=None, ): name = vector.metadata.name if not target or not hasattr(target, "to_dict"): @@ -93,21 +95,27 @@ def set_default_resources(resources, setter_function): set_default_resources( function.spec.executor_resources, function.with_executor_requests ) + if start_time and not isinstance(start_time, str): + start_time = start_time.isoformat() + if end_time and not isinstance(end_time, str): + end_time = end_time.isoformat() task = new_task( name=name, params={ "vector_uri": vector.uri, "target": target.to_dict(), - "timestamp_column": timestamp_column, + "entity_timestamp_column": entity_timestamp_column, "drop_columns": drop_columns, "with_indexes": with_indexes, "query": query, - "join_type": join_type, "order_by": order_by, + "start_time": start_time, + "end_time": end_time, + "timestamp_for_filtering": timestamp_for_filtering, "engine_args": engine_args, }, - inputs={"entity_rows": entity_rows}, + inputs={"entity_rows": entity_rows} if entity_rows is not None else {}, ) task.spec.secret_sources = run_config.secret_sources task.set_label("job-type", "feature-merge").set_label("feature-vector", vector.uri) @@ -174,18 +182,18 @@ def target_uri(self): import mlrun.feature_store.retrieval from mlrun.datastore.targets import get_target_driver def merge_handler(context, vector_uri, target, entity_rows=None, - timestamp_column=None, drop_columns=None, with_indexes=None, query=None, join_type='inner', - engine_args=None, order_by=None): + entity_timestamp_column=None, drop_columns=None, with_indexes=None, query=None, + engine_args=None, order_by=None, start_time=None, end_time=None, timestamp_for_filtering=None): vector = context.get_store_resource(vector_uri) store_target = get_target_driver(target, vector) - entity_timestamp_column = timestamp_column or vector.spec.timestamp_field if entity_rows: entity_rows = entity_rows.as_df() context.logger.info(f"starting vector merge task to {vector.uri}") merger = mlrun.feature_store.retrieval.{{{engine}}}(vector, **(engine_args or {})) merger.start(entity_rows, entity_timestamp_column, store_target, drop_columns, with_indexes=with_indexes, - query=query, join_type=join_type, order_by=order_by) + query=query, order_by=order_by, start_time=start_time, end_time=end_time, + timestamp_for_filtering=timestamp_for_filtering) target = vector.status.targets[store_target.name].to_dict() context.log_result('feature_vector', vector.uri) diff --git a/mlrun/feature_store/retrieval/local_merger.py b/mlrun/feature_store/retrieval/local_merger.py index 347b0c9dd2b3..c98d977121d5 100644 --- a/mlrun/feature_store/retrieval/local_merger.py +++ b/mlrun/feature_store/retrieval/local_merger.py @@ -47,7 +47,7 @@ def _asof_join( featureset_df[featureset.spec.timestamp_key] ) entity_df.sort_values(by=entity_timestamp_column, inplace=True) - featureset_df.sort_values(by=entity_timestamp_column, inplace=True) + featureset_df.sort_values(by=featureset.spec.timestamp_key, inplace=True) merged_df = pd.merge_asof( entity_df, @@ -62,7 +62,6 @@ def _asof_join( for col in merged_df.columns: if re.findall(f"_{featureset.metadata.name}_$", col): self._append_drop_column(col) - # Undo indexing tricks for asof merge # to return the correct indexes and not # overload `index` columns @@ -109,25 +108,14 @@ def _get_engine_df( column_names=None, start_time=None, end_time=None, - entity_timestamp_column=None, + time_column=None, ): - # handling case where there are multiple feature sets and user creates vector where entity_timestamp_ - # column is from a specific feature set (can't be entity timestamp) - if ( - entity_timestamp_column in column_names - or feature_set.spec.timestamp_key == entity_timestamp_column - ): - df = feature_set.to_dataframe( - columns=column_names, - start_time=start_time, - end_time=end_time, - time_column=entity_timestamp_column, - ) - else: - df = feature_set.to_dataframe( - columns=column_names, - time_column=entity_timestamp_column, - ) + df = feature_set.to_dataframe( + columns=column_names, + start_time=start_time, + end_time=end_time, + time_column=time_column, + ) if df.index.names[0]: df.reset_index(inplace=True) return df diff --git a/mlrun/feature_store/retrieval/online.py b/mlrun/feature_store/retrieval/online.py index 601361ae8488..d873ec862fb3 100644 --- a/mlrun/feature_store/retrieval/online.py +++ b/mlrun/feature_store/retrieval/online.py @@ -69,6 +69,10 @@ def init_feature_vector_graph(vector, query_options, update_stats=False): feature_set_objects, feature_set_fields = vector.parse_features( offline=False, update_stats=update_stats ) + if not feature_set_fields: + raise mlrun.errors.MLRunRuntimeError( + f"No features found for feature vector '{vector.metadata.name}'" + ) graph = _build_feature_vector_graph( vector, feature_set_fields, feature_set_objects, query_options ) diff --git a/mlrun/feature_store/retrieval/spark_merger.py b/mlrun/feature_store/retrieval/spark_merger.py index 562d4c8bcc0f..d6c8af934c9b 100644 --- a/mlrun/feature_store/retrieval/spark_merger.py +++ b/mlrun/feature_store/retrieval/spark_merger.py @@ -67,7 +67,7 @@ def _asof_join( entity_with_id = entity_df.withColumn("_row_nr", monotonically_increasing_id()) rename_right_keys = {} - for key in right_keys + [entity_timestamp_column]: + for key in right_keys + [featureset.spec.timestamp_key]: if key in entity_df.columns: rename_right_keys[key] = f"ft__{key}" # get columns for projection @@ -77,13 +77,14 @@ def _asof_join( ] aliased_featureset_df = featureset_df.select(projection) + right_timestamp = rename_right_keys.get( + featureset.spec.timestamp_key, featureset.spec.timestamp_key + ) # set join conditions join_cond = ( entity_with_id[entity_timestamp_column] - >= aliased_featureset_df[ - rename_right_keys.get(entity_timestamp_column, entity_timestamp_column) - ] + >= aliased_featureset_df[right_timestamp] ) # join based on entities @@ -98,13 +99,13 @@ def _asof_join( ) window = Window.partitionBy("_row_nr").orderBy( - col(f"ft__{entity_timestamp_column}").desc(), + col(right_timestamp).desc(), ) filter_most_recent_feature_timestamp = conditional_join.withColumn( "_rank", row_number().over(window) ).filter(col("_rank") == 1) - for key in right_keys + [entity_timestamp_column]: + for key in right_keys + [featureset.spec.timestamp_key]: if key in entity_df.columns + [entity_timestamp_column]: filter_most_recent_feature_timestamp = ( filter_most_recent_feature_timestamp.drop( @@ -194,7 +195,7 @@ def _get_engine_df( column_names=None, start_time=None, end_time=None, - entity_timestamp_column=None, + time_column=None, ): if feature_set.spec.passthrough: if not feature_set.spec.source: @@ -215,31 +216,27 @@ def _get_engine_df( # handling case where there are multiple feature sets and user creates vector where # entity_timestamp_column is from a specific feature set (can't be entity timestamp) source_driver = mlrun.datastore.sources.source_kind_to_driver[source_kind] + + source = source_driver( + name=self.vector.metadata.name, + path=source_path, + time_field=time_column, + start_time=start_time, + end_time=end_time, + ) + + columns = column_names + [ent.name for ent in feature_set.spec.entities] if ( - entity_timestamp_column in column_names - or feature_set.spec.timestamp_key == entity_timestamp_column + feature_set.spec.timestamp_key + and feature_set.spec.timestamp_key not in columns ): - source = source_driver( - name=self.vector.metadata.name, - path=source_path, - time_field=entity_timestamp_column, - start_time=start_time, - end_time=end_time, - ) - else: - source = source_driver( - name=self.vector.metadata.name, - path=source_path, - time_field=entity_timestamp_column, - ) - - if not entity_timestamp_column: - entity_timestamp_column = feature_set.spec.timestamp_key - # add the index/key to selected columns - timestamp_key = feature_set.spec.timestamp_key + columns.append(feature_set.spec.timestamp_key) return source.to_spark_df( - self.spark, named_view=self.named_view, time_field=timestamp_key + self.spark, + named_view=self.named_view, + time_field=time_column, + columns=columns, ) def _rename_columns_and_select( diff --git a/mlrun/feature_store/steps.py b/mlrun/feature_store/steps.py index 4262caf6c7a3..140a84229fd1 100644 --- a/mlrun/feature_store/steps.py +++ b/mlrun/feature_store/steps.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import math import re import uuid import warnings @@ -40,7 +41,9 @@ def get_engine(first_event): class MLRunStep(MapClass): def __init__(self, **kwargs): """Abstract class for mlrun step. - Can be used in pandas/storey/spark feature set ingestion""" + Can be used in pandas/storey/spark feature set ingestion. Extend this class and implement the relevant + `_do_XXX` methods to support the required execution engines. + """ super().__init__(**kwargs) self._engine_to_do_method = { "pandas": self._do_pandas, @@ -51,23 +54,41 @@ def __init__(self, **kwargs): def do(self, event): """ This method defines the do method of this class according to the first event type. + + .. warning:: + When extending this class, do not override this method; only override the `_do_XXX` methods. """ engine = get_engine(event) self.do = self._engine_to_do_method.get(engine, None) if self.do is None: - raise mlrun.errors.InvalidArgummentError( + raise mlrun.errors.MLRunInvalidArgumentError( f"Unrecognized engine: {engine}. Available engines are: pandas, spark and storey" ) return self.do(event) def _do_pandas(self, event): + """ + The execution method for pandas engine. + + :param event: Incoming event, a `pandas.DataFrame` object. + """ raise NotImplementedError def _do_storey(self, event): + """ + The execution method for storey engine. + + :param event: Incoming event, a dictionary or `storey.Event` object, depending on the `full_event` value. + """ raise NotImplementedError def _do_spark(self, event): + """ + The execution method for spark engine. + + :param event: Incoming event, a `pyspark.sql.DataFrame` object. + """ raise NotImplementedError @@ -136,7 +157,7 @@ class MapValues(StepToDict, MLRunStep): def __init__( self, - mapping: Dict[str, Dict[str, Any]], + mapping: Dict[str, Dict[Union[str, int, bool], Any]], with_original_features: bool = False, suffix: str = "mapped", **kwargs, @@ -226,34 +247,130 @@ def _do_pandas(self, event): def _do_spark(self, event): from itertools import chain - from pyspark.sql.functions import col, create_map, lit, when + from pyspark.sql.functions import col, create_map, isnan, isnull, lit, when + from pyspark.sql.types import DecimalType, DoubleType, FloatType + from pyspark.sql.utils import AnalysisException + df = event + source_column_names = df.columns for column, column_map in self.mapping.items(): new_column_name = self._get_feature_name(column) - if "ranges" not in column_map: + if not self.get_ranges_key() in column_map: + if column not in source_column_names: + continue mapping_expr = create_map([lit(x) for x in chain(*column_map.items())]) - event = event.withColumn( - new_column_name, mapping_expr.getItem(col(column)) - ) + try: + df = df.withColumn( + new_column_name, + when( + col(column).isin(list(column_map.keys())), + mapping_expr.getItem(col(column)), + ).otherwise(col(column)), + ) + # if failed to use otherwise it is probably because the new column has different type + # then the original column. + # we will try to replace the values without using 'otherwise'. + except AnalysisException: + df = df.withColumn( + new_column_name, mapping_expr.getItem(col(column)) + ) + col_type = df.schema[column].dataType + new_col_type = df.schema[new_column_name].dataType + # in order to avoid exception at isna on non-decimal/float columns - + # we need to check their types before filtering. + if isinstance(col_type, (FloatType, DoubleType, DecimalType)): + column_filter = (~isnull(col(column))) & (~isnan(col(column))) + else: + column_filter = ~isnull(col(column)) + if isinstance(new_col_type, (FloatType, DoubleType, DecimalType)): + new_column_filter = isnull(col(new_column_name)) | isnan( + col(new_column_name) + ) + else: + # we need to check that every value replaced if we changed column type - except None or NaN. + new_column_filter = isnull(col(new_column_name)) + mapping_to_null = [ + k + for k, v in column_map.items() + if v is None + or ( + isinstance(v, (float, np.float64, np.float32, np.float16)) + and math.isnan(v) + ) + ] + turned_to_none_values = df.filter( + column_filter & new_column_filter + ).filter(~col(column).isin(mapping_to_null)) + + if len(turned_to_none_values.head(1)) > 0: + raise mlrun.errors.MLRunInvalidArgumentError( + f"MapValues - mapping that changes column type must change all values accordingly," + f" which is not the case for column '{column}'" + ) else: for val, val_range in column_map["ranges"].items(): min_val = val_range[0] if val_range[0] != "-inf" else -np.inf max_val = val_range[1] if val_range[1] != "inf" else np.inf otherwise = "" - if new_column_name in event.columns: - otherwise = event[new_column_name] - event = event.withColumn( + if new_column_name in df.columns: + otherwise = df[new_column_name] + df = df.withColumn( new_column_name, when( - (event[column] < max_val) & (event[column] >= min_val), + (df[column] < max_val) & (df[column] >= min_val), lit(val), ).otherwise(otherwise), ) if not self.with_original_features: - event = event.select(*self.mapping.keys()) + df = df.select(*self.mapping.keys()) - return event + return df + + @classmethod + def validate_args(cls, feature_set, **kwargs): + mapping = kwargs.get("mapping", []) + for column, column_map in mapping.items(): + if not cls.get_ranges_key() in column_map: + types = set( + type(val) + for val in column_map.values() + if type(val) is not None + and not ( + isinstance(val, (float, np.float64, np.float32, np.float16)) + and math.isnan(val) + ) + ) + else: + if len(column_map) > 1: + raise mlrun.errors.MLRunInvalidArgumentError( + f"MapValues - mapping values of the same column can not combine ranges and " + f"single replacement, which is the case for column '{column}'" + ) + ranges_dict = column_map[cls.get_ranges_key()] + types = set() + for ranges_mapping_values in ranges_dict.values(): + range_types = set( + type(val) + for val in ranges_mapping_values + if type(val) is not None + and val != "-inf" + and val != "inf" + and not ( + isinstance(val, (float, np.float64, np.float32, np.float16)) + and math.isnan(val) + ) + ) + types = types.union(range_types) + if len(types) > 1: + raise mlrun.errors.MLRunInvalidArgumentError( + f"MapValues - mapping values of the same column must be in the" + f" same type, which was not the case for Column '{column}'" + ) + + @staticmethod + def get_ranges_key(): + return "ranges" class Imputer(StepToDict, MLRunStep): diff --git a/mlrun/features.py b/mlrun/features.py index ba7eb5584a0e..60615cde1163 100644 --- a/mlrun/features.py +++ b/mlrun/features.py @@ -105,11 +105,12 @@ def __init__( :param labels: a set of key/value labels (tags) """ self.name = name or "" - self.value_type = ( - python_type_to_value_type(value_type) - if value_type is not None - else ValueType.STRING - ) + if isinstance(value_type, ValueType): + self.value_type = value_type + elif value_type is not None: + self.value_type = python_type_to_value_type(value_type) + else: + self.value_type = ValueType.STRING self.dims = dims self.description = description self.default = default diff --git a/mlrun/frameworks/_ml_common/pkl_model_server.py b/mlrun/frameworks/_ml_common/pkl_model_server.py index c726948910ab..7ee09a43a4e6 100644 --- a/mlrun/frameworks/_ml_common/pkl_model_server.py +++ b/mlrun/frameworks/_ml_common/pkl_model_server.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +from typing import Any, Dict + import numpy as np import pandas as pd from cloudpickle import load @@ -56,3 +58,13 @@ def predict(self, request: dict) -> list: y_pred: np.ndarray = self.model.predict(x) return y_pred.tolist() + + def explain(self, request: Dict[str, Any]) -> str: + """ + Returns a string listing the model that is being served in this serving function and the function name. + + :param request: A given request. + + :return: Explanation string. + """ + return f"A model server named '{self.name}'" diff --git a/mlrun/k8s_utils.py b/mlrun/k8s_utils.py index 19130716498e..78966a5f8aef 100644 --- a/mlrun/k8s_utils.py +++ b/mlrun/k8s_utils.py @@ -11,782 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import base64 -import hashlib -import time import typing -from datetime import datetime -from sys import stdout import kubernetes.client -from kubernetes import client, config -from kubernetes.client.rest import ApiException -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.errors from .config import config as mlconfig -from .errors import err_to_str -from .platforms.iguazio import v3io_to_vol -from .utils import logger -_k8s = None +_running_inside_kubernetes_cluster = None -def get_k8s_helper(namespace=None, silent=False, log=False) -> "K8sHelper": - """ - :param silent: set to true if you're calling this function from a code that might run from remotely (outside of a - k8s cluster) - :param log: sometimes we want to avoid logging when executing init_k8s_config - """ - global _k8s - if not _k8s: - _k8s = K8sHelper(namespace, silent=silent, log=log) - return _k8s - - -class SecretTypes: - opaque = "Opaque" - v3io_fuse = "v3io/fuse" - - -class K8sHelper: - def __init__(self, namespace=None, config_file=None, silent=False, log=True): - self.namespace = namespace or mlconfig.namespace - self.config_file = config_file - self.running_inside_kubernetes_cluster = False - try: - self._init_k8s_config(log) - self.v1api = client.CoreV1Api() - self.crdapi = client.CustomObjectsApi() - except Exception: - if not silent: - raise - - def resolve_namespace(self, namespace=None): - return namespace or self.namespace - - def _init_k8s_config(self, log=True): - try: - config.load_incluster_config() - self.running_inside_kubernetes_cluster = True - if log: - logger.info("using in-cluster config.") - except Exception: - try: - config.load_kube_config(self.config_file) - if log: - logger.info("using local kubernetes config.") - except Exception: - raise RuntimeError( - "cannot find local kubernetes config file," - " place it in ~/.kube/config or specify it in " - "KUBECONFIG env var" - ) - - def is_running_inside_kubernetes_cluster(self): - return self.running_inside_kubernetes_cluster - - def list_pods(self, namespace=None, selector="", states=None): - try: - resp = self.v1api.list_namespaced_pod( - self.resolve_namespace(namespace), label_selector=selector - ) - except ApiException as exc: - logger.error(f"failed to list pods: {err_to_str(exc)}") - raise exc - - items = [] - for i in resp.items: - if not states or i.status.phase in states: - items.append(i) - return items - - def clean_pods(self, namespace=None, selector="", states=None): - if not selector and not states: - raise ValueError("labels selector or states list must be specified") - items = self.list_pods(namespace, selector, states) - for item in items: - self.delete_pod(item.metadata.name, item.metadata.namespace) - - def create_pod(self, pod, max_retry=3, retry_interval=3): - if "pod" in dir(pod): - pod = pod.pod - pod.metadata.namespace = self.resolve_namespace(pod.metadata.namespace) - - retry_count = 0 - while True: - try: - resp = self.v1api.create_namespaced_pod(pod.metadata.namespace, pod) - except ApiException as exc: - - if retry_count > max_retry: - logger.error( - "failed to create pod after max retries", - retry_count=retry_count, - exc=err_to_str(exc), - pod=pod, - ) - raise exc - - logger.error("failed to create pod", exc=err_to_str(exc), pod=pod) - - # known k8s issue, see https://github.com/kubernetes/kubernetes/issues/67761 - if "gke-resource-quotas" in err_to_str(exc): - logger.warning( - "failed to create pod due to gke resource error, " - f"sleeping {retry_interval} seconds and retrying" - ) - retry_count += 1 - time.sleep(retry_interval) - continue - - raise exc - else: - logger.info(f"Pod {resp.metadata.name} created") - return resp.metadata.name, resp.metadata.namespace - - def delete_pod(self, name, namespace=None): - try: - api_response = self.v1api.delete_namespaced_pod( - name, - self.resolve_namespace(namespace), - grace_period_seconds=0, - propagation_policy="Background", - ) - return api_response - except ApiException as exc: - # ignore error if pod is already removed - if exc.status != 404: - logger.error(f"failed to delete pod: {err_to_str(exc)}", pod_name=name) - raise exc - - def get_pod(self, name, namespace=None, raise_on_not_found=False): - try: - api_response = self.v1api.read_namespaced_pod( - name=name, namespace=self.resolve_namespace(namespace) - ) - return api_response - except ApiException as exc: - if exc.status != 404: - logger.error(f"failed to get pod: {err_to_str(exc)}") - raise exc - else: - if raise_on_not_found: - raise mlrun.errors.MLRunNotFoundError(f"Pod not found: {name}") - return None - - def get_pod_status(self, name, namespace=None): - return self.get_pod( - name, namespace, raise_on_not_found=True - ).status.phase.lower() - - def delete_crd(self, name, crd_group, crd_version, crd_plural, namespace=None): - try: - namespace = self.resolve_namespace(namespace) - self.crdapi.delete_namespaced_custom_object( - crd_group, - crd_version, - namespace, - crd_plural, - name, - ) - logger.info( - "Deleted crd object", - crd_name=name, - namespace=namespace, - ) - except ApiException as exc: - - # ignore error if crd is already removed - if exc.status != 404: - logger.error( - f"failed to delete crd: {err_to_str(exc)}", - crd_name=name, - crd_group=crd_group, - crd_version=crd_version, - crd_plural=crd_plural, - ) - raise exc - - def logs(self, name, namespace=None): - try: - resp = self.v1api.read_namespaced_pod_log( - name=name, namespace=self.resolve_namespace(namespace) - ) - except ApiException as exc: - logger.error(f"failed to get pod logs: {err_to_str(exc)}") - raise exc - - return resp - - def run_job(self, pod, timeout=600): - pod_name, namespace = self.create_pod(pod) - if not pod_name: - logger.error("failed to create pod") - return "error" - return self.watch(pod_name, namespace, timeout) - - def watch(self, pod_name, namespace=None, timeout=600, writer=None): - namespace = self.resolve_namespace(namespace) - start_time = datetime.now() - while True: - try: - pod = self.get_pod(pod_name, namespace) - if not pod: - return "error" - status = pod.status.phase.lower() - if status in ["running", "completed", "succeeded"]: - print("") - break - if status == "failed": - return "failed" - elapsed_time = (datetime.now() - start_time).seconds - if elapsed_time > timeout: - return "timeout" - time.sleep(2) - stdout.write(".") - if status != "pending": - logger.warning(f"pod state in loop is {status}") - except ApiException as exc: - logger.error(f"failed waiting for pod: {err_to_str(exc)}\n") - return "error" - outputs = self.v1api.read_namespaced_pod_log( - name=pod_name, namespace=namespace, follow=True, _preload_content=False - ) - for out in outputs: - print(out.decode("utf-8"), end="") - if writer: - writer.write(out) - - for i in range(5): - pod_state = self.get_pod(pod_name, namespace).status.phase.lower() - if pod_state != "running": - break - logger.warning("pod still running, waiting 2 sec") - time.sleep(2) - - if pod_state == "failed": - logger.error("pod exited with error") - if writer: - writer.flush() - return pod_state - - def create_cfgmap(self, name, data, namespace="", labels=None): - body = client.api_client.V1ConfigMap() - namespace = self.resolve_namespace(namespace) - body.data = data - if name.endswith("*"): - body.metadata = client.V1ObjectMeta( - generate_name=name[:-1], namespace=namespace, labels=labels - ) - else: - body.metadata = client.V1ObjectMeta( - name=name, namespace=namespace, labels=labels - ) +def is_running_inside_kubernetes_cluster(): + global _running_inside_kubernetes_cluster + if _running_inside_kubernetes_cluster is None: try: - resp = self.v1api.create_namespaced_config_map(namespace, body) - except ApiException as exc: - logger.error(f"failed to create configmap: {err_to_str(exc)}") - raise exc - - logger.info(f"ConfigMap {resp.metadata.name} created") - return resp.metadata.name - - def del_cfgmap(self, name, namespace=None): - try: - api_response = self.v1api.delete_namespaced_config_map( - name, - self.resolve_namespace(namespace), - grace_period_seconds=0, - propagation_policy="Background", - ) - - return api_response - except ApiException as exc: - # ignore error if ConfigMap is already removed - if exc.status != 404: - logger.error(f"failed to delete ConfigMap: {err_to_str(exc)}") - raise exc - - def list_cfgmap(self, namespace=None, selector=""): - try: - resp = self.v1api.list_namespaced_config_map( - self.resolve_namespace(namespace), watch=False, label_selector=selector - ) - except ApiException as exc: - logger.error(f"failed to list ConfigMaps: {err_to_str(exc)}") - raise exc - - items = [] - for i in resp.items: - items.append(i) - return items - - def get_logger_pods(self, project, uid, run_kind, namespace=""): - - # As this file is imported in mlrun.runtimes, we sadly cannot have this import in the top level imports - # as that will create an import loop. - # TODO: Fix the import loops already! - import mlrun.runtimes - - namespace = self.resolve_namespace(namespace) - mpijob_crd_version = mlrun.runtimes.utils.resolve_mpijob_crd_version( - api_context=True - ) - mpijob_role_label = ( - mlrun.runtimes.constants.MPIJobCRDVersions.role_label_by_version( - mpijob_crd_version - ) - ) - extra_selectors = { - "spark": "spark-role=driver", - "mpijob": f"{mpijob_role_label}=launcher", - } - - # TODO: all mlrun labels are sprinkled in a lot of places - they need to all be defined in a central, - # inclusive place. - selectors = [ - "mlrun/class", - f"mlrun/project={project}", - f"mlrun/uid={uid}", - ] - - # In order to make the `list_pods` request return a lighter and quicker result, we narrow the search for - # the relevant pods using the proper label selector according to the run kind - if run_kind in extra_selectors: - selectors.append(extra_selectors[run_kind]) - - selector = ",".join(selectors) - pods = self.list_pods(namespace, selector=selector) - if not pods: - logger.error("no pod matches that uid", uid=uid) - return - - return {p.metadata.name: p.status.phase for p in pods} - - def create_project_service_account(self, project, service_account, namespace=""): - namespace = self.resolve_namespace(namespace) - k8s_service_account = client.V1ServiceAccount() - labels = {"mlrun/project": project} - k8s_service_account.metadata = client.V1ObjectMeta( - name=service_account, namespace=namespace, labels=labels - ) - try: - api_response = self.v1api.create_namespaced_service_account( - namespace, - k8s_service_account, - ) - return api_response - except ApiException as exc: - logger.error(f"failed to create service account: {err_to_str(exc)}") - raise exc - - def get_project_vault_secret_name( - self, project, service_account_name, namespace="" - ): - namespace = self.resolve_namespace(namespace) - - try: - service_account = self.v1api.read_namespaced_service_account( - service_account_name, namespace - ) - except ApiException as exc: - # It's valid for the service account to not exist. Simply return None - if exc.status != 404: - logger.error(f"failed to retrieve service accounts: {err_to_str(exc)}") - raise exc - return None - - if len(service_account.secrets) > 1: - raise ValueError( - f"Service account {service_account_name} has more than one secret" - ) - - return service_account.secrets[0].name - - def get_project_secret_name(self, project) -> str: - return mlconfig.secret_stores.kubernetes.project_secret_name.format( - project=project - ) - - def get_auth_secret_name(self, access_key: str) -> str: - hashed_access_key = self._hash_access_key(access_key) - return mlconfig.secret_stores.kubernetes.auth_secret_name.format( - hashed_access_key=hashed_access_key - ) - - @staticmethod - def _hash_access_key(access_key: str): - return hashlib.sha224(access_key.encode()).hexdigest() - - def store_project_secrets(self, project, secrets, namespace=""): - secret_name = self.get_project_secret_name(project) - self.store_secrets(secret_name, secrets, namespace) - - def read_auth_secret(self, secret_name, namespace="", raise_on_not_found=False): - namespace = self.resolve_namespace(namespace) - - try: - secret_data = self.v1api.read_namespaced_secret(secret_name, namespace).data - except ApiException as exc: - logger.error( - "Failed to read secret", - secret_name=secret_name, - namespace=namespace, - exc=err_to_str(exc), - ) - if exc.status != 404: - raise exc - elif raise_on_not_found: - raise mlrun.errors.MLRunNotFoundError( - f"Secret '{secret_name}' was not found in namespace '{namespace}'" - ) from exc - - return None, None - - def _get_secret_value(key): - if secret_data.get(key): - return base64.b64decode(secret_data[key]).decode("utf-8") - else: - return None - - username = _get_secret_value( - mlrun.api.schemas.AuthSecretData.get_field_secret_key("username") - ) - access_key = _get_secret_value( - mlrun.api.schemas.AuthSecretData.get_field_secret_key("access_key") - ) - - return username, access_key - - def store_auth_secret(self, username: str, access_key: str, namespace="") -> str: - secret_name = self.get_auth_secret_name(access_key) - secret_data = { - mlrun.api.schemas.AuthSecretData.get_field_secret_key("username"): username, - mlrun.api.schemas.AuthSecretData.get_field_secret_key( - "access_key" - ): access_key, - } - self.store_secrets( - secret_name, - secret_data, - namespace, - type_=SecretTypes.v3io_fuse, - labels={"mlrun/username": username}, - ) - return secret_name - - def store_secrets( - self, - secret_name, - secrets, - namespace="", - type_=SecretTypes.opaque, - labels: typing.Optional[dict] = None, - ): - namespace = self.resolve_namespace(namespace) - try: - k8s_secret = self.v1api.read_namespaced_secret(secret_name, namespace) - except ApiException as exc: - # If secret doesn't exist, we'll simply create it - if exc.status != 404: - logger.error(f"failed to retrieve k8s secret: {err_to_str(exc)}") - raise exc - k8s_secret = client.V1Secret(type=type_) - k8s_secret.metadata = client.V1ObjectMeta( - name=secret_name, namespace=namespace, labels=labels - ) - k8s_secret.string_data = secrets - self.v1api.create_namespaced_secret(namespace, k8s_secret) - return - - secret_data = k8s_secret.data.copy() - for key, value in secrets.items(): - secret_data[key] = base64.b64encode(value.encode()).decode("utf-8") - - k8s_secret.data = secret_data - self.v1api.replace_namespaced_secret(secret_name, namespace, k8s_secret) - - def load_secret(self, secret_name, namespace=""): - namespace = namespace or self.resolve_namespace(namespace) - - try: - k8s_secret = self.v1api.read_namespaced_secret(secret_name, namespace) - except ApiException: - return None - - return k8s_secret.data - - def delete_project_secrets(self, project, secrets, namespace=""): - secret_name = self.get_project_secret_name(project) - self.delete_secrets(secret_name, secrets, namespace) - - def delete_auth_secret(self, secret_ref: str, namespace=""): - self.delete_secrets(secret_ref, {}, namespace) - - def delete_secrets(self, secret_name, secrets, namespace=""): - namespace = self.resolve_namespace(namespace) - - try: - k8s_secret = self.v1api.read_namespaced_secret(secret_name, namespace) - except ApiException as exc: - # If secret does not exist, return as if the deletion was successfully - if exc.status == 404: - return - else: - logger.error(f"failed to retrieve k8s secret: {err_to_str(exc)}") - raise exc - - if not secrets: - secret_data = {} - else: - secret_data = k8s_secret.data.copy() - for secret in secrets: - secret_data.pop(secret, None) - - if not secret_data: - self.v1api.delete_namespaced_secret(secret_name, namespace) - else: - k8s_secret.data = secret_data - self.v1api.replace_namespaced_secret(secret_name, namespace, k8s_secret) - - def _get_project_secrets_raw_data(self, project, namespace=""): - secret_name = self.get_project_secret_name(project) - return self._get_secret_raw_data(secret_name, namespace) - - def _get_secret_raw_data(self, secret_name, namespace=""): - namespace = self.resolve_namespace(namespace) - - try: - k8s_secret = self.v1api.read_namespaced_secret(secret_name, namespace) - except ApiException: - return None - - return k8s_secret.data - - def get_project_secret_keys(self, project, namespace="", filter_internal=False): - secrets_data = self._get_project_secrets_raw_data(project, namespace) - if not secrets_data: - return [] - - secret_keys = list(secrets_data.keys()) - if filter_internal: - secret_keys = list( - filter(lambda key: not key.startswith("mlrun."), secret_keys) - ) - return secret_keys - - def get_project_secret_data(self, project, secret_keys=None, namespace=""): - secrets_data = self._get_project_secrets_raw_data(project, namespace) - return self._decode_secret_data(secrets_data, secret_keys) - - def get_secret_data(self, secret_name, namespace=""): - secrets_data = self._get_secret_raw_data(secret_name, namespace) - return self._decode_secret_data(secrets_data) - - def _decode_secret_data(self, secrets_data, secret_keys=None): - results = {} - if not secrets_data: - return results - - # If not asking for specific keys, return all - secret_keys = secret_keys or secrets_data.keys() - - for key in secret_keys: - encoded_value = secrets_data.get(key) - if encoded_value: - results[key] = base64.b64decode(secrets_data[key]).decode("utf-8") - return results - - -class BasePod: - def __init__( - self, - task_name="", - image=None, - command=None, - args=None, - namespace="", - kind="job", - project=None, - default_pod_spec_attributes=None, - resources=None, - ): - self.namespace = namespace - self.name = "" - self.task_name = task_name - self.image = image - self.command = command - self.args = args - self._volumes = [] - self._mounts = [] - self.env = None - self.node_selector = None - self.project = project or mlrun.mlconf.default_project - self._labels = { - "mlrun/task-name": task_name, - "mlrun/class": kind, - "mlrun/project": self.project, - } - self._annotations = {} - self._init_containers = [] - # will be applied on the pod spec only when calling .pod(), allows to override spec attributes - self.default_pod_spec_attributes = default_pod_spec_attributes - self.resources = resources - - @property - def pod(self): - return self._get_spec() - - @property - def init_containers(self): - return self._init_containers - - @init_containers.setter - def init_containers(self, containers): - self._init_containers = containers - - def append_init_container( - self, - image, - command=None, - args=None, - env=None, - image_pull_policy="IfNotPresent", - name="init", - ): - if isinstance(env, dict): - env = [client.V1EnvVar(name=k, value=v) for k, v in env.items()] - self._init_containers.append( - client.V1Container( - name=name, - image=image, - env=env, - command=command, - args=args, - image_pull_policy=image_pull_policy, - ) - ) - - def add_label(self, key, value): - self._labels[key] = str(value) - - def add_annotation(self, key, value): - self._annotations[key] = str(value) - - def add_volume(self, volume: client.V1Volume, mount_path, name=None, sub_path=None): - self._mounts.append( - client.V1VolumeMount( - name=name or volume.name, mount_path=mount_path, sub_path=sub_path - ) - ) - self._volumes.append(volume) - - def mount_empty(self, name="empty", mount_path="/empty"): - self.add_volume( - client.V1Volume(name=name, empty_dir=client.V1EmptyDirVolumeSource()), - mount_path=mount_path, - ) - - def mount_v3io( - self, name="v3io", remote="~/", mount_path="/User", access_key="", user="" - ): - self.add_volume( - v3io_to_vol(name, remote, access_key, user), - mount_path=mount_path, - name=name, - ) - - def mount_cfgmap(self, name, path="/config"): - self.add_volume( - client.V1Volume( - name=name, config_map=client.V1ConfigMapVolumeSource(name=name) - ), - mount_path=path, - ) - - def mount_secret(self, name, path="/secret", items=None, sub_path=None): - self.add_volume( - client.V1Volume( - name=name, - secret=client.V1SecretVolumeSource( - secret_name=name, - items=items, - ), - ), - mount_path=path, - sub_path=sub_path, - ) - - def set_node_selector(self, node_selector: typing.Optional[typing.Dict[str, str]]): - self.node_selector = node_selector - - def _get_spec(self, template=False): - - pod_obj = client.V1PodTemplate if template else client.V1Pod - - if self.env and isinstance(self.env, dict): - env = [client.V1EnvVar(name=k, value=v) for k, v in self.env.items()] - else: - env = self.env - container = client.V1Container( - name="base", - image=self.image, - env=env, - command=self.command, - args=self.args, - volume_mounts=self._mounts, - resources=self.resources, - ) - - pod_spec = client.V1PodSpec( - containers=[container], - restart_policy="Never", - volumes=self._volumes, - node_selector=self.node_selector, - ) - - # if attribute isn't defined use default pod spec attributes - for key, val in self.default_pod_spec_attributes.items(): - if not getattr(pod_spec, key, None): - setattr(pod_spec, key, val) - - for init_containers in self._init_containers: - init_containers.volume_mounts = self._mounts - pod_spec.init_containers = self._init_containers - - pod = pod_obj( - metadata=client.V1ObjectMeta( - generate_name=f"{self.task_name}-", - namespace=self.namespace, - labels=self._labels, - annotations=self._annotations, - ), - spec=pod_spec, - ) - return pod - - -def format_labels(labels): - """Convert a dictionary of labels into a comma separated string""" - if labels: - return ",".join([f"{k}={v}" for k, v in labels.items()]) - else: - return "" - - -def verify_gpu_requests_and_limits(requests_gpu: str = None, limits_gpu: str = None): - # https://kubernetes.io/docs/tasks/manage-gpus/scheduling-gpus/ - if requests_gpu and not limits_gpu: - raise mlrun.errors.MLRunConflictError( - "You cannot specify GPU requests without specifying limits" - ) - if requests_gpu and limits_gpu and requests_gpu != limits_gpu: - raise mlrun.errors.MLRunConflictError( - f"When specifying both GPU requests and limits these two values must be equal, " - f"requests_gpu={requests_gpu}, limits_gpu={limits_gpu}" - ) + kubernetes.config.load_incluster_config() + _running_inside_kubernetes_cluster = True + except kubernetes.config.ConfigException: + _running_inside_kubernetes_cluster = False + return _running_inside_kubernetes_cluster def generate_preemptible_node_selector_requirements( @@ -795,7 +40,7 @@ def generate_preemptible_node_selector_requirements( """ Generate node selector requirements based on the pre-configured node selector of the preemptible nodes. node selector operator represents a key's relationship to a set of values. - Valid operators are listed in :py:class:`~mlrun.api.schemas.NodeSelectorOperator` + Valid operators are listed in :py:class:`~mlrun.common.schemas.NodeSelectorOperator` :param node_selector_operator: The operator of V1NodeSelectorRequirement :return: List[V1NodeSelectorRequirement] """ @@ -825,12 +70,9 @@ def generate_preemptible_nodes_anti_affinity_terms() -> typing.List[ https://kubernetes.io/docs/concepts/scheduling-eviction/assign-pod-node/#affinity-and-anti-affinity :return: List contains one nodeSelectorTerm with multiple expressions. """ - # import here to avoid circular imports - from mlrun.api.schemas import NodeSelectorOperator - # compile affinities with operator NotIn to make sure pods are not running on preemptible nodes. node_selector_requirements = generate_preemptible_node_selector_requirements( - NodeSelectorOperator.node_selector_op_not_in.value + mlrun.common.schemas.NodeSelectorOperator.node_selector_op_not_in.value ) return [ kubernetes.client.V1NodeSelectorTerm( @@ -848,14 +90,11 @@ def generate_preemptible_nodes_affinity_terms() -> typing.List[ then the pod can be scheduled onto a node if at least one of the nodeSelectorTerms can be satisfied. :return: List of nodeSelectorTerms associated with the preemptible nodes. """ - # import here to avoid circular imports - from mlrun.api.schemas import NodeSelectorOperator - node_selector_terms = [] # compile affinities with operator In so pods could schedule on at least one of the preemptible nodes. node_selector_requirements = generate_preemptible_node_selector_requirements( - NodeSelectorOperator.node_selector_op_in.value + mlrun.common.schemas.NodeSelectorOperator.node_selector_op_in.value ) for expression in node_selector_requirements: node_selector_terms.append( diff --git a/mlrun/kfpops.py b/mlrun/kfpops.py index 2139316fee79..6a820efe9bc3 100644 --- a/mlrun/kfpops.py +++ b/mlrun/kfpops.py @@ -18,6 +18,7 @@ from copy import deepcopy from typing import Dict, List, Union +import inflection from kfp import dsl from kubernetes import client as k8s_client @@ -226,7 +227,7 @@ def mlrun_op( :param labels: labels to tag the job/run with ({key:val, ..}) :param inputs: dictionary of input objects + optional paths (if path is omitted the path will be the in_path/key. - :param outputs: dictionary of input objects + optional paths (if path is + :param outputs: dictionary of output objects + optional paths (if path is omitted the path will be the out_path/key. :param in_path: default input path/url (prefix) for inputs :param out_path: default output path/url (prefix) for artifacts @@ -712,6 +713,14 @@ def generate_kfp_dag_and_resolve_project(run, project=None): record = { k: node[k] for k in ["phase", "startedAt", "finishedAt", "type", "id"] } + + # snake case + # align kfp fields to mlrun snake case convention + # create snake_case for consistency. + # retain the camelCase for compatibility + for key in list(record.keys()): + record[inflection.underscore(key)] = record[key] + record["parent"] = node.get("boundaryID", "") record["name"] = name record["children"] = node.get("children", []) @@ -747,21 +756,23 @@ def format_summary_from_kfp_run(kfp_run, project=None, session=None): if error: dag[step]["error"] = error - short_run = {"graph": dag} - short_run["run"] = { - k: str(v) - for k, v in kfp_run["run"].items() - if k - in [ - "id", - "name", - "status", - "error", - "created_at", - "scheduled_at", - "finished_at", - "description", - ] + short_run = { + "graph": dag, + "run": { + k: str(v) if v is not None else v + for k, v in kfp_run["run"].items() + if k + in [ + "id", + "name", + "status", + "error", + "created_at", + "scheduled_at", + "finished_at", + "description", + ] + }, } short_run["run"]["project"] = project short_run["run"]["message"] = message diff --git a/mlrun/launcher/__init__.py b/mlrun/launcher/__init__.py new file mode 100644 index 000000000000..7f557697af77 --- /dev/null +++ b/mlrun/launcher/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/mlrun/launcher/base.py b/mlrun/launcher/base.py new file mode 100644 index 000000000000..bb491b218c8e --- /dev/null +++ b/mlrun/launcher/base.py @@ -0,0 +1,406 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import ast +import copy +import os +import uuid +from typing import Any, Callable, Dict, List, Optional, Union + +import mlrun.common.schemas +import mlrun.config +import mlrun.errors +import mlrun.kfpops +import mlrun.lists +import mlrun.model +import mlrun.runtimes +from mlrun.utils import logger + +run_modes = ["pass"] + + +class BaseLauncher(abc.ABC): + """ + Abstract class for managing and running functions in different contexts + This class is designed to encapsulate the logic of running a function in different contexts + i.e. running a function locally, remotely or in a server + Each context will have its own implementation of the abstract methods while the common logic resides in this class + """ + + def save_function( + self, + runtime: "mlrun.runtimes.BaseRuntime", + tag: str = "", + versioned: bool = False, + refresh: bool = False, + ) -> str: + """ + store the function to the db + :param runtime: runtime object + :param tag: function tag to store + :param versioned: whether we want to version this function object so that it will queryable by its hash key + :param refresh: refresh function metadata + + :return: function uri + """ + db = runtime._get_db() + if not db: + raise mlrun.errors.MLRunPreconditionFailedError( + "Database connection is not configured" + ) + + if refresh: + self._refresh_function_metadata(runtime) + + tag = tag or runtime.metadata.tag + + obj = runtime.to_dict() + logger.debug("Saving function", runtime_name=runtime.metadata.name, tag=tag) + hash_key = db.store_function( + obj, runtime.metadata.name, runtime.metadata.project, tag, versioned + ) + hash_key = hash_key if versioned else None + return "db://" + runtime._function_uri(hash_key=hash_key, tag=tag) + + @abc.abstractmethod + def launch( + self, + runtime: "mlrun.runtimes.BaseRuntime", + task: Optional[ + Union["mlrun.run.RunTemplate", "mlrun.run.RunObject", dict] + ] = None, + handler: Optional[Union[str, Callable]] = None, + name: Optional[str] = "", + project: Optional[str] = "", + params: Optional[dict] = None, + inputs: Optional[Dict[str, str]] = None, + out_path: Optional[str] = "", + workdir: Optional[str] = "", + artifact_path: Optional[str] = "", + watch: Optional[bool] = True, + schedule: Optional[ + Union[str, mlrun.common.schemas.schedule.ScheduleCronTrigger] + ] = None, + hyperparams: Dict[str, list] = None, + hyper_param_options: Optional[mlrun.model.HyperParamOptions] = None, + verbose: Optional[bool] = None, + scrape_metrics: Optional[bool] = None, + local_code_path: Optional[str] = None, + auto_build: Optional[bool] = None, + param_file_secrets: Optional[Dict[str, str]] = None, + notifications: Optional[List[mlrun.model.Notification]] = None, + returns: Optional[List[Union[str, Dict[str, str]]]] = None, + ) -> "mlrun.run.RunObject": + """run the function from the server/client[local/remote]""" + pass + + def _validate_runtime( + self, + runtime: "mlrun.runtimes.BaseRuntime", + run: "mlrun.run.RunObject", + ): + mlrun.utils.helpers.verify_dict_items_type( + "Inputs", run.spec.inputs, [str], [str] + ) + + if runtime.spec.mode and runtime.spec.mode not in run_modes: + raise ValueError(f'run mode can only be {",".join(run_modes)}') + + self._validate_run_params(run.spec.parameters) + self._validate_output_path(runtime, run) + + @staticmethod + def _validate_output_path( + runtime: "mlrun.runtimes.BaseRuntime", + run: "mlrun.run.RunObject", + ): + if not run.spec.output_path or "://" not in run.spec.output_path: + message = "" + if not os.path.isabs(run.spec.output_path): + message = ( + "artifact/output path is not defined or is local and relative," + " artifacts will not be visible in the UI" + ) + if mlrun.runtimes.RuntimeKinds.requires_absolute_artifacts_path( + runtime.kind + ): + raise mlrun.errors.MLRunPreconditionFailedError( + "artifact path (`artifact_path`) must be absolute for remote tasks" + ) + elif ( + hasattr(runtime.spec, "volume_mounts") + and not runtime.spec.volume_mounts + ): + message = ( + "artifact output path is local while no volume mount is specified. " + "artifacts would not be visible via UI." + ) + if message: + logger.warning(message, output_path=run.spec.output_path) + + def _validate_run_params(self, parameters: Dict[str, Any]): + for param_name, param_value in parameters.items(): + + if isinstance(param_value, dict): + # if the parameter is a dict, we might have some nested parameters, + # in this case we need to verify them as well recursively + self._validate_run_params(param_value) + + # verify that integer parameters don't exceed a int64 + if isinstance(param_value, int) and abs(param_value) >= 2**63: + raise mlrun.errors.MLRunInvalidArgumentError( + f"parameter {param_name} value {param_value} exceeds int64" + ) + + @staticmethod + def _create_run_object(task): + valid_task_types = (dict, mlrun.run.RunTemplate, mlrun.run.RunObject) + + if not task: + # if task passed generate default RunObject + return mlrun.run.RunObject.from_dict(task) + + # deepcopy user's task, so we don't modify / enrich the user's object + task = copy.deepcopy(task) + + if isinstance(task, str): + task = ast.literal_eval(task) + + if not isinstance(task, valid_task_types): + raise mlrun.errors.MLRunInvalidArgumentError( + f"Task is not a valid object, type={type(task)}, expected types={valid_task_types}" + ) + + if isinstance(task, mlrun.run.RunTemplate): + return mlrun.run.RunObject.from_template(task) + elif isinstance(task, dict): + return mlrun.run.RunObject.from_dict(task) + + # task is already a RunObject + return task + + def _enrich_run( + self, + runtime, + run, + handler=None, + project_name=None, + name=None, + params=None, + inputs=None, + returns=None, + hyperparams=None, + hyper_param_options=None, + verbose=None, + scrape_metrics=None, + out_path=None, + artifact_path=None, + workdir=None, + notifications: List[mlrun.model.Notification] = None, + ): + run.spec.handler = ( + handler or run.spec.handler or runtime.spec.default_handler or "" + ) + # callable handlers are valid for handler and dask runtimes, + # for other runtimes we need to convert the handler to a string + if run.spec.handler and runtime.kind not in ["handler", "dask"]: + run.spec.handler = run.spec.handler_name + + def_name = runtime.metadata.name + if run.spec.handler_name: + short_name = run.spec.handler_name + for separator in ["#", "::", "."]: + # drop paths, module or class name from short name + if separator in short_name: + short_name = short_name.split(separator)[-1] + def_name += "-" + short_name + + run.metadata.name = mlrun.utils.normalize_name( + name=name or run.metadata.name or def_name, + # if name or runspec.metadata.name are set then it means that is user defined name and we want to warn the + # user that the passed name needs to be set without underscore, if its not user defined but rather enriched + # from the handler(function) name then we replace the underscore without warning the user. + # most of the time handlers will have `_` in the handler name (python convention is to separate function + # words with `_`), therefore we don't want to be noisy when normalizing the run name + verbose=bool(name or run.metadata.name), + ) + mlrun.utils.verify_field_regex( + "run.metadata.name", run.metadata.name, mlrun.utils.regex.run_name + ) + run.metadata.project = ( + project_name + or run.metadata.project + or runtime.metadata.project + or mlrun.mlconf.default_project + ) + run.spec.parameters = params or run.spec.parameters + run.spec.inputs = inputs or run.spec.inputs + run.spec.returns = returns or run.spec.returns + run.spec.hyperparams = hyperparams or run.spec.hyperparams + run.spec.hyper_param_options = ( + hyper_param_options or run.spec.hyper_param_options + ) + run.spec.verbose = verbose or run.spec.verbose + if scrape_metrics is None: + if run.spec.scrape_metrics is None: + scrape_metrics = mlrun.mlconf.scrape_metrics + else: + scrape_metrics = run.spec.scrape_metrics + run.spec.scrape_metrics = scrape_metrics + run.spec.input_path = workdir or run.spec.input_path or runtime.spec.workdir + if runtime.spec.allow_empty_resources: + run.spec.allow_empty_resources = runtime.spec.allow_empty_resources + + spec = run.spec + if spec.secret_sources: + runtime._secrets = mlrun.secrets.SecretsStore.from_list(spec.secret_sources) + + # update run metadata (uid, labels) and store in DB + meta = run.metadata + meta.uid = meta.uid or uuid.uuid4().hex + + run.spec.output_path = out_path or artifact_path or run.spec.output_path + + if not run.spec.output_path: + if run.metadata.project: + if ( + mlrun.pipeline_context.project + and run.metadata.project + == mlrun.pipeline_context.project.metadata.name + ): + run.spec.output_path = ( + mlrun.pipeline_context.project.spec.artifact_path + or mlrun.pipeline_context.workflow_artifact_path + ) + + # get_db might be None when no rundb is set on runtime + if not run.spec.output_path and runtime._get_db(): + try: + # not passing or loading the DB before the enrichment on purpose, because we want to enrich the + # spec first as get_db() depends on it + project = runtime._get_db().get_project(run.metadata.project) + # this is mainly for tests, so we won't need to mock get_project for so many tests + # in normal use cases if no project is found we will get an error + if project: + run.spec.output_path = project.spec.artifact_path + except mlrun.errors.MLRunNotFoundError: + logger.warning( + f"project {project_name} is not saved in DB yet, " + f"enriching output path with default artifact path: {mlrun.mlconf.artifact_path}" + ) + + if not run.spec.output_path: + run.spec.output_path = mlrun.mlconf.artifact_path + + if run.spec.output_path: + run.spec.output_path = run.spec.output_path.replace("{{run.uid}}", meta.uid) + run.spec.output_path = mlrun.utils.helpers.fill_artifact_path_template( + run.spec.output_path, run.metadata.project + ) + + notifications = notifications or run.spec.notifications or [] + mlrun.model.Notification.validate_notification_uniqueness(notifications) + for notification in notifications: + notification.validate_notification() + + run.spec.notifications = notifications + + return run + + @staticmethod + def _run_has_valid_notifications(runobj) -> bool: + if not runobj.spec.notifications: + logger.debug( + "No notifications to push for run", run_uid=runobj.metadata.uid + ) + return False + + # TODO: add support for other notifications per run iteration + if runobj.metadata.iteration and runobj.metadata.iteration > 0: + logger.debug( + "Notifications per iteration are not supported, skipping", + run_uid=runobj.metadata.uid, + ) + return False + + return True + + def _wrap_run_result( + self, + runtime: "mlrun.runtimes.BaseRuntime", + result: dict, + run: "mlrun.run.RunObject", + schedule: Optional[mlrun.common.schemas.ScheduleCronTrigger] = None, + err: Optional[Exception] = None, + ): + # if the purpose was to schedule (and not to run) nothing to wrap + if schedule: + return + + if result and runtime.kfp and err is None: + mlrun.kfpops.write_kfpmeta(result) + + self._log_track_results(runtime, result, run) + + if result: + run = mlrun.run.RunObject.from_dict(result) + logger.info( + "Run execution finished", + status=run.status.state, + name=run.metadata.name, + ) + if run.status.state in [ + mlrun.runtimes.base.RunStates.error, + mlrun.runtimes.base.RunStates.aborted, + ]: + if runtime._is_remote and not runtime.is_child: + logger.error( + "Run did not finish successfully", + state=run.status.state, + status=run.status.to_dict(), + ) + raise mlrun.runtimes.utils.RunError(run.error) + return run + + return None + + @staticmethod + def _refresh_function_metadata(runtime: "mlrun.runtimes.BaseRuntime"): + pass + + @staticmethod + def prepare_image_for_deploy(runtime: "mlrun.runtimes.BaseRuntime"): + """Check if the runtime requires to build the image and updates the spec accordingly""" + pass + + @staticmethod + @abc.abstractmethod + def enrich_runtime( + runtime: "mlrun.runtimes.base.BaseRuntime", + project_name: Optional[str] = "", + ): + pass + + @staticmethod + @abc.abstractmethod + def _store_function( + runtime: "mlrun.runtimes.BaseRuntime", run: "mlrun.run.RunObject" + ): + pass + + @staticmethod + def _log_track_results( + runtime: "mlrun.runtimes.BaseRuntime", result: dict, run: "mlrun.run.RunObject" + ): + pass diff --git a/mlrun/launcher/client.py b/mlrun/launcher/client.py new file mode 100644 index 000000000000..c4024740e8c4 --- /dev/null +++ b/mlrun/launcher/client.py @@ -0,0 +1,159 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import abc +import getpass +import os +from typing import Optional + +import IPython + +import mlrun.errors +import mlrun.launcher.base +import mlrun.lists +import mlrun.model +import mlrun.runtimes +from mlrun.utils import logger + + +class ClientBaseLauncher(mlrun.launcher.base.BaseLauncher, abc.ABC): + """ + Abstract class for common code between client launchers + """ + + @staticmethod + def enrich_runtime( + runtime: "mlrun.runtimes.base.BaseRuntime", project_name: Optional[str] = "" + ): + runtime.try_auto_mount_based_on_config() + runtime._fill_credentials() + + @staticmethod + def prepare_image_for_deploy(runtime: "mlrun.runtimes.BaseRuntime"): + """ + Check if the runtime requires to build the image. + If build is needed, set the image as the base_image for the build. + If image is not given set the default one. + """ + if runtime.kind in mlrun.runtimes.RuntimeKinds.nuclio_runtimes(): + return + + build = runtime.spec.build + require_build = ( + build.commands + or build.requirements + or (build.source and not build.load_source_on_run) + ) + image = runtime.spec.image + # we allow users to not set an image, in that case we'll use the default + if ( + not image + and runtime.kind in mlrun.mlconf.function_defaults.image_by_kind.to_dict() + ): + image = mlrun.mlconf.function_defaults.image_by_kind.to_dict()[runtime.kind] + + # TODO: need a better way to decide whether a function requires a build + if require_build and image and not runtime.spec.build.base_image: + # when the function require build use the image as the base_image for the build + runtime.spec.build.base_image = image + runtime.spec.image = "" + + @staticmethod + def _store_function( + runtime: "mlrun.runtimes.BaseRuntime", run: "mlrun.run.RunObject" + ): + run.metadata.labels["kind"] = runtime.kind + if "owner" not in run.metadata.labels: + run.metadata.labels["owner"] = ( + os.environ.get("V3IO_USERNAME") or getpass.getuser() + ) + if run.spec.output_path: + run.spec.output_path = run.spec.output_path.replace( + "{{run.user}}", run.metadata.labels["owner"] + ) + db = runtime._get_db() + if db and runtime.kind != "handler": + struct = runtime.to_dict() + hash_key = db.store_function( + struct, runtime.metadata.name, runtime.metadata.project, versioned=True + ) + run.spec.function = runtime._function_uri(hash_key=hash_key) + + @staticmethod + def _refresh_function_metadata(runtime: "mlrun.runtimes.BaseRuntime"): + try: + meta = runtime.metadata + db = runtime._get_db() + db_func = db.get_function(meta.name, meta.project, meta.tag) + if db_func and "status" in db_func: + runtime.status = db_func["status"] + if ( + runtime.status.state + and runtime.status.state == "ready" + and runtime.kind + # We don't want to override the nuclio image here because the build happens in nuclio + # TODO: have a better way to check if nuclio function deploy started + and not hasattr(runtime.status, "nuclio_name") + ): + runtime.spec.image = mlrun.utils.get_in( + db_func, "spec.image", runtime.spec.image + ) + except mlrun.errors.MLRunNotFoundError: + pass + + @staticmethod + def _log_track_results( + runtime: "mlrun.runtimes.BaseRuntime", result: dict, run: "mlrun.run.RunObject" + ): + """ + log commands to track results + in jupyter, displays a table widget with the result + else, logs CLI commands to track results and a link to the results in UI + + :param: runtime: runtime object + :param result: run result dict + :param run: run object + """ + uid = run.metadata.uid + project = run.metadata.project + + # show ipython/jupyter result table widget + results_tbl = mlrun.lists.RunList() + if result: + results_tbl.append(result) + else: + logger.info("no returned result (job may still be in progress)") + results_tbl.append(run.to_dict()) + + if mlrun.utils.is_ipython and mlrun.config.config.ipython_widget: + results_tbl.show() + print() + ui_url = mlrun.utils.get_ui_url(project, uid) + if ui_url: + ui_url = f' or click here to open in UI' + IPython.display.display( + IPython.display.HTML( + f" > to track results use the .show() or .logs() methods {ui_url}" + ) + ) + elif not runtime.is_child: + # TODO: Log sdk commands to track results instead of CLI commands + project_flag = f"-p {project}" if project else "" + info_cmd = f"mlrun get run {uid} {project_flag}" + logs_cmd = f"mlrun logs {uid} {project_flag}" + logger.info( + "To track results use the CLI", info_cmd=info_cmd, logs_cmd=logs_cmd + ) + ui_url = mlrun.utils.get_ui_url(project, uid) + if ui_url: + logger.info("Or click for UI", ui_url=ui_url) diff --git a/mlrun/launcher/factory.py b/mlrun/launcher/factory.py new file mode 100644 index 000000000000..e434c34b1136 --- /dev/null +++ b/mlrun/launcher/factory.py @@ -0,0 +1,50 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import mlrun.config +import mlrun.errors +import mlrun.launcher.base +import mlrun.launcher.local +import mlrun.launcher.remote + + +class LauncherFactory(object): + @staticmethod + def create_launcher( + is_remote: bool, local: bool = False + ) -> mlrun.launcher.base.BaseLauncher: + """ + Creates the appropriate launcher for the specified run. + ServerSideLauncher - if running as API. + ClientRemoteLauncher - if the run is remote and local was not specified. + ClientLocalLauncher - if the run is not remote or local was specified. + + :param is_remote: Whether the runtime requires remote execution. + :param local: Run the function locally vs on the Runtime/Cluster + + :return: The appropriate launcher for the specified run. + """ + if mlrun.config.is_running_as_api(): + if local: + raise mlrun.errors.MLRunInternalServerError( + "Launch of local run inside the server is not allowed" + ) + + from mlrun.api.launcher import ServerSideLauncher + + return ServerSideLauncher() + + if is_remote and not local: + return mlrun.launcher.remote.ClientRemoteLauncher() + + return mlrun.launcher.local.ClientLocalLauncher(local) diff --git a/mlrun/launcher/local.py b/mlrun/launcher/local.py new file mode 100644 index 000000000000..bbf63f64bb3c --- /dev/null +++ b/mlrun/launcher/local.py @@ -0,0 +1,276 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import pathlib +from typing import Callable, Dict, List, Optional, Union + +import mlrun.common.schemas.schedule +import mlrun.errors +import mlrun.launcher.client +import mlrun.run +import mlrun.runtimes.generators +import mlrun.utils.clones +import mlrun.utils.notifications +from mlrun.utils import logger + + +class ClientLocalLauncher(mlrun.launcher.client.ClientBaseLauncher): + """ + ClientLocalLauncher is a launcher that runs the job locally. + Either on the user's machine (_is_run_local is True) or on a remote machine (_is_run_local is False). + """ + + def __init__(self, local: bool): + """ + Initialize a ClientLocalLauncher. + :param local: True if the job runs on the user's local machine, + False if it runs on a remote machine (e.g. a dedicated k8s pod). + """ + super().__init__() + self._is_run_local = local + + def launch( + self, + runtime: "mlrun.runtimes.BaseRuntime", + task: Optional[ + Union["mlrun.run.RunTemplate", "mlrun.run.RunObject", dict] + ] = None, + handler: Optional[Union[str, Callable]] = None, + name: Optional[str] = "", + project: Optional[str] = "", + params: Optional[dict] = None, + inputs: Optional[Dict[str, str]] = None, + out_path: Optional[str] = "", + workdir: Optional[str] = "", + artifact_path: Optional[str] = "", + watch: Optional[bool] = True, + schedule: Optional[ + Union[str, mlrun.common.schemas.schedule.ScheduleCronTrigger] + ] = None, + hyperparams: Dict[str, list] = None, + hyper_param_options: Optional[mlrun.model.HyperParamOptions] = None, + verbose: Optional[bool] = None, + scrape_metrics: Optional[bool] = None, + local_code_path: Optional[str] = None, + auto_build: Optional[bool] = None, + param_file_secrets: Optional[Dict[str, str]] = None, + notifications: Optional[List[mlrun.model.Notification]] = None, + returns: Optional[List[Union[str, Dict[str, str]]]] = None, + ) -> "mlrun.run.RunObject": + + # do not allow local function to be scheduled + if self._is_run_local and schedule is not None: + raise mlrun.errors.MLRunInvalidArgumentError( + "local and schedule cannot be used together" + ) + + self.enrich_runtime(runtime) + run = self._create_run_object(task) + + if self._is_run_local: + runtime = self._create_local_function_for_execution( + runtime=runtime, + run=run, + local_code_path=local_code_path, + project=project, + name=name, + workdir=workdir, + handler=handler, + ) + + # sanity check + elif runtime._is_remote: + message = "Remote function cannot be executed locally" + logger.error( + message, + is_remote=runtime._is_remote, + local=self._is_run_local, + runtime=runtime.to_dict(), + ) + raise mlrun.errors.MLRunRuntimeError(message) + + run = self._enrich_run( + runtime=runtime, + run=run, + handler=handler, + project_name=project, + name=name, + params=params, + inputs=inputs, + returns=returns, + hyperparams=hyperparams, + hyper_param_options=hyper_param_options, + verbose=verbose, + scrape_metrics=scrape_metrics, + out_path=out_path, + artifact_path=artifact_path, + workdir=workdir, + notifications=notifications, + ) + self._validate_runtime(runtime, run) + result = self.execute( + runtime=runtime, + run=run, + ) + + return result + + def execute( + self, + runtime: "mlrun.runtimes.BaseRuntime", + run: Optional[Union["mlrun.run.RunTemplate", "mlrun.run.RunObject"]] = None, + ): + + if "V3IO_USERNAME" in os.environ and "v3io_user" not in run.metadata.labels: + run.metadata.labels["v3io_user"] = os.environ.get("V3IO_USERNAME") + + # store function object in db unless running from within a run pod + if not runtime.is_child: + logger.info( + "Storing function", + name=run.metadata.name, + uid=run.metadata.uid, + db=runtime.spec.rundb, + ) + self._store_function(runtime, run) + + execution = mlrun.run.MLClientCtx.from_dict( + run.to_dict(), + runtime._get_db(), + autocommit=False, + is_api=False, + store_run=False, + ) + + # create task generator (for child runs) from spec + task_generator = mlrun.runtimes.generators.get_generator(run.spec, execution) + if task_generator: + # verify valid task parameters + tasks = task_generator.generate(run) + for task in tasks: + self._validate_run_params(task.spec.parameters) + + # post verifications, store execution in db and run pre run hooks + execution.store_run() + runtime._pre_run(run, execution) # hook for runtime specific prep + + last_err = None + # If the runtime is nested, it means the hyper-run will run within a single instance of the run. + # So while in the API, we consider the hyper-run as a single run, and then in the runtime itself when the + # runtime is now a local runtime and therefore `self._is_nested == False`, we run each task as a separate run by + # using the task generator + # TODO client-server separation might not need the not runtime._is_nested anymore as this executed local func + if task_generator and not runtime._is_nested: + # multiple runs (based on hyper params or params file) + runner = runtime._run_many + if hasattr(runtime, "_parallel_run_many") and task_generator.use_parallel(): + runner = runtime._parallel_run_many + results = runner(task_generator, execution, run) + mlrun.runtimes.utils.results_to_iter(results, run, execution) + result = execution.to_dict() + result = runtime._update_run_state(result, task=run) + + else: + # single run + try: + resp = runtime._run(run, execution) + result = runtime._update_run_state(resp, task=run) + except mlrun.runtimes.base.RunError as err: + last_err = err + result = runtime._update_run_state(task=run, err=err) + + self._push_notifications(run, runtime) + + # run post run hooks + runtime._post_run(result, execution) # hook for runtime specific cleanup + + return self._wrap_run_result(runtime, result, run, err=last_err) + + def _create_local_function_for_execution( + self, + runtime: "mlrun.runtimes.BaseRuntime", + run: "mlrun.run.RunObject", + local_code_path: Optional[str] = None, + project: Optional[str] = "", + name: Optional[str] = "", + workdir: Optional[str] = "", + handler: Optional[str] = None, + ): + + project = project or runtime.metadata.project + function_name = name or runtime.metadata.name + command, args = self._resolve_local_code_path(local_code_path) + if command: + function_name = name or pathlib.Path(command).stem + + meta = mlrun.model.BaseMetadata(function_name, project=project) + + command, loaded_runtime = mlrun.run.load_func_code( + command or runtime, workdir, name=name + ) + # loaded_runtime is loaded from runtime or yaml file, if passed a command it should be None, + # so we keep the current runtime for enrichment + runtime = loaded_runtime or runtime + if loaded_runtime: + if run: + handler = handler or run.spec.handler + handler = handler or runtime.spec.default_handler or "" + meta = runtime.metadata.copy() + meta.name = function_name or meta.name + meta.project = project or meta.project + + # if the handler has module prefix force "local" (vs "handler") runtime + kind = "local" if isinstance(handler, str) and "." in handler else "" + fn = mlrun.new_function(meta.name, command=command, args=args, kind=kind) + fn.metadata = meta + setattr(fn, "_is_run_local", True) + if workdir: + fn.spec.workdir = str(workdir) + + fn.spec.allow_empty_resources = runtime.spec.allow_empty_resources + if runtime: + # copy the code/base-spec to the local function (for the UI and code logging) + fn.spec.description = runtime.spec.description + fn.spec.build = runtime.spec.build + + run.spec.handler = handler + return fn + + @staticmethod + def _resolve_local_code_path(local_code_path: str) -> (str, List[str]): + command = None + args = [] + if local_code_path: + command = local_code_path + if command: + sp = command.split() + # split command and args + command = sp[0] + if len(sp) > 1: + args = sp[1:] + return command, args + + def _push_notifications( + self, runobj: "mlrun.run.RunObject", runtime: "mlrun.runtimes.BaseRuntime" + ): + if not self._run_has_valid_notifications(runobj): + return + # TODO: add store_notifications API endpoint so we can store notifications pushed from the + # SDK for documentation purposes. + # The run is local, so we can assume that watch=True, therefore this code runs + # once the run is completed, and we can just push the notifications. + # Only push from jupyter, not from the CLI. + # "handler" and "dask" kinds are special cases of local runs which don't set local=True + if self._is_run_local or runtime.kind in ["handler", "dask"]: + mlrun.utils.notifications.NotificationPusher([runobj]).push() diff --git a/mlrun/launcher/remote.py b/mlrun/launcher/remote.py new file mode 100644 index 000000000000..463d67ddbfc7 --- /dev/null +++ b/mlrun/launcher/remote.py @@ -0,0 +1,178 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Dict, List, Optional, Union + +import requests + +import mlrun.common.schemas.schedule +import mlrun.db +import mlrun.errors +import mlrun.launcher.client +import mlrun.run +import mlrun.runtimes +import mlrun.runtimes.generators +import mlrun.utils.clones +import mlrun.utils.notifications +from mlrun.utils import logger + + +class ClientRemoteLauncher(mlrun.launcher.client.ClientBaseLauncher): + def launch( + self, + runtime: "mlrun.runtimes.KubejobRuntime", + task: Optional[ + Union["mlrun.run.RunTemplate", "mlrun.run.RunObject", dict] + ] = None, + handler: Optional[str] = None, + name: Optional[str] = "", + project: Optional[str] = "", + params: Optional[dict] = None, + inputs: Optional[Dict[str, str]] = None, + out_path: Optional[str] = "", + workdir: Optional[str] = "", + artifact_path: Optional[str] = "", + watch: Optional[bool] = True, + schedule: Optional[ + Union[str, mlrun.common.schemas.schedule.ScheduleCronTrigger] + ] = None, + hyperparams: Dict[str, list] = None, + hyper_param_options: Optional[mlrun.model.HyperParamOptions] = None, + verbose: Optional[bool] = None, + scrape_metrics: Optional[bool] = None, + local_code_path: Optional[str] = None, + auto_build: Optional[bool] = None, + param_file_secrets: Optional[Dict[str, str]] = None, + notifications: Optional[List[mlrun.model.Notification]] = None, + returns: Optional[List[Union[str, Dict[str, str]]]] = None, + ) -> "mlrun.run.RunObject": + self.enrich_runtime(runtime) + run = self._create_run_object(task) + + run = self._enrich_run( + runtime=runtime, + run=run, + handler=handler, + project_name=project, + name=name, + params=params, + inputs=inputs, + returns=returns, + hyperparams=hyperparams, + hyper_param_options=hyper_param_options, + verbose=verbose, + scrape_metrics=scrape_metrics, + out_path=out_path, + artifact_path=artifact_path, + workdir=workdir, + notifications=notifications, + ) + self._validate_runtime(runtime, run) + + if not runtime.is_deployed(): + if runtime.spec.build.auto_build or auto_build: + logger.info( + "Function is not deployed and auto_build flag is set, starting deploy..." + ) + runtime.deploy(skip_deployed=True, show_on_failure=True) + + else: + raise mlrun.errors.MLRunRuntimeError( + "function image is not built/ready, set auto_build=True or use .deploy() method first" + ) + + if runtime.verbose: + logger.info(f"runspec:\n{run.to_yaml()}") + + if "V3IO_USERNAME" in os.environ and "v3io_user" not in run.metadata.labels: + run.metadata.labels["v3io_user"] = os.environ.get("V3IO_USERNAME") + + logger.info( + "Storing function", + name=run.metadata.name, + uid=run.metadata.uid, + db=runtime.spec.rundb, + ) + self._store_function(runtime, run) + + return self.submit_job(runtime, run, schedule, watch) + + def submit_job( + self, + runtime: "mlrun.runtimes.KubejobRuntime", + run: "mlrun.run.RunObject", + schedule: Optional[mlrun.common.schemas.ScheduleCronTrigger] = None, + watch: Optional[bool] = None, + ): + if runtime._secrets: + run.spec.secret_sources = runtime._secrets.to_serial() + try: + db = runtime._get_db() + resp = db.submit_job(run, schedule=schedule) + if schedule: + action = resp.pop("action", "created") + logger.info(f"task schedule {action}", **resp) + return + + except (requests.HTTPError, Exception) as err: + logger.error(f"got remote run err, {mlrun.errors.err_to_str(err)}") + + if isinstance(err, requests.HTTPError): + runtime._handle_submit_job_http_error(err) + + result = None + # if we got a schedule no reason to do post_run stuff (it purposed to update the run status with error, + # but there's no run in case of schedule) + if not schedule: + result = runtime._update_run_state( + task=run, err=mlrun.errors.err_to_str(err) + ) + return self._wrap_run_result( + runtime, result, run, schedule=schedule, err=err + ) + + if resp: + txt = mlrun.runtimes.utils.helpers.get_in(resp, "status.status_text") + if txt: + logger.info(txt) + + # watch is None only in scenario where we run from pipeline step, in this case we don't want to watch the run + # logs too frequently but rather just pull the state of the run from the DB and pull the logs every x seconds + # which ideally greater than the pull state interval, this reduces unnecessary load on the API server, as + # running a pipeline is mostly not an interactive process which means the logs pulling doesn't need to be pulled + # in real time + if ( + watch is None + and runtime.kfp + and mlrun.mlconf.httpdb.logs.pipelines.pull_state.mode == "enabled" + ): + state_interval = int( + mlrun.mlconf.httpdb.logs.pipelines.pull_state.pull_state_interval + ) + logs_interval = int( + mlrun.mlconf.httpdb.logs.pipelines.pull_state.pull_logs_interval + ) + run.wait_for_completion( + show_logs=True, + sleep=state_interval, + logs_interval=logs_interval, + raise_on_failure=False, + ) + resp = runtime._get_db_run(run) + + elif watch or runtime.kfp: + run.logs(True, runtime._get_db()) + resp = runtime._get_db_run(run) + + return self._wrap_run_result(runtime, resp, run, schedule=schedule) diff --git a/mlrun/model.py b/mlrun/model.py index 3ec204b20b97..390142d41459 100644 --- a/mlrun/model.py +++ b/mlrun/model.py @@ -16,13 +16,18 @@ import pathlib import re import time +import typing +import warnings from collections import OrderedDict from copy import deepcopy from datetime import datetime from os import environ from typing import Any, Dict, List, Optional, Tuple, Union +import pydantic.error_wrappers + import mlrun +import mlrun.common.schemas.notification from .utils import ( dict_to_json, @@ -338,6 +343,7 @@ def __init__( origin_filename=None, with_mlrun=None, auto_build=None, + requirements: list = None, ): self.functionSourceCode = functionSourceCode #: functionSourceCode self.codeEntryType = "" #: codeEntryType @@ -355,6 +361,7 @@ def __init__( self.with_mlrun = with_mlrun #: with_mlrun self.auto_build = auto_build #: auto_build self.build_pod = None + self.requirements = requirements or [] #: pip requirements @property def source(self): @@ -369,12 +376,153 @@ def source(self, source): or source in [".", "./"] ): raise mlrun.errors.MLRunInvalidArgumentError( - "source must be a compressed (tar.gz / zip) file, a git repo, " - "a file path or in the project's context (.)" + f"source ({source}) must be a compressed (tar.gz / zip) file, a git repo, " + f"a file path or in the project's context (.)" ) self._source = source + def build_config( + self, + image="", + base_image=None, + commands: list = None, + secret=None, + source=None, + extra=None, + load_source_on_run=None, + with_mlrun=None, + auto_build=None, + requirements=None, + requirements_file=None, + overwrite=False, + ): + if image: + self.image = image + if base_image: + self.base_image = base_image + if commands: + self.with_commands(commands, overwrite=overwrite) + if requirements: + self.with_requirements(requirements, requirements_file, overwrite=overwrite) + if extra: + self.extra = extra + if secret is not None: + self.secret = secret + if source: + self.source = source + if load_source_on_run: + self.load_source_on_run = load_source_on_run + if with_mlrun is not None: + self.with_mlrun = with_mlrun + if auto_build: + self.auto_build = auto_build + + def with_commands( + self, + commands: List[str], + overwrite: bool = False, + ): + """add commands to build spec. + + :param commands: list of commands to run during build + :param overwrite: whether to overwrite the existing commands or add to them (the default) + + :return: function object + """ + if not isinstance(commands, list) or not all( + isinstance(item, str) for item in commands + ): + raise ValueError("commands must be a string list") + if not self.commands or overwrite: + self.commands = commands + else: + # add commands to existing build commands + for command in commands: + if command not in self.commands: + self.commands.append(command) + # using list(set(x)) won't retain order, + # solution inspired from https://stackoverflow.com/a/17016257/8116661 + self.commands = list(dict.fromkeys(self.commands)) + + def with_requirements( + self, + requirements: Union[str, List[str]], + requirements_file: str = "", + overwrite: bool = False, + ): + """add package requirements from file or list to build spec. + + :param requirements: a list of python packages + :param requirements_file: path to a python requirements file + :param overwrite: overwrite existing requirements, + when False (default) will append to existing requirements + :return: function object + """ + if isinstance(requirements, str) and mlrun.utils.is_file_path(requirements): + # TODO: remove in 1.6.0 + warnings.warn( + "Passing a requirements file path as a string in the 'requirements' argument is deprecated " + "and will be removed in 1.6.0, use 'requirements_file' instead", + FutureWarning, + ) + + resolved_requirements = self._resolve_requirements( + requirements, requirements_file + ) + requirements = self.requirements or [] if not overwrite else [] + + # make sure we don't append the same line twice + for requirement in resolved_requirements: + if requirement not in requirements: + requirements.append(requirement) + + self.requirements = requirements + + @staticmethod + def _resolve_requirements( + requirements: typing.Union[str, list], requirements_file: str = "" + ) -> list: + requirements_to_resolve = [] + + # handle the requirements_file argument + if requirements_file: + with open(requirements_file, "r") as fp: + requirements_to_resolve.extend(fp.read().splitlines()) + + # handle the requirements argument + # TODO: remove in 1.6.0, when requirements can only be a list + if isinstance(requirements, str): + # if it's a file path, read the file and add its content to the list + if mlrun.utils.is_file_path(requirements): + with open(requirements, "r") as fp: + requirements_to_resolve.extend(fp.read().splitlines()) + else: + # it's a string but not a file path, split it by lines and add it to the list + requirements_to_resolve.append(requirements) + else: + # it's a list, add it to the list + requirements_to_resolve.extend(requirements) + + requirements = [] + for requirement in requirements_to_resolve: + # clean redundant leading and trailing whitespaces + requirement = requirement.strip() + + # ignore empty lines + # ignore comments + if not requirement or requirement.startswith("#"): + continue + + # ignore inline comments as well + inline_comment = requirement.split(" #") + if len(inline_comment) > 1: + requirement = inline_comment[0].strip() + + requirements.append(requirement) + + return requirements + class Notification(ModelObj): """Notification specification""" @@ -401,6 +549,25 @@ def __init__( self.status = status self.sent_time = sent_time + self.validate_notification() + + def validate_notification(self): + try: + mlrun.common.schemas.notification.Notification(**self.to_dict()) + except pydantic.error_wrappers.ValidationError as exc: + raise mlrun.errors.MLRunInvalidArgumentError( + "Invalid notification object" + ) from exc + + @staticmethod + def validate_notification_uniqueness(notifications: List["Notification"]): + """Validate that all notifications in the list are unique by name""" + names = [notification.name for notification in notifications] + if len(names) != len(set(names)): + raise mlrun.errors.MLRunInvalidArgumentError( + "Notification names must be unique" + ) + class RunMetadata(ModelObj): """Run metadata""" @@ -636,6 +803,9 @@ def returns(self, returns: List[Union[str, Dict[str, str]]]): :raise MLRunInvalidArgumentError: In case one of the values in the list is invalid. """ + # This import is located in the method due to circular imports error. + from mlrun.package.utils import LogHintUtils + if returns is None: self._returns = None return @@ -643,7 +813,7 @@ def returns(self, returns: List[Union[str, Dict[str, str]]]): # Validate: for log_hint in returns: - mlrun.run._parse_log_hint(log_hint=log_hint) + LogHintUtils.parse_log_hint(log_hint=log_hint) # Store the results: self._returns = returns @@ -831,6 +1001,7 @@ def __init__( iterations=None, ui_url=None, reason: str = None, + notifications: Dict[str, Notification] = None, ): self.state = state or "created" self.status_text = status_text @@ -844,6 +1015,7 @@ def __init__( self.iterations = iterations self.ui_url = ui_url self.reason = reason + self.notifications = notifications or {} class RunTemplate(ModelObj): @@ -1014,6 +1186,20 @@ def status(self) -> RunStatus: def status(self, status): self._status = self._verify_dict(status, "status", RunStatus) + @property + def error(self) -> str: + """error string if failed""" + if self.status: + if self.status.state != "error": + return f"Run state ({self.status.state}) is not in error state" + return ( + self.status.error + or self.status.reason + or self.status.status_text + or "Unknown error" + ) + return "" + def output(self, key): """return the value of a specific result or artifact by key""" self._outputs_wait_for_completion() @@ -1185,7 +1371,6 @@ def wait_for_completion( ) if logs_enabled and not logs_interval: self.logs(watch=False) - if raise_on_failure and state != mlrun.runtimes.constants.RunStates.completed: raise mlrun.errors.MLRunRuntimeError( f"task {self.metadata.name} did not complete (state={state})" diff --git a/mlrun/model_monitoring/__init__.py b/mlrun/model_monitoring/__init__.py new file mode 100644 index 000000000000..8a3c19723b7b --- /dev/null +++ b/mlrun/model_monitoring/__init__.py @@ -0,0 +1,41 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx +# for backwards compatibility + +__all__ = [ + "ModelEndpoint", + "EventFieldType", + "EventLiveStats", + "EventKeyMetrics", + "TimeSeriesTarget", + "ModelEndpointTarget", + "FileTargetKind", + "ProjectSecretKeys", + "ModelMonitoringStoreKinds", +] + +from mlrun.common.model_monitoring import ( + EventFieldType, + EventKeyMetrics, + EventLiveStats, + FileTargetKind, + ModelEndpointTarget, + ModelMonitoringStoreKinds, + ProjectSecretKeys, + TimeSeriesTarget, +) + +from .model_endpoint import ModelEndpoint diff --git a/mlrun/model_monitoring/constants.py b/mlrun/model_monitoring/constants.py deleted file mode 100644 index bf3c616f36d9..000000000000 --- a/mlrun/model_monitoring/constants.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -class EventFieldType: - FUNCTION_URI = "function_uri" - MODEL = "model" - VERSION = "version" - VERSIONED_MODEL = "versioned_model" - MODEL_CLASS = "model_class" - TIMESTAMP = "timestamp" - ENDPOINT_ID = "endpoint_id" - REQUEST_ID = "request_id" - RECORD_TYPE = "record_type" - FEATURES = "features" - FEATURE_NAMES = "feature_names" - NAMED_FEATURES = "named_features" - LABELS = "labels" - LATENCY = "latency" - UNPACKED_LABELS = "unpacked_labels" - LABEL_COLUMNS = "label_columns" - LABEL_NAMES = "label_names" - PREDICTION = "prediction" - PREDICTIONS = "predictions" - NAMED_PREDICTIONS = "named_predictions" - ERROR_COUNT = "error_count" - ENTITIES = "entities" - FIRST_REQUEST = "first_request" - LAST_REQUEST = "last_request" - METRICS = "metrics" - BATCH_TIMESTAMP = "batch_timestamp" - TIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f" - BATCH_INTERVALS_DICT = "batch_intervals_dict" - DEFAULT_BATCH_INTERVALS = "default_batch_intervals" - DEFAULT_BATCH_IMAGE = "default_batch_image" - STREAM_IMAGE = "stream_image" - MINUTES = "minutes" - HOURS = "hours" - DAYS = "days" - - -class EventLiveStats: - LATENCY_AVG_5M = "latency_avg_5m" - LATENCY_AVG_1H = "latency_avg_1h" - PREDICTIONS_PER_SECOND = "predictions_per_second" - PREDICTIONS_COUNT_5M = "predictions_count_5m" - PREDICTIONS_COUNT_1H = "predictions_count_1h" - - -class EventKeyMetrics: - BASE_METRICS = "base_metrics" - CUSTOM_METRICS = "custom_metrics" - ENDPOINT_FEATURES = "endpoint_features" - - -class StoreTarget: - TSDB = "tsdb" diff --git a/mlrun/model_monitoring/helpers.py b/mlrun/model_monitoring/helpers.py index d35ecd28ebed..2c1bde2235a7 100644 --- a/mlrun/model_monitoring/helpers.py +++ b/mlrun/model_monitoring/helpers.py @@ -13,18 +13,25 @@ # limitations under the License. # import pathlib +import typing import sqlalchemy.orm +from fastapi import Depends import mlrun import mlrun.api.api.utils import mlrun.api.crud.secrets import mlrun.api.utils.singletons.db +import mlrun.api.utils.singletons.k8s +import mlrun.common.model_monitoring as model_monitoring_constants +import mlrun.common.schemas import mlrun.config import mlrun.feature_store as fstore import mlrun.model_monitoring.stream_processing_fs import mlrun.runtimes import mlrun.utils.helpers +import mlrun.utils.model_monitoring +from mlrun.api.api import deps _CURRENT_FILE_PATH = pathlib.Path(__file__) _STREAM_PROCESSING_FUNCTION_PATH = _CURRENT_FILE_PATH.parent / "stream_processing_fs.py" @@ -36,16 +43,20 @@ def initial_model_monitoring_stream_processing_function( project: str, model_monitoring_access_key: str, - db_session: sqlalchemy.orm.Session, tracking_policy: mlrun.utils.model_monitoring.TrackingPolicy, + auth_info: mlrun.common.schemas.AuthInfo, + parquet_target: str, ): """ Initialize model monitoring stream processing function. - :param project: project name. - :param model_monitoring_access_key: access key to apply the model monitoring process. - :param db_session: A session that manages the current dialog with the database. + :param project: Project name. + :param model_monitoring_access_key: Access key to apply the model monitoring process. Please note that in CE + deployments this parameter will be None. :param tracking_policy: Model monitoring configurations. + :param auth_info: The auth info of the request. + :parquet_target: Path to model monitoring parquet file that will be generated by the monitoring + stream nuclio function. :return: A function object from a mlrun runtime class @@ -54,12 +65,11 @@ def initial_model_monitoring_stream_processing_function( # Initialize Stream Processor object stream_processor = mlrun.model_monitoring.stream_processing_fs.EventStreamProcessor( project=project, - model_monitoring_access_key=model_monitoring_access_key, parquet_batching_max_events=mlrun.mlconf.model_endpoint_monitoring.parquet_batching_max_events, + parquet_target=parquet_target, + model_monitoring_access_key=model_monitoring_access_key, ) - http_source = mlrun.datastore.sources.HttpSource() - # Create a new serving function for the streaming process function = mlrun.code_to_function( name="model-monitoring-stream", @@ -75,46 +85,34 @@ def initial_model_monitoring_stream_processing_function( # Set the project to the serving function function.metadata.project = project - # Add v3io stream trigger - stream_path = mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( - project=project, kind="stream" - ) - function.add_v3io_stream_trigger( - stream_path=stream_path, name="monitoring_stream_trigger" - ) - - # Set model monitoring access key for managing permissions - function.set_env_from_secret( - "MODEL_MONITORING_ACCESS_KEY", - mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_name(project), - mlrun.api.crud.secrets.Secrets().generate_client_project_secret_key( - mlrun.api.crud.secrets.SecretsClientType.model_monitoring, - "MODEL_MONITORING_ACCESS_KEY", - ), + # Add stream triggers + function = _apply_stream_trigger( + project=project, + function=function, + model_monitoring_access_key=model_monitoring_access_key, + auth_info=auth_info, ) + # Apply feature store run configurations on the serving function run_config = fstore.RunConfig(function=function, local=False) function.spec.parameters = run_config.parameters - func = http_source.add_nuclio_trigger(function) - func.metadata.credentials.access_key = model_monitoring_access_key - func.apply(mlrun.v3io_cred()) - - return func + return function def get_model_monitoring_batch_function( project: str, model_monitoring_access_key: str, db_session: sqlalchemy.orm.Session, - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, tracking_policy: mlrun.utils.model_monitoring.TrackingPolicy, ): """ Initialize model monitoring batch function. :param project: project name. - :param model_monitoring_access_key: access key to apply the model monitoring process. + :param model_monitoring_access_key: access key to apply the model monitoring process. Please note that in CE + deployments this parameter will be None. :param db_session: A session that manages the current dialog with the database. :param auth_info: The auth info of the request. :param tracking_policy: Model monitoring configurations. @@ -137,20 +135,107 @@ def get_model_monitoring_batch_function( # Set the project to the job function function.metadata.project = project + if not mlrun.mlconf.is_ce_mode(): + function = _apply_access_key_and_mount_function( + project=project, + function=function, + model_monitoring_access_key=model_monitoring_access_key, + auth_info=auth_info, + ) + + # Enrich runtime with the required configurations + mlrun.api.api.utils.apply_enrichment_and_validation_on_function(function, auth_info) + + return function + + +def _apply_stream_trigger( + project: str, + function: mlrun.runtimes.ServingRuntime, + model_monitoring_access_key: str = None, + auth_info: mlrun.common.schemas.AuthInfo = Depends(deps.authenticate_request), +) -> mlrun.runtimes.ServingRuntime: + """Adding stream source for the nuclio serving function. By default, the function has HTTP stream trigger along + with another supported stream source that can be either Kafka or V3IO, depends on the stream path schema that is + defined under mlrun.mlconf.model_endpoint_monitoring.store_prefixes. Note that if no valid stream path has been + provided then the function will have a single HTTP stream source. + + :param project: Project name. + :param function: The serving function object that will be applied with the stream trigger. + :param model_monitoring_access_key: Access key to apply the model monitoring stream function when the stream is + schema is V3IO. + :param auth_info: The auth info of the request. + + :return: ServingRuntime object with stream trigger. + """ + + # Get the stream path from the configuration + # stream_path = mlrun.mlconf.get_file_target_path(project=project, kind="stream", target="stream") + stream_path = mlrun.utils.model_monitoring.get_stream_path(project=project) + + if stream_path.startswith("kafka://"): + + topic, brokers = mlrun.datastore.utils.parse_kafka_url(url=stream_path) + # Generate Kafka stream source + stream_source = mlrun.datastore.sources.KafkaSource( + brokers=brokers, + topics=[topic], + ) + function = stream_source.add_nuclio_trigger(function) + + if not mlrun.mlconf.is_ce_mode(): + function = _apply_access_key_and_mount_function( + project=project, + function=function, + model_monitoring_access_key=model_monitoring_access_key, + auth_info=auth_info, + ) + if stream_path.startswith("v3io://"): + # Generate V3IO stream trigger + function.add_v3io_stream_trigger( + stream_path=stream_path, name="monitoring_stream_trigger" + ) + # Add the default HTTP source + http_source = mlrun.datastore.sources.HttpSource() + function = http_source.add_nuclio_trigger(function) + + return function + + +def _apply_access_key_and_mount_function( + project: str, + function: typing.Union[ + mlrun.runtimes.KubejobRuntime, mlrun.runtimes.ServingRuntime + ], + model_monitoring_access_key: str, + auth_info: mlrun.common.schemas.AuthInfo, +) -> typing.Union[mlrun.runtimes.KubejobRuntime, mlrun.runtimes.ServingRuntime]: + """Applying model monitoring access key on the provided function when using V3IO path. In addition, this method + mount the V3IO path for the provided function to configure the access to the system files. + + :param project: Project name. + :param function: Model monitoring function object that will be filled with the access key and + the access to the system files. + :param model_monitoring_access_key: Access key to apply the model monitoring stream function when the stream is + schema is V3IO. + :param auth_info: The auth info of the request. + + :return: function runtime object with access key and access to system files. + """ + # Set model monitoring access key for managing permissions function.set_env_from_secret( - "MODEL_MONITORING_ACCESS_KEY", - mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_name(project), + model_monitoring_constants.ProjectSecretKeys.ACCESS_KEY, + mlrun.api.utils.singletons.k8s.get_k8s_helper().get_project_secret_name( + project + ), mlrun.api.crud.secrets.Secrets().generate_client_project_secret_key( mlrun.api.crud.secrets.SecretsClientType.model_monitoring, - "MODEL_MONITORING_ACCESS_KEY", + model_monitoring_constants.ProjectSecretKeys.ACCESS_KEY, ), ) - - function.apply(mlrun.mount_v3io()) - - # Needs to be a member of the project and have access to project data path function.metadata.credentials.access_key = model_monitoring_access_key + function.apply(mlrun.mount_v3io()) # Ensure that the auth env vars are set mlrun.api.api.utils.ensure_function_has_auth_set(function, auth_info) diff --git a/mlrun/model_monitoring/model_endpoint.py b/mlrun/model_monitoring/model_endpoint.py new file mode 100644 index 000000000000..991158e24d44 --- /dev/null +++ b/mlrun/model_monitoring/model_endpoint.py @@ -0,0 +1,144 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Any, Dict, List, Optional + +import mlrun.model +from mlrun.common.model_monitoring import ( + EndpointType, + EventKeyMetrics, + EventLiveStats, + ModelMonitoringMode, +) + + +class ModelEndpointSpec(mlrun.model.ModelObj): + def __init__( + self, + function_uri: Optional[str] = "", + model: Optional[str] = "", + model_class: Optional[str] = "", + model_uri: Optional[str] = "", + feature_names: Optional[List[str]] = None, + label_names: Optional[List[str]] = None, + stream_path: Optional[str] = "", + algorithm: Optional[str] = "", + monitor_configuration: Optional[dict] = None, + active: Optional[bool] = True, + monitoring_mode: Optional[ModelMonitoringMode] = ModelMonitoringMode.disabled, + ): + self.function_uri = function_uri # /: + self.model = model # : + self.model_class = model_class + self.model_uri = model_uri + self.feature_names = feature_names or [] + self.label_names = label_names or [] + self.stream_path = stream_path + self.algorithm = algorithm + self.monitor_configuration = monitor_configuration or {} + self.active = active + self.monitoring_mode = monitoring_mode + + +class ModelEndpointStatus(mlrun.model.ModelObj): + def __init__( + self, + feature_stats: Optional[dict] = None, + current_stats: Optional[dict] = None, + first_request: Optional[str] = "", + last_request: Optional[str] = "", + error_count: Optional[int] = 0, + drift_status: Optional[str] = "", + drift_measures: Optional[dict] = None, + metrics: Optional[Dict[str, Dict[str, Any]]] = None, + features: Optional[List[Dict[str, Any]]] = None, + children: Optional[List[str]] = None, + children_uids: Optional[List[str]] = None, + endpoint_type: Optional[EndpointType] = EndpointType.NODE_EP.value, + monitoring_feature_set_uri: Optional[str] = "", + state: Optional[str] = "", + ): + self.feature_stats = feature_stats or {} + self.current_stats = current_stats or {} + self.first_request = first_request + self.last_request = last_request + self.error_count = error_count + self.drift_status = drift_status + self.drift_measures = drift_measures or {} + self.features = features or [] + self.children = children or [] + self.children_uids = children_uids or [] + self.endpoint_type = endpoint_type + self.monitoring_feature_set_uri = monitoring_feature_set_uri + if metrics is None: + self.metrics = { + EventKeyMetrics.GENERIC: { + EventLiveStats.LATENCY_AVG_1H: 0, + EventLiveStats.PREDICTIONS_PER_SECOND: 0, + } + } + self.state = state + + +class ModelEndpoint(mlrun.model.ModelObj): + kind = "model-endpoint" + _dict_fields = ["kind", "metadata", "spec", "status"] + + def __init__(self): + self._status: ModelEndpointStatus = ModelEndpointStatus() + self._spec: ModelEndpointSpec = ModelEndpointSpec() + self._metadata: mlrun.model.VersionedObjMetadata = ( + mlrun.model.VersionedObjMetadata() + ) + + @property + def status(self) -> ModelEndpointStatus: + return self._status + + @status.setter + def status(self, status): + self._status = self._verify_dict(status, "status", ModelEndpointStatus) + + @property + def spec(self) -> ModelEndpointSpec: + return self._spec + + @spec.setter + def spec(self, spec): + self._spec = self._verify_dict(spec, "spec", ModelEndpointSpec) + + @property + def metadata(self) -> mlrun.model.VersionedObjMetadata: + return self._metadata + + @metadata.setter + def metadata(self, metadata): + self._metadata = self._verify_dict( + metadata, "metadata", mlrun.model.VersionedObjMetadata + ) + + @classmethod + def from_flat_dict(cls, struct=None, fields=None, deprecated_fields: dict = None): + new_obj = cls() + new_obj._metadata = mlrun.model.VersionedObjMetadata().from_dict( + struct=struct, fields=fields, deprecated_fields=deprecated_fields + ) + new_obj._status = ModelEndpointStatus().from_dict( + struct=struct, fields=fields, deprecated_fields=deprecated_fields + ) + new_obj._spec = ModelEndpointSpec().from_dict( + struct=struct, fields=fields, deprecated_fields=deprecated_fields + ) + return new_obj diff --git a/mlrun/model_monitoring/model_monitoring_batch.py b/mlrun/model_monitoring/model_monitoring_batch.py index b8ad312af65e..9179f05a6809 100644 --- a/mlrun/model_monitoring/model_monitoring_batch.py +++ b/mlrun/model_monitoring/model_monitoring_batch.py @@ -28,14 +28,16 @@ import v3io_frames import mlrun -import mlrun.api.schemas +import mlrun.common.model_monitoring +import mlrun.common.schemas import mlrun.data_types.infer import mlrun.feature_store as fstore +import mlrun.model_monitoring +import mlrun.model_monitoring.stores import mlrun.run import mlrun.utils.helpers import mlrun.utils.model_monitoring import mlrun.utils.v3io_clients -from mlrun.model_monitoring.constants import EventFieldType from mlrun.utils import logger @@ -461,6 +463,7 @@ def calculate_inputs_statistics( :returns: The calculated statistics of the inputs data. """ + # Use `DFDataInfer` to calculate the statistics over the inputs: inputs_statistics = mlrun.data_types.infer.DFDataInfer.get_stats( df=inputs, @@ -493,8 +496,6 @@ def __init__( self, context: mlrun.run.MLClientCtx, project: str, - model_monitoring_access_key: str, - v3io_access_key: str, ): """ @@ -502,60 +503,16 @@ def __init__( :param context: An MLRun context. :param project: Project name. - :param model_monitoring_access_key: Access key to apply the model monitoring process. - :param v3io_access_key: Token key for v3io. """ self.context = context self.project = project - self.v3io_access_key = v3io_access_key - self.model_monitoring_access_key = ( - model_monitoring_access_key or v3io_access_key - ) - # Initialize virtual drift object self.virtual_drift = VirtualDrift(inf_capping=10) - # Define the required paths for the project objects. - # Note that the kv table, tsdb, and the input stream paths are located at the default location - # while the parquet path is located at the user-space location - template = mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default - kv_path = template.format(project=self.project, kind="endpoints") - ( - _, - self.kv_container, - self.kv_path, - ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(kv_path) - tsdb_path = template.format(project=project, kind="events") - ( - _, - self.tsdb_container, - self.tsdb_path, - ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(tsdb_path) - stream_path = template.format(project=self.project, kind="log_stream") - ( - _, - self.stream_container, - self.stream_path, - ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(stream_path) - self.parquet_path = ( - mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format( - project=project, kind="parquet" - ) - ) - logger.info( "Initializing BatchProcessor", project=project, - model_monitoring_access_key_initalized=bool(model_monitoring_access_key), - v3io_access_key_initialized=bool(v3io_access_key), - parquet_path=self.parquet_path, - kv_container=self.kv_container, - kv_path=self.kv_path, - tsdb_container=self.tsdb_container, - tsdb_path=self.tsdb_path, - stream_container=self.stream_container, - stream_path=self.stream_path, ) # Get drift thresholds from the model monitoring configuration @@ -567,46 +524,87 @@ def __init__( ) # Get a runtime database - self.db = mlrun.get_run_db() - # Get the frames clients based on the v3io configuration - # it will be used later for writing the results into the tsdb - self.v3io = mlrun.utils.v3io_clients.get_v3io_client( - access_key=self.v3io_access_key - ) - self.frames = mlrun.utils.v3io_clients.get_frames_client( - address=mlrun.mlconf.v3io_framesd, - container=self.tsdb_container, - token=self.v3io_access_key, + self.db = mlrun.model_monitoring.stores.get_model_endpoint_store( + project=project ) + if not mlrun.mlconf.is_ce_mode(): + # TODO: Once there is a time series DB alternative in a non-CE deployment, we need to update this if + # statement to be applied only for V3IO TSDB + self._initialize_v3io_configurations() + # If an error occurs, it will be raised using the following argument self.exception = None # Get the batch interval range - self.batch_dict = context.parameters[EventFieldType.BATCH_INTERVALS_DICT] + self.batch_dict = context.parameters[ + mlrun.common.model_monitoring.EventFieldType.BATCH_INTERVALS_DICT + ] - # TODO: This will be removed in 1.2.0 once the job params can be parsed with different types + # TODO: This will be removed in 1.5.0 once the job params can be parsed with different types # Convert batch dict string into a dictionary if isinstance(self.batch_dict, str): self._parse_batch_dict_str() + def _initialize_v3io_configurations(self): + self.v3io_access_key = os.environ.get("V3IO_ACCESS_KEY") + self.model_monitoring_access_key = ( + os.environ.get("MODEL_MONITORING_ACCESS_KEY") or self.v3io_access_key + ) + + # Define the required paths for the project objects + tsdb_path = mlrun.mlconf.get_model_monitoring_file_target_path( + project=self.project, + kind=mlrun.common.model_monitoring.FileTargetKind.EVENTS, + ) + ( + _, + self.tsdb_container, + self.tsdb_path, + ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(tsdb_path) + # stream_path = template.format(project=self.project, kind="log_stream") + stream_path = mlrun.mlconf.get_model_monitoring_file_target_path( + project=self.project, + kind=mlrun.common.model_monitoring.FileTargetKind.LOG_STREAM, + ) + ( + _, + self.stream_container, + self.stream_path, + ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(stream_path) + + # Get the frames clients based on the v3io configuration + # it will be used later for writing the results into the tsdb + self.v3io = mlrun.utils.v3io_clients.get_v3io_client( + access_key=self.v3io_access_key + ) + self.frames = mlrun.utils.v3io_clients.get_frames_client( + address=mlrun.mlconf.v3io_framesd, + container=self.tsdb_container, + token=self.v3io_access_key, + ) + def post_init(self): """ Preprocess of the batch processing. """ - # create v3io stream based on the input stream - response = self.v3io.create_stream( - container=self.stream_container, - path=self.stream_path, - shard_count=1, - raise_for_status=v3io.dataplane.RaiseForStatus.never, - access_key=self.v3io_access_key, - ) + if not mlrun.mlconf.is_ce_mode(): + # Create v3io stream based on the input stream + response = self.v3io.create_stream( + container=self.stream_container, + path=self.stream_path, + shard_count=1, + raise_for_status=v3io.dataplane.RaiseForStatus.never, + access_key=self.v3io_access_key, + ) - if not (response.status_code == 400 and "ResourceInUse" in str(response.body)): - response.raise_for_status([409, 204, 403]) + if not ( + response.status_code == 400 and "ResourceInUse" in str(response.body) + ): + response.raise_for_status([409, 204, 403]) + pass def run(self): """ @@ -614,231 +612,218 @@ def run(self): """ # Get model endpoints (each deployed project has at least 1 serving model): try: - endpoints = self.db.list_model_endpoints(self.project) + endpoints = self.db.list_model_endpoints() except Exception as e: logger.error("Failed to list endpoints", exc=e) return - active_endpoints = set() - for endpoint in endpoints.endpoints: + for endpoint in endpoints: if ( - endpoint.spec.active - and endpoint.spec.monitoring_mode - == mlrun.api.schemas.ModelMonitoringMode.enabled.value + endpoint[mlrun.common.model_monitoring.EventFieldType.ACTIVE] + and endpoint[ + mlrun.common.model_monitoring.EventFieldType.MONITORING_MODE + ] + == mlrun.common.model_monitoring.ModelMonitoringMode.enabled.value ): - active_endpoints.add(endpoint.metadata.uid) - - # perform drift analysis for each model endpoint - for endpoint_id in active_endpoints: - try: - - # Get model endpoint object: - endpoint = self.db.get_model_endpoint( - project=self.project, endpoint_id=endpoint_id - ) - # Skip router endpoint: if ( - endpoint.status.endpoint_type - == mlrun.utils.model_monitoring.EndpointType.ROUTER + int( + endpoint[ + mlrun.common.model_monitoring.EventFieldType.ENDPOINT_TYPE + ] + ) + == mlrun.common.model_monitoring.EndpointType.ROUTER ): - # endpoint.status.feature_stats is None - logger.info(f"{endpoint_id} is router skipping") + # Router endpoint has no feature stats + logger.info( + f"{endpoint[mlrun.common.model_monitoring.EventFieldType.UID]} is router skipping" + ) continue + self.update_drift_metrics(endpoint=endpoint) - # convert feature set into dataframe and get the latest dataset - ( - _, - serving_function_name, - _, - _, - ) = mlrun.utils.helpers.parse_versioned_object_uri( - endpoint.spec.function_uri - ) - - model_name = endpoint.spec.model.replace(":", "-") + def update_drift_metrics(self, endpoint: dict): + try: + # Convert feature set into dataframe and get the latest dataset + ( + _, + serving_function_name, + _, + _, + ) = mlrun.utils.helpers.parse_versioned_object_uri( + endpoint[mlrun.common.model_monitoring.EventFieldType.FUNCTION_URI] + ) - m_fs = fstore.get_feature_set( - f"store://feature-sets/{self.project}/monitoring-{serving_function_name}-{model_name}" - ) + model_name = endpoint[ + mlrun.common.model_monitoring.EventFieldType.MODEL + ].replace(":", "-") - # Getting batch interval start time and end time - start_time, end_time = self.get_interval_range() + m_fs = fstore.get_feature_set( + f"store://feature-sets/{self.project}/monitoring-{serving_function_name}-{model_name}" + ) - try: - df = m_fs.to_dataframe( - start_time=start_time, - end_time=end_time, - time_column="timestamp", - ) + # Getting batch interval start time and end time + start_time, end_time = self._get_interval_range() - if len(df) == 0: - logger.warn( - "Not enough model events since the beginning of the batch interval", - parquet_target=m_fs.status.targets[0].path, - endpoint=endpoint_id, - min_rqeuired_events=mlrun.mlconf.model_endpoint_monitoring.parquet_batching_max_events, - start_time=str( - datetime.datetime.now() - datetime.timedelta(hours=1) - ), - end_time=str(datetime.datetime.now()), - ) - continue + try: + df = m_fs.to_dataframe( + start_time=start_time, + end_time=end_time, + time_column=mlrun.common.model_monitoring.EventFieldType.TIMESTAMP, + ) - # TODO: The below warn will be removed once the state of the Feature Store target is updated - # as expected. In that case, the existence of the file will be checked before trying to get - # the offline data from the feature set. - # Continue if not enough events provided since the deployment of the model endpoint - except FileNotFoundError: + if len(df) == 0: logger.warn( - "Parquet not found, probably due to not enough model events", + "Not enough model events since the beginning of the batch interval", parquet_target=m_fs.status.targets[0].path, - endpoint=endpoint_id, + endpoint=endpoint[ + mlrun.common.model_monitoring.EventFieldType.UID + ], min_rqeuired_events=mlrun.mlconf.model_endpoint_monitoring.parquet_batching_max_events, + start_time=str( + datetime.datetime.now() - datetime.timedelta(hours=1) + ), + end_time=str(datetime.datetime.now()), ) - continue + return + + # TODO: The below warn will be removed once the state of the Feature Store target is updated + # as expected. In that case, the existence of the file will be checked before trying to get + # the offline data from the feature set. + # Continue if not enough events provided since the deployment of the model endpoint + except FileNotFoundError: + logger.warn( + "Parquet not found, probably due to not enough model events", + parquet_target=m_fs.status.targets[0].path, + endpoint=endpoint[mlrun.common.model_monitoring.EventFieldType.UID], + min_rqeuired_events=mlrun.mlconf.model_endpoint_monitoring.parquet_batching_max_events, + ) + return - # Get feature names from monitoring feature set - feature_names = [ - feature_name["name"] - for feature_name in m_fs.spec.features.to_dict() - ] + # Get feature names from monitoring feature set + feature_names = [ + feature_name["name"] for feature_name in m_fs.spec.features.to_dict() + ] - # Create DataFrame based on the input features - stats_columns = [ - "timestamp", - *feature_names, - ] + # Create DataFrame based on the input features + stats_columns = [ + mlrun.common.model_monitoring.EventFieldType.TIMESTAMP, + *feature_names, + ] - # Add label names if provided - if endpoint.spec.label_names: - stats_columns.extend(endpoint.spec.label_names) + # Add label names if provided + if endpoint[mlrun.common.model_monitoring.EventFieldType.LABEL_NAMES]: + labels = endpoint[ + mlrun.common.model_monitoring.EventFieldType.LABEL_NAMES + ] + if isinstance(labels, str): + labels = json.loads(labels) + stats_columns.extend(labels) + named_features_df = df[stats_columns].copy() + + # Infer feature set stats and schema + fstore.api._infer_from_static_df( + named_features_df, + m_fs, + options=mlrun.data_types.infer.InferOptions.all_stats(), + ) - named_features_df = df[stats_columns].copy() + # Save feature set to apply changes + m_fs.save() - # Infer feature set stats and schema - fstore.api._infer_from_static_df( - named_features_df, - m_fs, - options=mlrun.data_types.infer.InferOptions.all_stats(), - ) + # Get the timestamp of the latest request: + timestamp = df[mlrun.common.model_monitoring.EventFieldType.TIMESTAMP].iloc[ + -1 + ] - # Save feature set to apply changes - m_fs.save() + # Get the feature stats from the model endpoint for reference data + feature_stats = json.loads( + endpoint[mlrun.common.model_monitoring.EventFieldType.FEATURE_STATS] + ) - # Get the timestamp of the latest request: - timestamp = df["timestamp"].iloc[-1] + # Get the current stats: + current_stats = calculate_inputs_statistics( + sample_set_statistics=feature_stats, + inputs=named_features_df, + ) - # Get the current stats: - current_stats = calculate_inputs_statistics( - sample_set_statistics=endpoint.status.feature_stats, - inputs=named_features_df, + # Compute the drift based on the histogram of the current stats and the histogram of the original + # feature stats that can be found in the model endpoint object: + drift_result = self.virtual_drift.compute_drift_from_histograms( + feature_stats=feature_stats, + current_stats=current_stats, + ) + logger.info("Drift result", drift_result=drift_result) + + # Get drift thresholds from the model configuration: + monitor_configuration = ( + json.loads( + endpoint[ + mlrun.common.model_monitoring.EventFieldType.MONITOR_CONFIGURATION + ] ) + or {} + ) + possible_drift = monitor_configuration.get( + "possible_drift", self.default_possible_drift_threshold + ) + drift_detected = monitor_configuration.get( + "drift_detected", self.default_drift_detected_threshold + ) - # Compute the drift based on the histogram of the current stats and the histogram of the original - # feature stats that can be found in the model endpoint object: - drift_result = self.virtual_drift.compute_drift_from_histograms( - feature_stats=endpoint.status.feature_stats, - current_stats=current_stats, - ) - logger.info("Drift result", drift_result=drift_result) + # Check for possible drift based on the results of the statistical metrics defined above: + drift_status, drift_measure = self.virtual_drift.check_for_drift( + metrics_results_dictionary=drift_result, + possible_drift_threshold=possible_drift, + drift_detected_threshold=drift_detected, + ) + logger.info( + "Drift status", + endpoint_id=endpoint[mlrun.common.model_monitoring.EventFieldType.UID], + drift_status=drift_status.value, + drift_measure=drift_measure, + ) - # Get drift thresholds from the model configuration: - monitor_configuration = endpoint.spec.monitor_configuration or {} - possible_drift = monitor_configuration.get( - "possible_drift", self.default_possible_drift_threshold - ) - drift_detected = monitor_configuration.get( - "drift_detected", self.default_drift_detected_threshold - ) + attributes = { + "current_stats": json.dumps(current_stats), + "drift_measures": json.dumps(drift_result), + "drift_status": drift_status.value, + } - # Check for possible drift based on the results of the statistical metrics defined above: - drift_status, drift_measure = self.virtual_drift.check_for_drift( - metrics_results_dictionary=drift_result, - possible_drift_threshold=possible_drift, - drift_detected_threshold=drift_detected, - ) - logger.info( - "Drift status", - endpoint_id=endpoint_id, - drift_status=drift_status.value, + self.db.update_model_endpoint( + endpoint_id=endpoint[mlrun.common.model_monitoring.EventFieldType.UID], + attributes=attributes, + ) + + if not mlrun.mlconf.is_ce_mode(): + # Update drift results in TSDB + self._update_drift_in_input_stream( + endpoint_id=endpoint[ + mlrun.common.model_monitoring.EventFieldType.UID + ], + drift_status=drift_status, drift_measure=drift_measure, + drift_result=drift_result, + timestamp=timestamp, ) - - # If drift was detected, add the results to the input stream - if ( - drift_status == DriftStatus.POSSIBLE_DRIFT - or drift_status == DriftStatus.DRIFT_DETECTED - ): - self.v3io.stream.put_records( - container=self.stream_container, - stream_path=self.stream_path, - records=[ - { - "data": json.dumps( - { - "endpoint_id": endpoint_id, - "drift_status": drift_status.value, - "drift_measure": drift_measure, - "drift_per_feature": {**drift_result}, - } - ) - } - ], - ) - - attributes = { - "current_stats": json.dumps(current_stats), - "drift_measures": json.dumps(drift_result), - "drift_status": drift_status.value, - } - - self.db.patch_model_endpoint( - project=self.project, - endpoint_id=endpoint_id, - attributes=attributes, + logger.info( + "Done updating drift measures", + endpoint_id=endpoint[ + mlrun.common.model_monitoring.EventFieldType.UID + ], ) - # Update the results in tsdb: - tsdb_drift_measures = { - "endpoint_id": endpoint_id, - "timestamp": pd.to_datetime( - timestamp, - format=EventFieldType.TIME_FORMAT, - ), - "record_type": "drift_measures", - "tvd_mean": drift_result["tvd_mean"], - "kld_mean": drift_result["kld_mean"], - "hellinger_mean": drift_result["hellinger_mean"], - } - - try: - self.frames.write( - backend="tsdb", - table=self.tsdb_path, - dfs=pd.DataFrame.from_dict([tsdb_drift_measures]), - index_cols=["timestamp", "endpoint_id", "record_type"], - ) - except v3io_frames.errors.Error as err: - logger.warn( - "Could not write drift measures to TSDB", - err=err, - tsdb_path=self.tsdb_path, - endpoint=endpoint_id, - ) - - logger.info("Done updating drift measures", endpoint_id=endpoint_id) - - except Exception as e: - logger.error(f"Exception for endpoint {endpoint_id}") - self.exception = e + except Exception as e: + logger.error( + f"Exception for endpoint {endpoint[mlrun.common.model_monitoring.EventFieldType.UID]}" + ) + self.exception = e - def get_interval_range(self) -> Tuple[datetime.datetime, datetime.datetime]: + def _get_interval_range(self) -> Tuple[datetime.datetime, datetime.datetime]: """Getting batch interval time range""" minutes, hours, days = ( - self.batch_dict[EventFieldType.MINUTES], - self.batch_dict[EventFieldType.HOURS], - self.batch_dict[EventFieldType.DAYS], + self.batch_dict[mlrun.common.model_monitoring.EventFieldType.MINUTES], + self.batch_dict[mlrun.common.model_monitoring.EventFieldType.HOURS], + self.batch_dict[mlrun.common.model_monitoring.EventFieldType.DAYS], ) start_time = datetime.datetime.now() - datetime.timedelta( minutes=minutes, hours=hours, days=days @@ -858,13 +843,79 @@ def _parse_batch_dict_str(self): pair_list = pair.split(":") self.batch_dict[pair_list[0]] = float(pair_list[1]) + def _update_drift_in_input_stream( + self, + endpoint_id: str, + drift_status: DriftStatus, + drift_measure: float, + drift_result: Dict[str, Dict[str, Any]], + timestamp: pd._libs.tslibs.timestamps.Timestamp, + ): + """Update drift results in input stream. + + :param endpoint_id: The unique id of the model endpoint. + :param drift_status: Drift status result. Possible values can be found under DriftStatus enum class. + :param drift_measure: The drift result (float) based on the mean of the Total Variance Distance and the + Hellinger distance. + :param drift_result: A dictionary that includes the drift results for each feature. + :param timestamp: Pandas Timestamp value. + + """ + + if ( + drift_status == DriftStatus.POSSIBLE_DRIFT + or drift_status == DriftStatus.DRIFT_DETECTED + ): + self.v3io.stream.put_records( + container=self.stream_container, + stream_path=self.stream_path, + records=[ + { + "data": json.dumps( + { + "endpoint_id": endpoint_id, + "drift_status": drift_status.value, + "drift_measure": drift_measure, + "drift_per_feature": {**drift_result}, + } + ) + } + ], + ) + + # Update the results in tsdb: + tsdb_drift_measures = { + "endpoint_id": endpoint_id, + "timestamp": pd.to_datetime( + timestamp, + format=mlrun.common.model_monitoring.EventFieldType.TIME_FORMAT, + ), + "record_type": "drift_measures", + "tvd_mean": drift_result["tvd_mean"], + "kld_mean": drift_result["kld_mean"], + "hellinger_mean": drift_result["hellinger_mean"], + } + + try: + self.frames.write( + backend="tsdb", + table=self.tsdb_path, + dfs=pd.DataFrame.from_dict([tsdb_drift_measures]), + index_cols=["timestamp", "endpoint_id", "record_type"], + ) + except v3io_frames.errors.Error as err: + logger.warn( + "Could not write drift measures to TSDB", + err=err, + tsdb_path=self.tsdb_path, + endpoint=endpoint_id, + ) + def handler(context: mlrun.run.MLClientCtx): batch_processor = BatchProcessor( context=context, project=context.project, - model_monitoring_access_key=os.environ.get("MODEL_MONITORING_ACCESS_KEY"), - v3io_access_key=os.environ.get("V3IO_ACCESS_KEY"), ) batch_processor.post_init() batch_processor.run() diff --git a/mlrun/model_monitoring/stores/__init__.py b/mlrun/model_monitoring/stores/__init__.py new file mode 100644 index 000000000000..b36430ebc676 --- /dev/null +++ b/mlrun/model_monitoring/stores/__init__.py @@ -0,0 +1,106 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx + +import enum +import typing + +import mlrun + +from .model_endpoint_store import ModelEndpointStore + + +class ModelEndpointStoreType(enum.Enum): + """Enum class to handle the different store type values for saving a model endpoint record.""" + + v3io_nosql = "v3io-nosql" + SQL = "sql" + + def to_endpoint_store( + self, + project: str, + access_key: str = None, + endpoint_store_connection: str = None, + ) -> ModelEndpointStore: + """ + Return a ModelEndpointStore object based on the provided enum value. + + :param project: The name of the project. + :param access_key: Access key with permission to the DB table. Note that if access key is None + and the endpoint target is from type KV then the access key will be + retrieved from the environment variable. + :param endpoint_store_connection: A valid connection string for model endpoint target. Contains several + key-value pairs that required for the database connection. + e.g. A root user with password 1234, tries to connect a schema called + mlrun within a local MySQL DB instance: + 'mysql+pymysql://root:1234@localhost:3306/mlrun'. + + :return: `ModelEndpointStore` object. + + """ + + if self.value == ModelEndpointStoreType.v3io_nosql.value: + + from .kv_model_endpoint_store import KVModelEndpointStore + + # Get V3IO access key from env + access_key = access_key or mlrun.mlconf.get_v3io_access_key() + + return KVModelEndpointStore(project=project, access_key=access_key) + + # Assuming SQL store target if store type is not KV. + # Update these lines once there are more than two store target types. + from mlrun.utils.model_monitoring import get_connection_string + + sql_connection_string = endpoint_store_connection or get_connection_string( + project=project + ) + from .sql_model_endpoint_store import SQLModelEndpointStore + + return SQLModelEndpointStore( + project=project, sql_connection_string=sql_connection_string + ) + + @classmethod + def _missing_(cls, value: typing.Any): + """A lookup function to handle an invalid value. + :param value: Provided enum (invalid) value. + """ + valid_values = list(cls.__members__.keys()) + raise mlrun.errors.MLRunInvalidArgumentError( + f"{value} is not a valid endpoint store, please choose a valid value: %{valid_values}." + ) + + +def get_model_endpoint_store( + project: str, access_key: str = None +) -> ModelEndpointStore: + """ + Getting the DB target type based on mlrun.config.model_endpoint_monitoring.store_type. + + :param project: The name of the project. + :param access_key: Access key with permission to the DB table. + + :return: `ModelEndpointStore` object. Using this object, the user can apply different operations on the + model endpoint record such as write, update, get and delete. + """ + + # Get store type value from ModelEndpointStoreType enum class + model_endpoint_store_type = ModelEndpointStoreType( + mlrun.mlconf.model_endpoint_monitoring.store_type + ) + + # Convert into model endpoint store target object + return model_endpoint_store_type.to_endpoint_store(project, access_key) diff --git a/mlrun/model_monitoring/stores/kv_model_endpoint_store.py b/mlrun/model_monitoring/stores/kv_model_endpoint_store.py new file mode 100644 index 000000000000..3c9f8c14549f --- /dev/null +++ b/mlrun/model_monitoring/stores/kv_model_endpoint_store.py @@ -0,0 +1,448 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import typing + +import v3io.dataplane +import v3io_frames + +import mlrun +import mlrun.common.model_monitoring as model_monitoring_constants +import mlrun.utils.model_monitoring +import mlrun.utils.v3io_clients +from mlrun.utils import logger + +from .model_endpoint_store import ModelEndpointStore + + +class KVModelEndpointStore(ModelEndpointStore): + """ + Handles the DB operations when the DB target is from type KV. For the KV operations, we use an instance of V3IO + client and usually the KV table can be found under v3io:///users/pipelines/project-name/model-endpoints/endpoints/. + """ + + def __init__(self, project: str, access_key: str): + super().__init__(project=project) + # Initialize a V3IO client instance + self.access_key = access_key or os.environ.get("V3IO_ACCESS_KEY") + self.client = mlrun.utils.v3io_clients.get_v3io_client( + endpoint=mlrun.mlconf.v3io_api, access_key=self.access_key + ) + # Get the KV table path and container + self.path, self.container = self._get_path_and_container() + + def write_model_endpoint(self, endpoint: typing.Dict[str, typing.Any]): + """ + Create a new endpoint record in the KV table. + + :param endpoint: model endpoint dictionary that will be written into the DB. + """ + + self.client.kv.put( + container=self.container, + table_path=self.path, + key=endpoint[model_monitoring_constants.EventFieldType.UID], + attributes=endpoint, + ) + + def update_model_endpoint( + self, endpoint_id: str, attributes: typing.Dict[str, typing.Any] + ): + """ + Update a model endpoint record with a given attributes. + + :param endpoint_id: The unique id of the model endpoint. + :param attributes: Dictionary of attributes that will be used for update the model endpoint. Note that the keys + of the attributes dictionary should exist in the KV table. + + """ + + self.client.kv.update( + container=self.container, + table_path=self.path, + key=endpoint_id, + attributes=attributes, + ) + + def delete_model_endpoint( + self, + endpoint_id: str, + ): + """ + Deletes the KV record of a given model endpoint id. + + :param endpoint_id: The unique id of the model endpoint. + """ + + self.client.kv.delete( + container=self.container, + table_path=self.path, + key=endpoint_id, + ) + + def get_model_endpoint( + self, + endpoint_id: str, + ) -> typing.Dict[str, typing.Any]: + """ + Get a single model endpoint record. + + :param endpoint_id: The unique id of the model endpoint. + + :return: A model endpoint record as a dictionary. + + :raise MLRunNotFoundError: If the endpoint was not found. + """ + + # Getting the raw data from the KV table + endpoint = self.client.kv.get( + container=self.container, + table_path=self.path, + key=endpoint_id, + raise_for_status=v3io.dataplane.RaiseForStatus.never, + access_key=self.access_key, + ) + endpoint = endpoint.output.item + + if not endpoint: + raise mlrun.errors.MLRunNotFoundError(f"Endpoint {endpoint_id} not found") + + # For backwards compatability: replace null values for `error_count` and `metrics` + mlrun.utils.model_monitoring.validate_old_schema_fields(endpoint=endpoint) + + return endpoint + + def _get_path_and_container(self): + """Getting path and container based on the model monitoring configurations""" + path = mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( + project=self.project, + kind=model_monitoring_constants.ModelMonitoringStoreKinds.ENDPOINTS, + ) + ( + _, + container, + path, + ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(path) + return path, container + + def list_model_endpoints( + self, + model: str = None, + function: str = None, + labels: typing.List[str] = None, + top_level: bool = None, + uids: typing.List = None, + ) -> typing.List[typing.Dict[str, typing.Any]]: + """ + Returns a list of model endpoint dictionaries, supports filtering by model, function, labels or top level. + By default, when no filters are applied, all available model endpoints for the given project will + be listed. + + :param model: The name of the model to filter by. + :param function: The name of the function to filter by. + :param labels: A list of labels to filter by. Label filters work by either filtering a specific value + of a label (i.e. list("key=value")) or by looking for the existence of a given + key (i.e. "key"). + :param top_level: If True will return only routers and endpoint that are NOT children of any router. + :param uids: List of model endpoint unique ids to include in the result. + + + :return: A list of model endpoint dictionaries. + """ + + # # Initialize an empty model endpoints list + endpoint_list = [] + + # Retrieve the raw data from the KV table and get the endpoint ids + try: + cursor = self.client.kv.new_cursor( + container=self.container, + table_path=self.path, + filter_expression=self._build_kv_cursor_filter_expression( + self.project, + function, + model, + labels, + top_level, + ), + raise_for_status=v3io.dataplane.RaiseForStatus.never, + ) + items = cursor.all() + + except Exception as exc: + logger.warning("Failed retrieving raw data from kv table", exc=exc) + return endpoint_list + + # Create a list of model endpoints unique ids + if uids is None: + uids = [] + for item in items: + if model_monitoring_constants.EventFieldType.UID not in item: + # This is kept for backwards compatibility - in old versions the key column named endpoint_id + uids.append( + item[model_monitoring_constants.EventFieldType.ENDPOINT_ID] + ) + else: + uids.append(item[model_monitoring_constants.EventFieldType.UID]) + + # Add each relevant model endpoint to the model endpoints list + for endpoint_id in uids: + endpoint = self.get_model_endpoint( + endpoint_id=endpoint_id, + ) + endpoint_list.append(endpoint) + + return endpoint_list + + def delete_model_endpoints_resources( + self, endpoints: typing.List[typing.Dict[str, typing.Any]] + ): + """ + Delete all model endpoints resources in both KV and the time series DB. + + :param endpoints: A list of model endpoints flattened dictionaries. + """ + + # Delete model endpoint record from KV table + for endpoint_dict in endpoints: + if model_monitoring_constants.EventFieldType.UID not in endpoint_dict: + # This is kept for backwards compatibility - in old versions the key column named endpoint_id + endpoint_id = endpoint_dict[ + model_monitoring_constants.EventFieldType.ENDPOINT_ID + ] + else: + endpoint_id = endpoint_dict[ + model_monitoring_constants.EventFieldType.UID + ] + self.delete_model_endpoint( + endpoint_id, + ) + + # Delete remain records in the KV + all_records = self.client.kv.new_cursor( + container=self.container, + table_path=self.path, + raise_for_status=v3io.dataplane.RaiseForStatus.never, + ).all() + + all_records = [r["__name"] for r in all_records] + + # Cleanup KV + for record in all_records: + self.client.kv.delete( + container=self.container, + table_path=self.path, + key=record, + raise_for_status=v3io.dataplane.RaiseForStatus.never, + ) + + # Cleanup TSDB + frames = mlrun.utils.v3io_clients.get_frames_client( + token=self.access_key, + address=mlrun.mlconf.v3io_framesd, + container=self.container, + ) + + # Generate the required tsdb paths + tsdb_path, filtered_path = self._generate_tsdb_paths() + + # Delete time series DB resources + try: + frames.delete( + backend=model_monitoring_constants.TimeSeriesTarget.TSDB, + table=filtered_path, + ) + except (v3io_frames.errors.DeleteError, v3io_frames.errors.CreateError) as e: + # Frames might raise an exception if schema file does not exist. + logger.warning("Failed to delete TSDB schema file:", err=e) + pass + + # Final cleanup of tsdb path + tsdb_path.replace("://u", ":///u") + store, _ = mlrun.store_manager.get_or_create_store(tsdb_path) + store.rm(tsdb_path, recursive=True) + + def get_endpoint_real_time_metrics( + self, + endpoint_id: str, + metrics: typing.List[str], + start: str = "now-1h", + end: str = "now", + access_key: str = None, + ) -> typing.Dict[str, typing.List[typing.Tuple[str, float]]]: + """ + Getting metrics from the time series DB. There are pre-defined metrics for model endpoints such as + `predictions_per_second` and `latency_avg_5m` but also custom metrics defined by the user. + + :param endpoint_id: The unique id of the model endpoint. + :param metrics: A list of real-time metrics to return for the model endpoint. + :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 + time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the + earliest time. + :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 + time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the + earliest time. + :param access_key: V3IO access key that will be used for generating Frames client object. If not + provided, the access key will be retrieved from the environment variables. + + :return: A dictionary of metrics in which the key is a metric name and the value is a list of tuples that + includes timestamps and the values. + """ + + # Initialize access key + access_key = access_key or mlrun.mlconf.get_v3io_access_key() + + if not metrics: + raise mlrun.errors.MLRunInvalidArgumentError( + "Metric names must be provided" + ) + + # Initialize metrics mapping dictionary + metrics_mapping = {} + + # Getting the path for the time series DB + events_path = ( + mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( + project=self.project, + kind=model_monitoring_constants.ModelMonitoringStoreKinds.EVENTS, + ) + ) + ( + _, + container, + events_path, + ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(events_path) + + # Retrieve the raw data from the time series DB based on the provided metrics and time ranges + frames_client = mlrun.utils.v3io_clients.get_frames_client( + token=access_key, + address=mlrun.mlconf.v3io_framesd, + container=container, + ) + + try: + data = frames_client.read( + backend=model_monitoring_constants.TimeSeriesTarget.TSDB, + table=events_path, + columns=["endpoint_id", *metrics], + filter=f"endpoint_id=='{endpoint_id}'", + start=start, + end=end, + ) + + # Fill the metrics mapping dictionary with the metric name and values + data_dict = data.to_dict() + for metric in metrics: + metric_data = data_dict.get(metric) + if metric_data is None: + continue + + values = [ + (str(timestamp), value) for timestamp, value in metric_data.items() + ] + metrics_mapping[metric] = values + + except v3io_frames.errors.ReadError: + logger.warn("Failed to read tsdb", endpoint=endpoint_id) + + return metrics_mapping + + def _generate_tsdb_paths(self) -> typing.Tuple[str, str]: + """Generate a short path to the TSDB resources and a filtered path for the frames object + :return: A tuple of: + [0] = Short path to the TSDB resources + [1] = Filtered path to TSDB events without schema and container + """ + # Full path for the time series DB events + full_path = ( + mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( + project=self.project, + kind=model_monitoring_constants.ModelMonitoringStoreKinds.EVENTS, + ) + ) + + # Generate the main directory with the TSDB resources + tsdb_path = mlrun.utils.model_monitoring.parse_model_endpoint_project_prefix( + full_path, self.project + ) + + # Generate filtered path without schema and container as required by the frames object + ( + _, + _, + filtered_path, + ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(full_path) + return tsdb_path, filtered_path + + @staticmethod + def _build_kv_cursor_filter_expression( + project: str, + function: str = None, + model: str = None, + labels: typing.List[str] = None, + top_level: bool = False, + ) -> str: + """ + Convert the provided filters into a valid filter expression. The expected filter expression includes different + conditions, divided by ' AND '. + + :param project: The name of the project. + :param model: The name of the model to filter by. + :param function: The name of the function to filter by. + :param labels: A list of labels to filter by. Label filters work by either filtering a specific value of + a label (i.e. list("key=value")) or by looking for the existence of a given + key (i.e. "key"). + :param top_level: If True will return only routers and endpoint that are NOT children of any router. + + :return: A valid filter expression as a string. + + :raise MLRunInvalidArgumentError: If project value is None. + """ + + if not project: + raise mlrun.errors.MLRunInvalidArgumentError("project can't be empty") + + # Add project filter + filter_expression = [f"project=='{project}'"] + + # Add function and model filters + if function: + filter_expression.append(f"function=='{function}'") + if model: + filter_expression.append(f"model=='{model}'") + + # Add labels filters + if labels: + for label in labels: + if not label.startswith("_"): + label = f"_{label}" + + if "=" in label: + lbl, value = list(map(lambda x: x.strip(), label.split("="))) + filter_expression.append(f"{lbl}=='{value}'") + else: + filter_expression.append(f"exists({label})") + + # Apply top_level filter (remove endpoints that considered a child of a router) + if top_level: + filter_expression.append( + f"(endpoint_type=='{str(model_monitoring_constants.EndpointType.NODE_EP.value)}' " + f"OR endpoint_type=='{str(model_monitoring_constants.EndpointType.ROUTER.value)}')" + ) + + return " AND ".join(filter_expression) diff --git a/mlrun/model_monitoring/stores/model_endpoint_store.py b/mlrun/model_monitoring/stores/model_endpoint_store.py new file mode 100644 index 000000000000..6aaa51081328 --- /dev/null +++ b/mlrun/model_monitoring/stores/model_endpoint_store.py @@ -0,0 +1,147 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import typing +from abc import ABC, abstractmethod + + +class ModelEndpointStore(ABC): + """ + An abstract class to handle the model endpoint in the DB target. + """ + + def __init__(self, project: str): + """ + Initialize a new model endpoint target. + + :param project: The name of the project. + """ + self.project = project + + @abstractmethod + def write_model_endpoint(self, endpoint: typing.Dict[str, typing.Any]): + """ + Create a new endpoint record in the DB table. + + :param endpoint: model endpoint dictionary that will be written into the DB. + """ + pass + + @abstractmethod + def update_model_endpoint( + self, endpoint_id: str, attributes: typing.Dict[str, typing.Any] + ): + """ + Update a model endpoint record with a given attributes. + + :param endpoint_id: The unique id of the model endpoint. + :param attributes: Dictionary of attributes that will be used for update the model endpoint. Note that the keys + of the attributes dictionary should exist in the DB table. + + """ + pass + + @abstractmethod + def delete_model_endpoint(self, endpoint_id: str): + """ + Deletes the record of a given model endpoint id. + + :param endpoint_id: The unique id of the model endpoint. + """ + pass + + @abstractmethod + def delete_model_endpoints_resources( + self, endpoints: typing.List[typing.Dict[str, typing.Any]] + ): + """ + Delete all model endpoints resources. + + :param endpoints: A list of model endpoints flattened dictionaries. + + """ + pass + + @abstractmethod + def get_model_endpoint( + self, + endpoint_id: str, + ) -> typing.Dict[str, typing.Any]: + """ + Get a single model endpoint record. + + :param endpoint_id: The unique id of the model endpoint. + + :return: A model endpoint record as a dictionary. + """ + pass + + @abstractmethod + def list_model_endpoints( + self, + model: str = None, + function: str = None, + labels: typing.List[str] = None, + top_level: bool = None, + uids: typing.List = None, + ) -> typing.List[typing.Dict[str, typing.Any]]: + """ + Returns a list of model endpoint dictionaries, supports filtering by model, function, labels or top level. + By default, when no filters are applied, all available model endpoints for the given project will + be listed. + + :param model: The name of the model to filter by. + :param function: The name of the function to filter by. + :param labels: A list of labels to filter by. Label filters work by either filtering a specific value + of a label (i.e. list("key=value")) or by looking for the existence of a given + key (i.e. "key"). + :param top_level: If True will return only routers and endpoint that are NOT children of any router. + :param uids: List of model endpoint unique ids to include in the result. + + :return: A list of model endpoint dictionaries. + """ + pass + + @abstractmethod + def get_endpoint_real_time_metrics( + self, + endpoint_id: str, + metrics: typing.List[str], + start: str = "now-1h", + end: str = "now", + access_key: str = None, + ) -> typing.Dict[str, typing.List[typing.Tuple[str, float]]]: + """ + Getting metrics from the time series DB. There are pre-defined metrics for model endpoints such as + `predictions_per_second` and `latency_avg_5m` but also custom metrics defined by the user. + + :param endpoint_id: The unique id of the model endpoint. + :param metrics: A list of real-time metrics to return for the model endpoint. + :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 + time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the + earliest time. + :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 + time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the + earliest time. + :param access_key: V3IO access key that will be used for generating Frames client object. If not + provided, the access key will be retrieved from the environment variables. + + :return: A dictionary of metrics in which the key is a metric name and the value is a list of tuples that + includes timestamps and the values. + """ + + pass diff --git a/mlrun/model_monitoring/stores/models/__init__.py b/mlrun/model_monitoring/stores/models/__init__.py new file mode 100644 index 000000000000..4329738e5bbb --- /dev/null +++ b/mlrun/model_monitoring/stores/models/__init__.py @@ -0,0 +1,23 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +def get_ModelEndpointsTable(connection_string: str = None): + """Return ModelEndpointsTable based on the provided connection string""" + if "mysql:" in connection_string: + from .mysql import ModelEndpointsTable + else: + from .sqlite import ModelEndpointsTable + return ModelEndpointsTable diff --git a/mlrun/model_monitoring/stores/models/base.py b/mlrun/model_monitoring/stores/models/base.py new file mode 100644 index 000000000000..ad3a65122cbc --- /dev/null +++ b/mlrun/model_monitoring/stores/models/base.py @@ -0,0 +1,18 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from sqlalchemy.ext.declarative import declarative_base + +Base = declarative_base() diff --git a/mlrun/model_monitoring/stores/models/mysql.py b/mlrun/model_monitoring/stores/models/mysql.py new file mode 100644 index 000000000000..d9edc57583b1 --- /dev/null +++ b/mlrun/model_monitoring/stores/models/mysql.py @@ -0,0 +1,100 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +import sqlalchemy.dialects +from sqlalchemy import Boolean, Column, Integer, String, Text + +import mlrun.common.model_monitoring as model_monitoring_constants +from mlrun.utils.db import BaseModel + +from .base import Base + + +class ModelEndpointsTable(Base, BaseModel): + __tablename__ = model_monitoring_constants.EventFieldType.MODEL_ENDPOINTS + + uid = Column( + model_monitoring_constants.EventFieldType.UID, + String(40), + primary_key=True, + ) + state = Column(model_monitoring_constants.EventFieldType.STATE, String(10)) + project = Column(model_monitoring_constants.EventFieldType.PROJECT, String(40)) + function_uri = Column( + model_monitoring_constants.EventFieldType.FUNCTION_URI, + String(255), + ) + model = Column(model_monitoring_constants.EventFieldType.MODEL, String(255)) + model_class = Column( + model_monitoring_constants.EventFieldType.MODEL_CLASS, + String(255), + ) + labels = Column(model_monitoring_constants.EventFieldType.LABELS, Text) + model_uri = Column(model_monitoring_constants.EventFieldType.MODEL_URI, String(255)) + stream_path = Column(model_monitoring_constants.EventFieldType.STREAM_PATH, Text) + algorithm = Column( + model_monitoring_constants.EventFieldType.ALGORITHM, + String(255), + ) + active = Column(model_monitoring_constants.EventFieldType.ACTIVE, Boolean) + monitoring_mode = Column( + model_monitoring_constants.EventFieldType.MONITORING_MODE, + String(10), + ) + feature_stats = Column( + model_monitoring_constants.EventFieldType.FEATURE_STATS, Text + ) + current_stats = Column( + model_monitoring_constants.EventFieldType.CURRENT_STATS, Text + ) + feature_names = Column( + model_monitoring_constants.EventFieldType.FEATURE_NAMES, Text + ) + children = Column(model_monitoring_constants.EventFieldType.CHILDREN, Text) + label_names = Column(model_monitoring_constants.EventFieldType.LABEL_NAMES, Text) + + endpoint_type = Column( + model_monitoring_constants.EventFieldType.ENDPOINT_TYPE, + String(10), + ) + children_uids = Column( + model_monitoring_constants.EventFieldType.CHILDREN_UIDS, Text + ) + drift_measures = Column( + model_monitoring_constants.EventFieldType.DRIFT_MEASURES, Text + ) + drift_status = Column( + model_monitoring_constants.EventFieldType.DRIFT_STATUS, + String(40), + ) + monitor_configuration = Column( + model_monitoring_constants.EventFieldType.MONITOR_CONFIGURATION, + Text, + ) + monitoring_feature_set_uri = Column( + model_monitoring_constants.EventFieldType.FEATURE_SET_URI, + String(255), + ) + first_request = Column( + model_monitoring_constants.EventFieldType.FIRST_REQUEST, + sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3), + ) + last_request = Column( + model_monitoring_constants.EventFieldType.LAST_REQUEST, + sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3), + ) + error_count = Column(model_monitoring_constants.EventFieldType.ERROR_COUNT, Integer) + metrics = Column(model_monitoring_constants.EventFieldType.METRICS, Text) diff --git a/mlrun/model_monitoring/stores/models/sqlite.py b/mlrun/model_monitoring/stores/models/sqlite.py new file mode 100644 index 000000000000..e790b50d6925 --- /dev/null +++ b/mlrun/model_monitoring/stores/models/sqlite.py @@ -0,0 +1,98 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from sqlalchemy import TIMESTAMP, Boolean, Column, Integer, String, Text + +import mlrun.common.model_monitoring as model_monitoring_constants +from mlrun.utils.db import BaseModel + +from .base import Base + + +class ModelEndpointsTable(Base, BaseModel): + __tablename__ = model_monitoring_constants.EventFieldType.MODEL_ENDPOINTS + + uid = Column( + model_monitoring_constants.EventFieldType.UID, + String(40), + primary_key=True, + ) + state = Column(model_monitoring_constants.EventFieldType.STATE, String(10)) + project = Column(model_monitoring_constants.EventFieldType.PROJECT, String(40)) + function_uri = Column( + model_monitoring_constants.EventFieldType.FUNCTION_URI, + String(255), + ) + model = Column(model_monitoring_constants.EventFieldType.MODEL, String(255)) + model_class = Column( + model_monitoring_constants.EventFieldType.MODEL_CLASS, + String(255), + ) + labels = Column(model_monitoring_constants.EventFieldType.LABELS, Text) + model_uri = Column(model_monitoring_constants.EventFieldType.MODEL_URI, String(255)) + stream_path = Column(model_monitoring_constants.EventFieldType.STREAM_PATH, Text) + algorithm = Column( + model_monitoring_constants.EventFieldType.ALGORITHM, + String(255), + ) + active = Column(model_monitoring_constants.EventFieldType.ACTIVE, Boolean) + monitoring_mode = Column( + model_monitoring_constants.EventFieldType.MONITORING_MODE, + String(10), + ) + feature_stats = Column( + model_monitoring_constants.EventFieldType.FEATURE_STATS, Text + ) + current_stats = Column( + model_monitoring_constants.EventFieldType.CURRENT_STATS, Text + ) + feature_names = Column( + model_monitoring_constants.EventFieldType.FEATURE_NAMES, Text + ) + children = Column(model_monitoring_constants.EventFieldType.CHILDREN, Text) + label_names = Column(model_monitoring_constants.EventFieldType.LABEL_NAMES, Text) + endpoint_type = Column( + model_monitoring_constants.EventFieldType.ENDPOINT_TYPE, + String(10), + ) + children_uids = Column( + model_monitoring_constants.EventFieldType.CHILDREN_UIDS, Text + ) + drift_measures = Column( + model_monitoring_constants.EventFieldType.DRIFT_MEASURES, Text + ) + drift_status = Column( + model_monitoring_constants.EventFieldType.DRIFT_STATUS, + String(40), + ) + monitor_configuration = Column( + model_monitoring_constants.EventFieldType.MONITOR_CONFIGURATION, + Text, + ) + monitoring_feature_set_uri = Column( + model_monitoring_constants.EventFieldType.FEATURE_SET_URI, + String(255), + ) + first_request = Column( + model_monitoring_constants.EventFieldType.FIRST_REQUEST, + TIMESTAMP, + ) + last_request = Column( + model_monitoring_constants.EventFieldType.LAST_REQUEST, + TIMESTAMP, + ) + error_count = Column(model_monitoring_constants.EventFieldType.ERROR_COUNT, Integer) + metrics = Column(model_monitoring_constants.EventFieldType.METRICS, Text) diff --git a/mlrun/model_monitoring/stores/sql_model_endpoint_store.py b/mlrun/model_monitoring/stores/sql_model_endpoint_store.py new file mode 100644 index 000000000000..fc69c4ffffe7 --- /dev/null +++ b/mlrun/model_monitoring/stores/sql_model_endpoint_store.py @@ -0,0 +1,370 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import typing +from datetime import datetime, timezone + +import pandas as pd +import sqlalchemy as db + +import mlrun +import mlrun.common.model_monitoring as model_monitoring_constants +import mlrun.model_monitoring.model_endpoint +import mlrun.utils.model_monitoring +import mlrun.utils.v3io_clients +from mlrun.api.db.sqldb.session import create_session, get_engine +from mlrun.utils import logger + +from .model_endpoint_store import ModelEndpointStore +from .models import get_ModelEndpointsTable +from .models.base import Base + + +class SQLModelEndpointStore(ModelEndpointStore): + + """ + Handles the DB operations when the DB target is from type SQL. For the SQL operations, we use SQLAlchemy, a Python + SQL toolkit that handles the communication with the database. When using SQL for storing the model endpoints + record, the user needs to provide a valid connection string for the database. + """ + + _engine = None + + def __init__( + self, + project: str, + sql_connection_string: str = None, + ): + """ + Initialize SQL store target object. + + :param project: The name of the project. + :param sql_connection_string: Valid connection string or a path to SQL database with model endpoints table. + """ + + super().__init__(project=project) + + self.sql_connection_string = ( + sql_connection_string + or mlrun.utils.model_monitoring.get_connection_string(project=self.project) + ) + + self.table_name = model_monitoring_constants.EventFieldType.MODEL_ENDPOINTS + + self._engine = get_engine(dsn=self.sql_connection_string) + self.ModelEndpointsTable = get_ModelEndpointsTable( + connection_string=self.sql_connection_string + ) + # Create table if not exist. The `metadata` contains the `ModelEndpointsTable` + if not self._engine.has_table(self.table_name): + Base.metadata.create_all(bind=self._engine) + self.model_endpoints_table = self.ModelEndpointsTable.__table__ + + def write_model_endpoint(self, endpoint: typing.Dict[str, typing.Any]): + """ + Create a new endpoint record in the SQL table. This method also creates the model endpoints table within the + SQL database if not exist. + + :param endpoint: model endpoint dictionary that will be written into the DB. + """ + + with self._engine.connect() as connection: + # Adjust timestamps fields + endpoint[ + model_monitoring_constants.EventFieldType.FIRST_REQUEST + ] = datetime.now(timezone.utc) + endpoint[ + model_monitoring_constants.EventFieldType.LAST_REQUEST + ] = datetime.now(timezone.utc) + + # Convert the result into a pandas Dataframe and write it into the database + endpoint_df = pd.DataFrame([endpoint]) + + endpoint_df.to_sql( + self.table_name, con=connection, index=False, if_exists="append" + ) + + def update_model_endpoint( + self, endpoint_id: str, attributes: typing.Dict[str, typing.Any] + ): + """ + Update a model endpoint record with a given attributes. + + :param endpoint_id: The unique id of the model endpoint. + :param attributes: Dictionary of attributes that will be used for update the model endpoint. Note that the keys + of the attributes dictionary should exist in the SQL table. + + """ + + # Update the model endpoint record using sqlalchemy ORM + with create_session(dsn=self.sql_connection_string) as session: + # Remove endpoint id (foreign key) from the update query + attributes.pop(model_monitoring_constants.EventFieldType.ENDPOINT_ID, None) + + # Generate and commit the update session query + session.query(self.ModelEndpointsTable).filter( + self.ModelEndpointsTable.uid == endpoint_id + ).update(attributes) + session.commit() + + def delete_model_endpoint(self, endpoint_id: str): + """ + Deletes the SQL record of a given model endpoint id. + + :param endpoint_id: The unique id of the model endpoint. + """ + + # Delete the model endpoint record using sqlalchemy ORM + with create_session(dsn=self.sql_connection_string) as session: + # Generate and commit the delete query + session.query(self.ModelEndpointsTable).filter_by(uid=endpoint_id).delete() + session.commit() + + def get_model_endpoint( + self, + endpoint_id: str, + ) -> typing.Dict[str, typing.Any]: + """ + Get a single model endpoint record. + + :param endpoint_id: The unique id of the model endpoint. + + :return: A model endpoint record as a dictionary. + + :raise MLRunNotFoundError: If the model endpoints table was not found or the model endpoint id was not found. + """ + + # Get the model endpoint record using sqlalchemy ORM + with create_session(dsn=self.sql_connection_string) as session: + # Generate the get query + endpoint_record = ( + session.query(self.ModelEndpointsTable) + .filter_by(uid=endpoint_id) + .one_or_none() + ) + + if not endpoint_record: + raise mlrun.errors.MLRunNotFoundError(f"Endpoint {endpoint_id} not found") + + # Convert the database values and the table columns into a python dictionary + return endpoint_record.to_dict() + + def list_model_endpoints( + self, + model: str = None, + function: str = None, + labels: typing.List[str] = None, + top_level: bool = None, + uids: typing.List = None, + ) -> typing.List[typing.Dict[str, typing.Any]]: + """ + Returns a list of model endpoint dictionaries, supports filtering by model, function, labels or top level. + By default, when no filters are applied, all available model endpoints for the given project will + be listed. + + :param model: The name of the model to filter by. + :param function: The name of the function to filter by. + :param labels: A list of labels to filter by. Label filters work by either filtering a specific value + of a label (i.e. list("key=value")) or by looking for the existence of a given + key (i.e. "key"). + :param top_level: If True will return only routers and endpoint that are NOT children of any router. + :param uids: List of model endpoint unique ids to include in the result. + + :return: A list of model endpoint dictionaries. + """ + + # Generate an empty model endpoints that will be filled afterwards with model endpoint dictionaries + endpoint_list = [] + + # Get the model endpoints records using sqlalchemy ORM + with create_session(dsn=self.sql_connection_string) as session: + # Generate the list query + query = session.query(self.ModelEndpointsTable).filter_by( + project=self.project + ) + + # Apply filters + if model: + query = self._filter_values( + query=query, + model_endpoints_table=self.model_endpoints_table, + key_filter=model_monitoring_constants.EventFieldType.MODEL, + filtered_values=[model], + ) + if function: + query = self._filter_values( + query=query, + model_endpoints_table=self.model_endpoints_table, + key_filter=model_monitoring_constants.EventFieldType.FUNCTION, + filtered_values=[function], + ) + if uids: + query = self._filter_values( + query=query, + model_endpoints_table=self.model_endpoints_table, + key_filter=model_monitoring_constants.EventFieldType.UID, + filtered_values=uids, + combined=False, + ) + if top_level: + node_ep = str(mlrun.common.model_monitoring.EndpointType.NODE_EP.value) + router_ep = str(mlrun.common.model_monitoring.EndpointType.ROUTER.value) + endpoint_types = [node_ep, router_ep] + query = self._filter_values( + query=query, + model_endpoints_table=self.model_endpoints_table, + key_filter=model_monitoring_constants.EventFieldType.ENDPOINT_TYPE, + filtered_values=endpoint_types, + combined=False, + ) + # Convert the results from the DB into a ModelEndpoint object and append it to the model endpoints list + for endpoint_record in query.all(): + endpoint_dict = endpoint_record.to_dict() + + # Filter labels + if labels and not self._validate_labels( + endpoint_dict=endpoint_dict, labels=labels + ): + continue + + endpoint_list.append(endpoint_dict) + + return endpoint_list + + @staticmethod + def _filter_values( + query: db.orm.query.Query, + model_endpoints_table: db.Table, + key_filter: str, + filtered_values: typing.List, + combined=True, + ) -> db.orm.query.Query: + """Filtering the SQL query object according to the provided filters. + + :param query: SQLAlchemy ORM query object. Includes the SELECT statements generated by the ORM + for getting the model endpoint data from the SQL table. + :param model_endpoints_table: SQLAlchemy table object that represents the model endpoints table. + :param key_filter: Key column to filter by. + :param filtered_values: List of values to filter the query the result. + :param combined: If true, then apply AND operator on the filtered values list. Otherwise, apply OR + operator. + + return: SQLAlchemy ORM query object that represents the updated query with the provided + filters. + """ + + if combined and len(filtered_values) > 1: + raise mlrun.errors.MLRunInvalidArgumentError( + "Can't apply combined policy with multiple values" + ) + + if not combined: + return query.filter( + model_endpoints_table.c[key_filter].in_(filtered_values) + ) + + # Generating a tuple with the relevant filters + filter_query = [] + for _filter in filtered_values: + filter_query.append(model_endpoints_table.c[key_filter] == _filter) + + # Apply AND operator on the SQL query object with the filters tuple + return query.filter(db.and_(*filter_query)) + + @staticmethod + def _validate_labels( + endpoint_dict: dict, + labels: typing.List, + ) -> bool: + """Validate that the model endpoint dictionary has the provided labels. There are 2 possible cases: + 1 - Labels were provided as a list of key-values pairs (e.g. ['label_1=value_1', 'label_2=value_2']): Validate + that each pair exist in the endpoint dictionary. + 2 - Labels were provided as a list of key labels (e.g. ['label_1', 'label_2']): Validate that each key exist in + the endpoint labels dictionary. + + :param endpoint_dict: Dictionary of the model endpoint records. + :param labels: List of dictionary of required labels. + + :return: True if the labels exist in the endpoint labels dictionary, otherwise False. + """ + + # Convert endpoint labels into dictionary + endpoint_labels = json.loads( + endpoint_dict.get(model_monitoring_constants.EventFieldType.LABELS) + ) + + for label in labels: + # Case 1 - label is a key=value pair + if "=" in label: + lbl, value = list(map(lambda x: x.strip(), label.split("="))) + if lbl not in endpoint_labels or str(endpoint_labels[lbl]) != value: + return False + # Case 2 - label is just a key + else: + if label not in endpoint_labels: + return False + + return True + + def delete_model_endpoints_resources( + self, endpoints: typing.List[typing.Dict[str, typing.Any]] + ): + """ + Delete all model endpoints resources in both SQL and the time series DB. + + :param endpoints: A list of model endpoints flattened dictionaries. + """ + + for endpoint_dict in endpoints: + # Delete model endpoint record from SQL table + self.delete_model_endpoint( + endpoint_dict[model_monitoring_constants.EventFieldType.UID], + ) + + def get_endpoint_real_time_metrics( + self, + endpoint_id: str, + metrics: typing.List[str], + start: str = "now-1h", + end: str = "now", + access_key: str = None, + ) -> typing.Dict[str, typing.List[typing.Tuple[str, float]]]: + """ + Getting metrics from the time series DB. There are pre-defined metrics for model endpoints such as + `predictions_per_second` and `latency_avg_5m` but also custom metrics defined by the user. + + :param endpoint_id: The unique id of the model endpoint. + :param metrics: A list of real-time metrics to return for the model endpoint. + :param start: The start time of the metrics. Can be represented by a string containing an RFC 3339 + time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the + earliest time. + :param end: The end time of the metrics. Can be represented by a string containing an RFC 3339 + time, a Unix timestamp in milliseconds, a relative time (`'now'` or + `'now-[0-9]+[mhd]'`, where `m` = minutes, `h` = hours, and `'d'` = days), or 0 for the + earliest time. + :param access_key: V3IO access key that will be used for generating Frames client object. If not + provided, the access key will be retrieved from the environment variables. + + :return: A dictionary of metrics in which the key is a metric name and the value is a list of tuples that + includes timestamps and the values. + """ + # # TODO : Implement this method once Perometheus is supported + logger.warning( + "Real time metrics service using Prometheus will be implemented in 1.4.0" + ) + + return {} diff --git a/mlrun/model_monitoring/stream_processing_fs.py b/mlrun/model_monitoring/stream_processing_fs.py index ff2c2a0f854e..fa5ff0d4a253 100644 --- a/mlrun/model_monitoring/stream_processing_fs.py +++ b/mlrun/model_monitoring/stream_processing_fs.py @@ -19,23 +19,25 @@ import typing import pandas as pd - -# Constants import storey -import v3io -import v3io.dataplane +import mlrun +import mlrun.common.model_monitoring import mlrun.config import mlrun.datastore.targets import mlrun.feature_store.steps import mlrun.utils import mlrun.utils.model_monitoring import mlrun.utils.v3io_clients -from mlrun.model_monitoring.constants import ( +from mlrun.common.model_monitoring import ( EventFieldType, EventKeyMetrics, EventLiveStats, + FileTargetKind, + ModelEndpointTarget, + ProjectSecretKeys, ) +from mlrun.model_monitoring.stores import get_model_endpoint_store from mlrun.utils import logger @@ -45,81 +47,90 @@ def __init__( self, project: str, parquet_batching_max_events: int, + parquet_target: str, sample_window: int = 10, - tsdb_batching_max_events: int = 10, - tsdb_batching_timeout_secs: int = 60 * 5, # Default 5 minutes parquet_batching_timeout_secs: int = 30 * 60, # Default 30 minutes aggregate_count_windows: typing.Optional[typing.List[str]] = None, aggregate_count_period: str = "30s", aggregate_avg_windows: typing.Optional[typing.List[str]] = None, aggregate_avg_period: str = "30s", - v3io_access_key: typing.Optional[str] = None, - v3io_framesd: typing.Optional[str] = None, - v3io_api: typing.Optional[str] = None, model_monitoring_access_key: str = None, ): + # General configurations, mainly used for the storey steps in the future serving graph self.project = project self.sample_window = sample_window - self.tsdb_batching_max_events = tsdb_batching_max_events - self.tsdb_batching_timeout_secs = tsdb_batching_timeout_secs - self.parquet_batching_max_events = parquet_batching_max_events - self.parquet_batching_timeout_secs = parquet_batching_timeout_secs self.aggregate_count_windows = aggregate_count_windows or ["5m", "1h"] self.aggregate_count_period = aggregate_count_period self.aggregate_avg_windows = aggregate_avg_windows or ["5m", "1h"] self.aggregate_avg_period = aggregate_avg_period + # Parquet path and configurations + self.parquet_path = parquet_target + self.parquet_batching_max_events = parquet_batching_max_events + self.parquet_batching_timeout_secs = parquet_batching_timeout_secs + + self.model_endpoint_store_target = ( + mlrun.mlconf.model_endpoint_monitoring.store_type + ) + + logger.info( + "Initializing model monitoring event stream processor", + parquet_path=self.parquet_path, + parquet_batching_max_events=self.parquet_batching_max_events, + ) + + self.storage_options = None + if not mlrun.mlconf.is_ce_mode(): + self._initialize_v3io_configurations( + model_monitoring_access_key=model_monitoring_access_key + ) + + def _initialize_v3io_configurations( + self, + tsdb_batching_max_events: int = 10, + tsdb_batching_timeout_secs: int = 60 * 5, # Default 5 minutes + v3io_access_key: typing.Optional[str] = None, + v3io_framesd: typing.Optional[str] = None, + v3io_api: typing.Optional[str] = None, + model_monitoring_access_key: str = None, + ): + # Get the V3IO configurations self.v3io_framesd = v3io_framesd or mlrun.mlconf.v3io_framesd self.v3io_api = v3io_api or mlrun.mlconf.v3io_api self.v3io_access_key = v3io_access_key or os.environ.get("V3IO_ACCESS_KEY") self.model_monitoring_access_key = ( model_monitoring_access_key - or os.environ.get("MODEL_MONITORING_ACCESS_KEY") + or os.environ.get(ProjectSecretKeys.ACCESS_KEY) or self.v3io_access_key ) self.storage_options = dict( v3io_access_key=self.model_monitoring_access_key, v3io_api=self.v3io_api ) - template = mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default - - kv_path = template.format(project=project, kind="endpoints") + # KV path + kv_path = mlrun.mlconf.get_model_monitoring_file_target_path( + project=self.project, kind=FileTargetKind.ENDPOINTS + ) ( _, self.kv_container, self.kv_path, ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(kv_path) - tsdb_path = template.format(project=project, kind="events") + # TSDB path and configurations + tsdb_path = mlrun.mlconf.get_model_monitoring_file_target_path( + project=self.project, kind=FileTargetKind.EVENTS + ) ( _, self.tsdb_container, self.tsdb_path, ) = mlrun.utils.model_monitoring.parse_model_endpoint_store_prefix(tsdb_path) - self.tsdb_path = f"{self.tsdb_container}/{self.tsdb_path}" - self.parquet_path = ( - mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space.format( - project=project, kind="parquet" - ) - ) - - logger.info( - "Initializing model monitoring event stream processor", - parquet_batching_max_events=self.parquet_batching_max_events, - v3io_access_key=self.v3io_access_key, - model_monitoring_access_key=self.model_monitoring_access_key, - default_store_prefix=mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default, - user_space_store_prefix=mlrun.mlconf.model_endpoint_monitoring.store_prefixes.user_space, - v3io_api=self.v3io_api, - v3io_framesd=self.v3io_framesd, - kv_container=self.kv_container, - kv_path=self.kv_path, - tsdb_container=self.tsdb_container, - tsdb_path=self.tsdb_path, - parquet_path=self.parquet_path, - ) + self.tsdb_path = f"{self.tsdb_container}/{self.tsdb_path}" + self.tsdb_batching_max_events = tsdb_batching_max_events + self.tsdb_batching_timeout_secs = tsdb_batching_timeout_secs def apply_monitoring_serving_graph(self, fn): """ @@ -127,20 +138,23 @@ def apply_monitoring_serving_graph(self, fn): of different operations that are executed on the events from the model server. Each event has metadata (function_uri, timestamp, class, etc.) but also inputs and predictions from the model server. Throughout the serving graph, the results are written to 3 different databases: - 1. KV (steps 7-9): Stores metadata and stats about the average latency and the amount of predictions over time - per endpoint. for example the amount of predictions of endpoint x in the last 5 min. This data is used by - the monitoring dashboards in grafana. Please note that the KV table, which can be found under - v3io:///users/pipelines/project-name/model-endpoints/endpoints/ also contains data on the model endpoint - from other processes, such as current_stats that is being calculated by the monitoring batch job - process. + 1. KV/SQL (steps 7-9): Stores metadata and stats about the average latency and the amount of predictions over + time per endpoint. for example the amount of predictions of endpoint x in the last 5 min. This data is used + by the monitoring dashboards in grafana. The model endpoints table also contains data on the model endpoint + from other processes, such as current_stats that is being calculated by the monitoring batch job + process. If the target is from type KV, then the model endpoints table can be found under + v3io:///users/pipelines/project-name/model-endpoints/endpoints/. If the target is SQL, then the table + is stored within the database that was defined in the provided connection string and can be found + under mlrun.mlconf.model_endpoint_monitoring.endpoint_store_connection. 2. TSDB (steps 12-18): Stores live data of different key metric dictionaries in tsdb target. Results can be found under v3io:///users/pipelines/project-name/model-endpoints/events/. At the moment, this part supports 3 different key metric dictionaries: base_metrics (average latency and predictions over time), endpoint_features (Prediction and feature names and values), and custom_metrics (user-defined metrics). This data is also being used by the monitoring dashboards in grafana. 3. Parquet (steps 19-20): This Parquet file includes the required data for the model monitoring batch job - that run every hour by default. The parquet target can be found under - v3io:///projects/{project}/model-endpoints/. + that run every hour by default. If defined, the parquet target path can be found under + mlrun.mlconf.model_endpoint_monitoring.offline. Otherwise, the default parquet path is under + mlrun.mlconf.model_endpoint_monitoring.user_space. :param fn: A serving function. """ @@ -151,9 +165,6 @@ def apply_monitoring_serving_graph(self, fn): def apply_process_endpoint_event(): graph.add_step( "ProcessEndpointEvent", - kv_container=self.kv_container, - kv_path=self.kv_path, - v3io_access_key=self.v3io_access_key, full_event=True, project=self.project, ) @@ -182,10 +193,8 @@ def apply_map_feature_names(): graph.add_step( "MapFeatureNames", name="MapFeatureNames", - kv_container=self.kv_container, - kv_path=self.kv_path, - access_key=self.v3io_access_key, infer_columns_from_data=True, + project=self.project, after="flatten_events", ) @@ -209,7 +218,6 @@ def apply_storey_aggregations(): after="MapFeatureNames", step_name="Aggregates", table=".", - v3io_access_key=self.v3io_access_key, ) # Step 5.2 - Calculate average latency time for each window (5 min and 1 hour by default) graph.add_step( @@ -226,7 +234,6 @@ def apply_storey_aggregations(): name=EventFieldType.LATENCY, after=EventFieldType.PREDICTIONS, table=".", - v3io_access_key=self.v3io_access_key, ) apply_storey_aggregations() @@ -239,117 +246,121 @@ def apply_storey_sample_window(): after=EventFieldType.LATENCY, window_size=self.sample_window, key=EventFieldType.ENDPOINT_ID, - v3io_access_key=self.v3io_access_key, ) apply_storey_sample_window() - # Steps 7-9 - KV branch - # Step 7 - Filter relevant keys from the event before writing the data into KV - def apply_process_before_kv(): - graph.add_step("ProcessBeforeKV", name="ProcessBeforeKV", after="sample") + # Steps 7-9 - KV/SQL branch + # Step 7 - Filter relevant keys from the event before writing the data into the database table + def apply_process_before_endpoint_update(): + graph.add_step( + "ProcessBeforeEndpointUpdate", + name="ProcessBeforeEndpointUpdate", + after="sample", + ) - apply_process_before_kv() + apply_process_before_endpoint_update() - # Step 8 - Write the filtered event to KV table. At this point, the serving graph updates the stats + # Step 8 - Write the filtered event to KV/SQL table. At this point, the serving graph updates the stats # about average latency and the amount of predictions over time - def apply_write_to_kv(): + def apply_update_endpoint(): graph.add_step( - "WriteToKV", - name="WriteToKV", - after="ProcessBeforeKV", - container=self.kv_container, - table=self.kv_path, - v3io_access_key=self.v3io_access_key, + "UpdateEndpoint", + name="UpdateEndpoint", + after="ProcessBeforeEndpointUpdate", + project=self.project, + model_endpoint_store_target=self.model_endpoint_store_target, ) - apply_write_to_kv() + apply_update_endpoint() - # Step 9 - Apply infer_schema on the KB table for generating schema file + # Step 9 (only for KV target) - Apply infer_schema on the model endpoints table for generating schema file # which will be used by Grafana monitoring dashboards def apply_infer_schema(): graph.add_step( "InferSchema", name="InferSchema", - after="WriteToKV", - v3io_access_key=self.v3io_access_key, + after="UpdateEndpoint", v3io_framesd=self.v3io_framesd, container=self.kv_container, table=self.kv_path, ) - apply_infer_schema() + if self.model_endpoint_store_target == ModelEndpointTarget.V3IO_NOSQL: + apply_infer_schema() - # Steps 11-18 - TSDB branch - # Step 11 - Before writing data to TSDB, create dictionary of 2-3 dictionaries that contains - # stats and details about the events - def apply_process_before_tsdb(): - graph.add_step( - "ProcessBeforeTSDB", name="ProcessBeforeTSDB", after="sample" - ) + # Steps 11-18 - TSDB branch (not supported in CE environment at the moment) - apply_process_before_tsdb() + if not mlrun.mlconf.is_ce_mode(): + # Step 11 - Before writing data to TSDB, create dictionary of 2-3 dictionaries that contains + # stats and details about the events + def apply_process_before_tsdb(): + graph.add_step( + "ProcessBeforeTSDB", name="ProcessBeforeTSDB", after="sample" + ) - # Steps 12-18: - Unpacked keys from each dictionary and write to TSDB target - def apply_filter_and_unpacked_keys(name, keys): - graph.add_step( - "FilterAndUnpackKeys", - name=name, - after="ProcessBeforeTSDB", - keys=[keys], - ) + apply_process_before_tsdb() - def apply_tsdb_target(name, after): - graph.add_step( - "storey.TSDBTarget", - name=name, - after=after, - path=self.tsdb_path, - rate="10/m", - time_col=EventFieldType.TIMESTAMP, - container=self.tsdb_container, - access_key=self.v3io_access_key, - v3io_frames=self.v3io_framesd, - infer_columns_from_data=True, - index_cols=[ - EventFieldType.ENDPOINT_ID, - EventFieldType.RECORD_TYPE, - ], - max_events=self.tsdb_batching_max_events, - flush_after_seconds=self.tsdb_batching_timeout_secs, - key=EventFieldType.ENDPOINT_ID, - ) + # Steps 12-18: - Unpacked keys from each dictionary and write to TSDB target + def apply_filter_and_unpacked_keys(name, keys): + graph.add_step( + "FilterAndUnpackKeys", + name=name, + after="ProcessBeforeTSDB", + keys=[keys], + ) - # Steps 12-13 - unpacked base_metrics dictionary - apply_filter_and_unpacked_keys( - name="FilterAndUnpackKeys1", - keys=EventKeyMetrics.BASE_METRICS, - ) - apply_tsdb_target(name="tsdb1", after="FilterAndUnpackKeys1") + def apply_tsdb_target(name, after): + graph.add_step( + "storey.TSDBTarget", + name=name, + after=after, + path=self.tsdb_path, + rate="10/m", + time_col=EventFieldType.TIMESTAMP, + container=self.tsdb_container, + access_key=self.v3io_access_key, + v3io_frames=self.v3io_framesd, + infer_columns_from_data=True, + index_cols=[ + EventFieldType.ENDPOINT_ID, + EventFieldType.RECORD_TYPE, + ], + max_events=self.tsdb_batching_max_events, + flush_after_seconds=self.tsdb_batching_timeout_secs, + key=EventFieldType.ENDPOINT_ID, + ) - # Steps 14-15 - unpacked endpoint_features dictionary - apply_filter_and_unpacked_keys( - name="FilterAndUnpackKeys2", - keys=EventKeyMetrics.ENDPOINT_FEATURES, - ) - apply_tsdb_target(name="tsdb2", after="FilterAndUnpackKeys2") + # Steps 12-13 - unpacked base_metrics dictionary + apply_filter_and_unpacked_keys( + name="FilterAndUnpackKeys1", + keys=EventKeyMetrics.BASE_METRICS, + ) + apply_tsdb_target(name="tsdb1", after="FilterAndUnpackKeys1") - # Steps 16-18 - unpacked custom_metrics dictionary. In addition, use storey.Filter remove none values - apply_filter_and_unpacked_keys( - name="FilterAndUnpackKeys3", - keys=EventKeyMetrics.CUSTOM_METRICS, - ) + # Steps 14-15 - unpacked endpoint_features dictionary + apply_filter_and_unpacked_keys( + name="FilterAndUnpackKeys2", + keys=EventKeyMetrics.ENDPOINT_FEATURES, + ) + apply_tsdb_target(name="tsdb2", after="FilterAndUnpackKeys2") - def apply_storey_filter(): - graph.add_step( - "storey.Filter", - "FilterNotNone", - after="FilterAndUnpackKeys3", - _fn="(event is not None)", + # Steps 16-18 - unpacked custom_metrics dictionary. In addition, use storey.Filter remove none values + apply_filter_and_unpacked_keys( + name="FilterAndUnpackKeys3", + keys=EventKeyMetrics.CUSTOM_METRICS, ) - apply_storey_filter() - apply_tsdb_target(name="tsdb3", after="FilterNotNone") + def apply_storey_filter(): + graph.add_step( + "storey.Filter", + "FilterNotNone", + after="FilterAndUnpackKeys3", + _fn="(event is not None)", + ) + + apply_storey_filter() + apply_tsdb_target(name="tsdb3", after="FilterNotNone") # Steps 19-20 - Parquet branch # Step 19 - Filter and validate different keys before writing the data to Parquet target @@ -384,19 +395,18 @@ def apply_parquet_target(): apply_parquet_target() -class ProcessBeforeKV(mlrun.feature_store.steps.MapClass): +class ProcessBeforeEndpointUpdate(mlrun.feature_store.steps.MapClass): def __init__(self, **kwargs): """ - Filter relevant keys from the event before writing the data to KV table (in WriteToKV step). Note that in KV - we only keep metadata (function_uri, model_class, etc.) and stats about the average latency and the number - of predictions (per 5min and 1hour). + Filter relevant keys from the event before writing the data to database table (in EndpointUpdate step). + Note that in the endpoint table we only keep metadata (function_uri, model_class, etc.) and stats about the + average latency and the number of predictions (per 5min and 1hour). - :returns: A filtered event as a dictionary which will be written to KV table in the next step. + :returns: A filtered event as a dictionary which will be written to the endpoint table in the next step. """ super().__init__(**kwargs) def do(self, event): - # Compute prediction per second event[EventLiveStats.PREDICTIONS_PER_SECOND] = ( float(event[EventLiveStats.PREDICTIONS_COUNT_5M]) / 300 @@ -408,26 +418,31 @@ def do(self, event): EventFieldType.FUNCTION_URI, EventFieldType.MODEL, EventFieldType.MODEL_CLASS, - EventFieldType.TIMESTAMP, EventFieldType.ENDPOINT_ID, EventFieldType.LABELS, - EventFieldType.UNPACKED_LABELS, + EventFieldType.FIRST_REQUEST, + EventFieldType.LAST_REQUEST, + EventFieldType.ERROR_COUNT, + ] + } + + # Add generic metrics statistics + generic_metrics = { + k: event[k] + for k in [ EventLiveStats.LATENCY_AVG_5M, EventLiveStats.LATENCY_AVG_1H, EventLiveStats.PREDICTIONS_PER_SECOND, EventLiveStats.PREDICTIONS_COUNT_5M, EventLiveStats.PREDICTIONS_COUNT_1H, - EventFieldType.FIRST_REQUEST, - EventFieldType.LAST_REQUEST, - EventFieldType.ERROR_COUNT, ] } - # Unpack labels dictionary - e = { - **e.pop(EventFieldType.UNPACKED_LABELS, {}), - **e, - } - # Write labels to kv as json string to be presentable later + + e[EventFieldType.METRICS] = json.dumps( + {EventKeyMetrics.GENERIC: generic_metrics} + ) + + # Write labels as json string as required by the DB format e[EventFieldType.LABELS] = json.dumps(e[EventFieldType.LABELS]) return e @@ -449,7 +464,6 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def do(self, event): - # Compute prediction per second event[EventLiveStats.PREDICTIONS_PER_SECOND] = ( float(event[EventLiveStats.PREDICTIONS_COUNT_5M]) / 300 @@ -519,11 +533,9 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def do(self, event): - logger.info("ProcessBeforeParquet1", event=event) # Remove the following keys from the event for key in [ - EventFieldType.UNPACKED_LABELS, EventFieldType.FEATURES, EventFieldType.NAMED_FEATURES, ]: @@ -549,32 +561,23 @@ def do(self, event): class ProcessEndpointEvent(mlrun.feature_store.steps.MapClass): def __init__( self, - kv_container: str, - kv_path: str, - v3io_access_key: str, + project: str, **kwargs, ): """ Process event or batch of events as part of the first step of the monitoring serving graph. It includes - Adding important details to the event such as endpoint_id, handling errors coming from the stream, Validation + Adding important details to the event such as endpoint_id, handling errors coming from the stream, validation of event data such as inputs and outputs, and splitting model event into sub-events. - :param kv_container: Name of the container that will be used to retrieve the endpoint id. For model - endpoints it is usually 'users'. - :param kv_path: KV table path that will be used to retrieve the endpoint id. For model endpoints - it is usually pipelines/project-name/model-endpoints/endpoints/ - :param v3io_access_key: Access key with permission to read from a KV table. - :param project: Project name. - + :param project: Project name. :returns: A Storey event object which is the basic unit of data in Storey. Note that the next steps of the monitoring serving graph are based on Storey operations. """ super().__init__(**kwargs) - self.kv_container: str = kv_container - self.kv_path: str = kv_path - self.v3io_access_key: str = v3io_access_key + + self.project: str = project # First and last requests timestamps (value) of each endpoint (key) self.first_request: typing.Dict[str, str] = dict() @@ -602,7 +605,7 @@ def do(self, full_event): version = event.get(EventFieldType.VERSION) versioned_model = f"{model}:{version}" if version else f"{model}:latest" - endpoint_id = mlrun.utils.model_monitoring.create_model_endpoint_id( + endpoint_id = mlrun.common.model_monitoring.create_model_endpoint_uid( function_uri=function_uri, versioned_model=versioned_model, ) @@ -615,10 +618,12 @@ def do(self, full_event): # In case this process fails, resume state from existing record self.resume_state(endpoint_id) - # Handle errors coming from stream - found_errors = self.handle_errors(endpoint_id, event) - if found_errors: - return None + # If error key has been found in the current event, + # increase the error counter by 1 and raise the error description + error = event.get("error") + if error: + self.error_count[endpoint_id] += 1 + raise mlrun.errors.MLRunInvalidArgumentError(str(error)) # Validate event fields model_class = event.get("model_class") or event.get("class") @@ -679,11 +684,6 @@ def do(self, full_event): ): return None - # Get labels from event (if exist) - unpacked_labels = { - f"_{k}": v for k, v in event.get(EventFieldType.LABELS, {}).items() - } - # Adjust timestamp format timestamp = datetime.datetime.strptime(timestamp[:-6], "%Y-%m-%d %H:%M:%S.%f") @@ -722,7 +722,6 @@ def do(self, full_event): EventFieldType.ENTITIES: event.get("request", {}).get( EventFieldType.ENTITIES, {} ), - EventFieldType.UNPACKED_LABELS: unpacked_labels, } ) @@ -745,14 +744,13 @@ def _validate_last_request_timestamp(self, endpoint_id: str, timestamp: str): endpoint_id in self.last_request and self.last_request[endpoint_id] > timestamp ): - logger.error( f"current event request time {timestamp} is earlier than the last request time " f"{self.last_request[endpoint_id]} - write to TSDB will be rejected" ) + @staticmethod def is_list_of_numerics( - self, field: typing.List[typing.Union[int, float, dict, list]], dict_path: typing.List[str], ): @@ -769,10 +767,8 @@ def resume_state(self, endpoint_id): if endpoint_id not in self.endpoints: logger.info("Trying to resume state", endpoint_id=endpoint_id) endpoint_record = get_endpoint_record( - kv_container=self.kv_container, - kv_path=self.kv_path, + project=self.project, endpoint_id=endpoint_id, - access_key=self.v3io_access_key, ) # If model endpoint found, get first_request, last_request and error_count values @@ -784,13 +780,12 @@ def resume_state(self, endpoint_id): last_request = endpoint_record.get(EventFieldType.LAST_REQUEST) if last_request: - self.last_request[endpoint_id] = last_request error_count = endpoint_record.get(EventFieldType.ERROR_COUNT) if error_count: - self.error_count[endpoint_id] = error_count + self.error_count[endpoint_id] = int(error_count) # add endpoint to endpoints set self.endpoints.add(endpoint_id) @@ -807,13 +802,6 @@ def is_valid( self.error_count[endpoint_id] += 1 return False - def handle_errors(self, endpoint_id, event) -> bool: - if "error" in event: - self.error_count[endpoint_id] += 1 - return True - - return False - def is_not_none(field: typing.Any, dict_path: typing.List[str]): if field is not None: @@ -857,9 +845,7 @@ def do(self, event): class MapFeatureNames(mlrun.feature_store.steps.MapClass): def __init__( self, - kv_container: str, - kv_path: str, - access_key: str, + project: str, infer_columns_from_data: bool = False, **kwargs, ): @@ -867,11 +853,7 @@ def __init__( Validating feature names and label columns and map each feature to its value. In the end of this step, the event should have key-value pairs of (feature name: feature value). - :param kv_container: Name of the container that will be used to retrieve the endpoint id. For model - endpoints it is usually 'users'. - :param kv_path: KV table path that will be used to retrieve the endpoint id. For model endpoints - it is usually pipelines/project-name/model-endpoints/endpoints/ - :param v3io_access_key: Access key with permission to read from a KV table. + :param project: Project name. :param infer_columns_from_data: If true and features or labels names were not found, then try to retrieve them from data that was stored in the previous events of the current process. This data can be found under self.feature_names and @@ -882,10 +864,9 @@ def __init__( feature names and values (as well as the prediction results). """ super().__init__(**kwargs) - self.kv_container = kv_container - self.kv_path = kv_path - self.access_key = access_key + self._infer_columns_from_data = infer_columns_from_data + self.project = project # Dictionaries that will be used in case features names # and labels columns were not found in the current event @@ -914,10 +895,8 @@ def do(self, event: typing.Dict): # Get feature names and label columns if endpoint_id not in self.feature_names: endpoint_record = get_endpoint_record( - kv_container=self.kv_container, - kv_path=self.kv_path, + project=self.project, endpoint_id=endpoint_id, - access_key=self.access_key, ) feature_names = endpoint_record.get(EventFieldType.FEATURE_NAMES) feature_names = json.loads(feature_names) if feature_names else None @@ -940,15 +919,12 @@ def do(self, event: typing.Dict): ] # Update the endpoint record with the generated features - mlrun.utils.v3io_clients.get_v3io_client().kv.update( - container=self.kv_container, - table_path=self.kv_path, - access_key=self.access_key, - key=event[EventFieldType.ENDPOINT_ID], + update_endpoint_record( + project=self.project, + endpoint_id=endpoint_id, attributes={ EventFieldType.FEATURE_NAMES: json.dumps(feature_names) }, - raise_for_status=v3io.dataplane.RaiseForStatus.always, ) # Similar process with label columns @@ -963,15 +939,11 @@ def do(self, event: typing.Dict): label_columns = [ f"p{i}" for i, _ in enumerate(event[EventFieldType.PREDICTION]) ] - mlrun.utils.v3io_clients.get_v3io_client().kv.update( - container=self.kv_container, - table_path=self.kv_path, - access_key=self.access_key, - key=event[EventFieldType.ENDPOINT_ID], - attributes={ - EventFieldType.LABEL_COLUMNS: json.dumps(label_columns) - }, - raise_for_status=v3io.dataplane.RaiseForStatus.always, + + update_endpoint_record( + project=self.project, + endpoint_id=endpoint_id, + attributes={EventFieldType.LABEL_NAMES: json.dumps(label_columns)}, ) self.label_columns[endpoint_id] = label_columns @@ -1033,33 +1005,24 @@ def _map_dictionary_values( event[mapping_dictionary][name] = value -class WriteToKV(mlrun.feature_store.steps.MapClass): - def __init__(self, container: str, table: str, v3io_access_key: str, **kwargs): +class UpdateEndpoint(mlrun.feature_store.steps.MapClass): + def __init__(self, project: str, model_endpoint_store_target: str, **kwargs): """ - Writes the event to KV table. Note that the event at this point includes metadata and stats about the - average latency and the amount of predictions over time. This data will be used in the monitoring dashboards + Update the model endpoint record in the DB. Note that the event at this point includes metadata and stats about + the average latency and the amount of predictions over time. This data will be used in the monitoring dashboards such as "Model Monitoring - Performance" which can be found in Grafana. - :param kv_container: Name of the container that will be used to retrieve the endpoint id. For model - endpoints it is usually 'users'. - :param table: KV table path that will be used to retrieve the endpoint id. For model endpoints - it is usually pipelines/project-name/model-endpoints/endpoints/. - :param v3io_access_key: Access key with permission to read from a KV table. - :returns: Event as a dictionary (without any changes) for the next step (InferSchema). """ super().__init__(**kwargs) - self.container = container - self.table = table - self.v3io_access_key = v3io_access_key + self.project = project + self.model_endpoint_store_target = model_endpoint_store_target def do(self, event: typing.Dict): - mlrun.utils.v3io_clients.get_v3io_client().kv.update( - container=self.container, - table_path=self.table, - key=event[EventFieldType.ENDPOINT_ID], + update_endpoint_record( + project=self.project, + endpoint_id=event.pop(EventFieldType.ENDPOINT_ID), attributes=event, - access_key=self.v3io_access_key, ) return event @@ -1067,7 +1030,6 @@ def do(self, event: typing.Dict): class InferSchema(mlrun.feature_store.steps.MapClass): def __init__( self, - v3io_access_key: str, v3io_framesd: str, container: str, table: str, @@ -1087,7 +1049,6 @@ def __init__( """ super().__init__(**kwargs) self.container = container - self.v3io_access_key = v3io_access_key self.v3io_framesd = v3io_framesd self.table = table self.keys = set() @@ -1098,34 +1059,29 @@ def do(self, event: typing.Dict): self.keys.update(key_set) # Apply infer_schema on the kv table for generating the schema file mlrun.utils.v3io_clients.get_frames_client( - token=self.v3io_access_key, container=self.container, address=self.v3io_framesd, ).execute(backend="kv", table=self.table, command="infer_schema") + return event -def get_endpoint_record( - kv_container: str, kv_path: str, endpoint_id: str, access_key: str -) -> typing.Optional[dict]: - logger.info( - "Grabbing endpoint data", - container=kv_container, - table_path=kv_path, - key=endpoint_id, +def update_endpoint_record( + project: str, + endpoint_id: str, + attributes: dict, +): + model_endpoint_store = get_model_endpoint_store( + project=project, ) - try: - endpoint_record = ( - mlrun.utils.v3io_clients.get_v3io_client() - .kv.get( - container=kv_container, - table_path=kv_path, - key=endpoint_id, - access_key=access_key, - raise_for_status=v3io.dataplane.RaiseForStatus.always, - ) - .output.item - ) - return endpoint_record - except Exception: - return None + + model_endpoint_store.update_model_endpoint( + endpoint_id=endpoint_id, attributes=attributes + ) + + +def get_endpoint_record(project: str, endpoint_id: str): + model_endpoint_store = get_model_endpoint_store( + project=project, + ) + return model_endpoint_store.get_model_endpoint(endpoint_id=endpoint_id) diff --git a/mlrun/package/__init__.py b/mlrun/package/__init__.py new file mode 100644 index 000000000000..71680331cbc3 --- /dev/null +++ b/mlrun/package/__init__.py @@ -0,0 +1,163 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx + +import functools +import inspect +from collections import OrderedDict +from typing import Callable, Dict, List, Type, Union + +from ..config import config +from .context_handler import ContextHandler +from .errors import ( + MLRunPackageCollectionError, + MLRunPackageError, + MLRunPackagePackingError, + MLRunPackageUnpackingError, +) +from .packager import Packager +from .packagers import DefaultPackager +from .packagers_manager import PackagersManager +from .utils import ( + ArchiveSupportedFormat, + ArtifactType, + LogHintKey, + StructFileSupportedFormat, +) + + +def handler( + labels: Dict[str, str] = None, + outputs: List[Union[str, Dict[str, str]]] = None, + inputs: Union[bool, Dict[str, Union[str, Type]]] = True, +): + """ + MLRun's handler is a decorator to wrap a function and enable setting labels, parsing inputs (`mlrun.DataItem`) using + type hints and log returning outputs using log hints. + + Notice: this decorator is now appplied automatically with the release of `mlrun.package`. It should not be used + manually. + + :param labels: Labels to add to the run. Expecting a dictionary with the labels names as keys. Default: None. + :param outputs: Log hints (logging configurations) for the function's returned values. Expecting a list of the + following values: + + * `str` - A string in the format of '{key}:{artifact_type}'. If a string was given without ':' it + will indicate the key, and the artifact type will be according to the returned value type's + default artifact type. The artifact types supported are listed in the relevant type packager. + * `Dict[str, str]` - A dictionary of logging configuration. the key 'key' is mandatory for the + logged artifact key. + * None - Do not log the output. + + If the list length is not equal to the total amount of returned values from the function, those + without log hints will be ignored. + + Default: None - meaning no outputs will be logged. + + :param inputs: Type hints (parsing configurations) for the arguments passed as inputs via the `run` method of an + MLRun function. Can be passed as a boolean value or a dictionary: + + * True - Parse all found inputs to the assigned type hint in the function's signature. If there is no + type hint assigned, the value will remain an `mlrun.DataItem`. + * False - Do not parse inputs, leaving the inputs as `mlrun.DataItem`. + * Dict[str, Union[Type, str]] - A dictionary with argument name as key and the expected type to parse + the `mlrun.DataItem` to. The expected type can be a string as well, idicating the full module path. + + Default: True - meaning inputs will be parsed from `DataItem`s as long as they are type hinted. + + Example:: + + import mlrun + + @mlrun.handler( + outputs=[ + "my_string", + None, + {"key": "my_array", "artifact_type": "file", "file_format": "npy"}, + "my_multiplier: reuslt" + ] + ) + def my_handler(array: np.ndarray, m: int): + m += 1 + array = array * m + return "I will be logged", "I won't be logged", array, m + + >>> mlrun_function = mlrun.code_to_function("my_code.py", kind="job") + >>> run_object = mlrun_function.run( + ... handler="my_handler", + ... inputs={"array": "store://my_array_Artifact"}, + ... params={"m": 2} + ... ) + >>> run_object.outputs + {'my_string': 'I will be logged', 'my_array': 'store://...', 'my_multiplier': 3} + """ + + def decorator(func: Callable): + def wrapper(*args: tuple, **kwargs: dict): + nonlocal labels + nonlocal outputs + nonlocal inputs + + # Set default `inputs` - inspect the full signature and add the user's input on top of it: + if inputs: + # Get the available parameters type hints from the function's signature: + func_signature = inspect.signature(func) + parameters = OrderedDict( + { + parameter.name: parameter.annotation + for parameter in func_signature.parameters.values() + } + ) + # If user input is given, add it on top of the collected defaults (from signature): + if isinstance(inputs, dict): + parameters.update(inputs) + inputs = parameters + + # Create a context handler and look for a context: + cxt_handler = ContextHandler() + cxt_handler.look_for_context(args=args, kwargs=kwargs) + + # If an MLRun context is found, parse arguments pre-run (kwargs are parsed inplace): + if cxt_handler.is_context_available() and inputs: + args = cxt_handler.parse_inputs( + args=args, kwargs=kwargs, type_hints=inputs + ) + + # Call the original function and get the returning values: + func_outputs = func(*args, **kwargs) + + # If an MLRun context is found, set the given labels and log the returning values to MLRun via the context: + if cxt_handler.is_context_available(): + if labels: + # TODO: Should deprecate this labels + cxt_handler.set_labels(labels=labels) + if outputs: + cxt_handler.log_outputs( + outputs=func_outputs + if type(func_outputs) is tuple + and not config.packagers.pack_tuples + else [func_outputs], + log_hints=outputs, + ) + return # Do not return any values as the returning values were logged to MLRun. + return func_outputs + + # Make sure to pass the wrapped function's signature (argument list, type hints and doc strings) to the wrapper: + wrapper = functools.wraps(func)(wrapper) + + return wrapper + + return decorator diff --git a/mlrun/package/context_handler.py b/mlrun/package/context_handler.py new file mode 100644 index 000000000000..f193c9b4c277 --- /dev/null +++ b/mlrun/package/context_handler.py @@ -0,0 +1,325 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import os +from collections import OrderedDict +from typing import Dict, List, Union + +from mlrun.datastore import DataItem +from mlrun.errors import MLRunInvalidArgumentError +from mlrun.execution import MLClientCtx +from mlrun.run import get_or_create_ctx + +from .errors import MLRunPackageCollectionError, MLRunPackagePackingError +from .packagers_manager import PackagersManager +from .utils import ArtifactType, LogHintKey, LogHintUtils, TypeHintUtils + + +class ContextHandler: + """ + A class for handling a MLRun context of a function that is wrapped in MLRun's `handler` decorator. + + The context handler have 3 duties: + 1. Check if the user used MLRun to run the wrapped function and if so, get the MLRun context. + 2. Parse the user's inputs (MLRun `DataItem`) to the function. + 3. Log the function's outputs to MLRun. + + The context handler uses a packagers manager to unpack (parse) the inputs and pack (log) the outputs. It sets up a + manager with all the packagers in the `mlrun.package.packagers` directory. Packagers whom are in charge of modules + that are in the MLRun requirements are mandatory and additional extensions packagers for non-required modules are + added if the modules are available in the user's interpreter. Once a context is found, project custom packagers will + be added as well. + """ + + # Mandatory packagers to be collected at initialization time: + _MLRUN_REQUIREMENTS_PACKAGERS = [ + "python_standard_library", + "pandas", + "numpy", + ] + # Optional packagers to be collected at initialization time: + _EXTENDED_PACKAGERS = [] # TODO: Create "matplotlib", "plotly", "bokeh" packagers. + # Optional packagers from the `mlrun.frameworks` package: + _MLRUN_FRAMEWORKS_PACKAGERS = [] # TODO: Create frameworks packagers. + # Default priority values for packagers: + _BUILTIN_PACKAGERS_DEFAULT_PRIORITY = 5 + _CUSTOM_PACKAGERS_DEFAULT_PRIORITY = 3 + + def __init__(self): + """ + Initialize a context handler. + """ + # Set up a variable to hold the context: + self._context: MLClientCtx = None + + # Initialize a packagers manager: + self._packagers_manager = PackagersManager() + + # Prepare the manager (collect the MLRun builtin standard and optional packagers): + self._collect_mlrun_packagers() + + def look_for_context(self, args: tuple, kwargs: dict): + """ + Look for an MLRun context (`mlrun.MLClientCtx`). The handler will look for a context in the given order: + 1. The given arguments. + 2. The given keyword arguments. + 3. If an MLRun RunTime was used the context will be located via the `mlrun.get_or_create_ctx` method. + + :param args: The arguments tuple passed to the function. + :param kwargs: The keyword arguments dictionary passed to the function. + """ + # Search in the given arguments: + for argument in args: + if isinstance(argument, MLClientCtx): + self._context = argument + break + + # Search in the given keyword arguments: + if self._context is None: + for argument_name, argument_value in kwargs.items(): + if isinstance(argument_value, MLClientCtx): + self._context = argument_value + break + + # Search if the function was triggered from an MLRun RunTime object by looking at the call stack: + # Index 0: the current frame. + # Index 1: the decorator's frame. + # Index 2-...: If it is from mlrun.runtimes we can be sure it ran via MLRun, otherwise not. + if self._context is None: + for callstack_frame in inspect.getouterframes(inspect.currentframe()): + if ( + os.path.join("mlrun", "runtimes", "local") + in callstack_frame.filename + ): + self._context = get_or_create_ctx("context") + break + + # Give the packagers manager custom packagers to collect (if a context is found and a project is available): + if self._context is not None and self._context.project: + # Get the custom packagers property from the project's spec: + project = self._context.get_project_object() + if project and project.spec.custom_packagers: + # Add the custom packagers taking into account the mandatory flag: + for custom_packager, is_mandatory in project.spec.custom_packagers: + self._collect_packagers( + packagers=[custom_packager], + is_mandatory=is_mandatory, + is_custom_packagers=True, + ) + + def is_context_available(self) -> bool: + """ + Check if a context was found by the method `look_for_context`. + + :returns: True if a context was found and False otherwise. + """ + return self._context is not None + + def parse_inputs( + self, + args: tuple, + kwargs: dict, + type_hints: OrderedDict, + ) -> tuple: + """ + Parse the given arguments and keyword arguments data items to the expected types. + + :param args: The arguments tuple passed to the function. + :param kwargs: The keyword arguments dictionary passed to the function. + :param type_hints: An ordered dictionary of the expected types of arguments. + + :returns: The parsed args (kwargs are parsed inplace). + """ + # Parse the type hints (in case some were given as strings): + type_hints = { + key: TypeHintUtils.parse_type_hint(type_hint=value) + for key, value in type_hints.items() + } + + # Parse the arguments: + parsed_args = [] + type_hints_keys = list(type_hints.keys()) + for i, argument in enumerate(args): + if ( + isinstance(argument, DataItem) + and type_hints[type_hints_keys[i]] is not inspect.Parameter.empty + ): + parsed_args.append( + self._packagers_manager.unpack( + data_item=argument, + type_hint=type_hints[type_hints_keys[i]], + ) + ) + else: + parsed_args.append(argument) + parsed_args = tuple(parsed_args) # `args` is expected to be a tuple. + + # Parse the keyword arguments: + for key, value in kwargs.items(): + if ( + isinstance(value, DataItem) + and type_hints[key] is not inspect.Parameter.empty + ): + kwargs[key] = self._packagers_manager.unpack( + data_item=value, type_hint=type_hints[key] + ) + + return parsed_args + + def log_outputs( + self, + outputs: list, + log_hints: List[Union[Dict[str, str], str, None]], + ): + """ + Log the given outputs as artifacts (or results) with the stored context. Errors raised during the packing will + be ignored to not fail a run. A warning with the error wil be printed. + + :param outputs: List of outputs to log. + :param log_hints: List of log hints (logging configurations) to use. + """ + # Verify the outputs and log hints are the same length: + if len(outputs) != len(log_hints): + self._context.logger.warn( + f"The amount of outputs objects returned from the function ({len(outputs)}) does not match the amount " + f"of provided log hints ({len(log_hints)})." + ) + if len(outputs) > len(log_hints): + ignored_outputs = [str(output) for output in outputs[len(log_hints) :]] + self._context.logger.warn( + f"The following outputs will not be logged: {', '.join(ignored_outputs)}" + ) + if len(outputs) < len(log_hints): + ignored_log_hints = [ + str(log_hint) for log_hint in log_hints[len(outputs) :] + ] + self._context.logger.warn( + f"The following log hints will be ignored: {', '.join(ignored_log_hints)}" + ) + + # Go over the outputs and pack them: + for obj, log_hint in zip(outputs, log_hints): + try: + # Check if needed to log (not None): + if log_hint is None: + continue + # Parse the log hint: + log_hint = LogHintUtils.parse_log_hint(log_hint=log_hint) + # Check if the object to log is None (None values are only logged if the artifact type is Result): + if ( + obj is None + and log_hint.get(LogHintKey.ARTIFACT_TYPE, ArtifactType.RESULT) + != ArtifactType.RESULT + ): + continue + # Pack the object (we don't catch the returned package as we log it after we pack all the outputs to + # enable linking extra data of some artifacts): + self._packagers_manager.pack(obj=obj, log_hint=log_hint) + except (MLRunInvalidArgumentError, MLRunPackagePackingError) as error: + self._context.logger.warn( + f"Skipping logging an object with the log hint '{log_hint}' due to the following error:\n{error}" + ) + + # Link packages: + self._packagers_manager.link_packages( + additional_artifacts=self._context.artifacts, + additional_results=self._context.results, + ) + + # Log the packed results and artifacts: + self._context.log_results(results=self._packagers_manager.results) + for artifact in self._packagers_manager.artifacts: + self._context.log_artifact(item=artifact) + + # Clear packagers outputs: + self._packagers_manager.clear_packagers_outputs() + + def set_labels(self, labels: Dict[str, str]): + """ + Set the given labels with the stored context. + + :param labels: The labels to set. + """ + for key, value in labels.items(): + self._context.set_label(key=key, value=value) + + def _collect_packagers( + self, packagers: List[str], is_mandatory: bool, is_custom_packagers: bool + ): + """ + Collect packagers with the stored manager. The collection can ignore errors raised by setting the mandatory flag + to False. + + :param packagers: The list of packagers to collect. + :param is_mandatory: Whether the packagers are mandatory for the context run. + :param is_custom_packagers: Whether the packagers to collect are user's custom or MLRun's builtins. + """ + try: + self._packagers_manager.collect_packagers( + packagers=packagers, + default_priority=self._CUSTOM_PACKAGERS_DEFAULT_PRIORITY + if is_custom_packagers + else self._BUILTIN_PACKAGERS_DEFAULT_PRIORITY, + ) + except MLRunPackageCollectionError as error: + if is_mandatory: + raise error + else: + # If the packagers to collect were added manually by the user, the logger should write the collection + # issue as a warning. Otherwise - for mlrun builtin packagers, a debug message will do. + message = ( + f"The given optional packagers '{packagers}' could not be imported due to the following error:\n" + f"'{error}'" + ) + if is_custom_packagers: + self._context.logger.warn(message) + else: + self._context.logger.debug(message) + + def _collect_mlrun_packagers(self): + """ + Collect MLRun's builtin packagers. That include all mandatory packagers whom in charge of MLRun's requirements + libraries, more optional commonly used libraries packagers and more `mlrun.frameworks` packagers. The priority + will be as follows (from higher to lower priority): + + 1. Optional `mlrun.frameworks` packagers + 2. MLRun's optional packagers + 3. MLRun's mandatory packagers (MLRun's requirements) + """ + # Collect MLRun's requirements packagers (mandatory): + self._collect_packagers( + packagers=[ + f"mlrun.package.packagers.{module_name}_packagers.*" + for module_name in self._MLRUN_REQUIREMENTS_PACKAGERS + ], + is_mandatory=True, + is_custom_packagers=False, + ) + + # Add extra packagers for optional libraries: + for module_name in self._EXTENDED_PACKAGERS: + self._collect_packagers( + packagers=[f"mlrun.package.packagers.{module_name}_packagers.*"], + is_mandatory=False, + is_custom_packagers=False, + ) + + # Add extra packagers from `mlrun.frameworks` package: + for module_name in self._MLRUN_FRAMEWORKS_PACKAGERS: + self._collect_packagers( + packagers=[f"mlrun.frameworks.{module_name}.packagers.*"], + is_mandatory=False, + is_custom_packagers=False, + ) diff --git a/mlrun/package/errors.py b/mlrun/package/errors.py new file mode 100644 index 000000000000..8ab0f119fe9e --- /dev/null +++ b/mlrun/package/errors.py @@ -0,0 +1,47 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from mlrun.errors import MLRunBaseError + + +class MLRunPackageError(MLRunBaseError): + """ + General error from `mlrun.package`. + """ + + pass + + +class MLRunPackageCollectionError(MLRunPackageError): + """ + An error that may be raised during the collection of packagers the manager is assigned to do. + """ + + pass + + +class MLRunPackagePackingError(MLRunPackageError): + """ + An error that may be raised during a `mlrun.Packager.pack` method. + """ + + pass + + +class MLRunPackageUnpackingError(MLRunPackageError): + """ + An error that may be raised during a `mlrun.Packager.unpack` method. + """ + + pass diff --git a/mlrun/package/packager.py b/mlrun/package/packager.py new file mode 100644 index 000000000000..e6001fe89c42 --- /dev/null +++ b/mlrun/package/packager.py @@ -0,0 +1,298 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC, ABCMeta, abstractmethod +from pathlib import Path +from typing import Any, List, Tuple, Type, Union + +from mlrun.artifacts import Artifact +from mlrun.datastore import DataItem + +from .utils import TypeHintUtils + + +# TODO: When 3.7 is no longer supported, add "Packager" as reference type hint to cls (cls: Type["Packager"]) and other. +class _PackagerMeta(ABCMeta): + """ + Metaclass for `Packager` to override type class methods. + """ + + def __lt__(cls, other) -> bool: + """ + A less than implementation to compare by priority in order to be able to sort the packagers by it. + + :param other: The compared packager. + + :return: True if priority is lower (means better) and False otherwise. + """ + return cls.PRIORITY < other.PRIORITY + + def __repr__(cls) -> str: + """ + Get the string representation of a packager in the following format: + (type=, artifact_types=[], priority=) + + :return: The string representation of e packager. + """ + # Get the packager info into variables: + packager_name = cls.__name__ + handled_type = ( + ( + # Types have __name__ attribute but typing's types do not. + cls.PACKABLE_OBJECT_TYPE.__name__ + if hasattr(cls.PACKABLE_OBJECT_TYPE, "__name__") + else str(cls.PACKABLE_OBJECT_TYPE) + ) + if cls.PACKABLE_OBJECT_TYPE is not ... + else "Any" + ) + supported_artifact_types = cls.get_supported_artifact_types() + + # Return the string representation in the format noted above: + return ( + f"{packager_name}(packable_type={handled_type}, artifact_types={supported_artifact_types}, " + f"priority={cls.PRIORITY})" + ) + + +class Packager(ABC, metaclass=_PackagerMeta): + """ + The abstract base class for a packager. A packager is a static class that have two main duties: + + 1. Packing - get an object that was returned from a function and log it to MLRun. The user can specify packing + configurations to the packager using log hints. The packed object can be an artifact or a result. + 2. Unpacking - get a ``mlrun.DataItem`` (an input to a MLRun function) and parse it to the desired hinted type. The + packager is using the instructions it noted itself when originally packing the object. + + The Packager has one class variable and five class methods that must be implemented: + + * ``PACKABLE_OBJECT_TYPE`` - A class variable to specify the object type this packager handles. Used for the + ``is_packable`` and ``repr`` methods. An ellipses (`...`) means any type. + * ``PRIORITY`` - The priority of this packager among the rest of the packagers. Should be an integer between 1-10 + where 1 is the highest priority and 10 is the lowest. If not set, a default priority of 5 is set for MLRun + builtin packagers and 3 for user custom packagers. + * ``get_default_packing_artifact_type`` - A class method to get the default artifact type for packing an object + when it is not provided by the user. + * ``get_default_unpacking_artifact_type`` - A class method to get the default artifact type for unpacking a data + item when it is not representing a package, but a simple url or an old / manually logged artifact + * ``get_supported_artifact_types`` - A class method to get the supported artifact types this packager can pack an + object as. Used for the ``is_packable`` and `repr` methods. + * ``pack`` - A class method to pack a returned object using the provided log hint configurations while noting itself + instructions for how to unpack it once needed (only relevant of packed artifacts as results do not need + unpacking). + * ``unpack`` - A class method to unpack a MLRun ``DataItem``, parsing it to its desired hinted type using the + instructions noted while originally packing it. + + The class methods ``is_packable`` and ``is_unpackable`` are implemented with the following basic logic: + + * ``is_packable`` - a class method to know whether to use this packager to pack an object by its + type and artifact type, compares the object's type with the ``PACKABLE_OBJECT_TYPE`` and checks the artifact type + is in the returned supported artifacts list from ``get_supported_artifact_types``. + * ``is_unpackable`` - a class method to know whether to use this packager to unpack a data item by the user noted + type hint and optionally stored artifact type in the data item (in case it was packaged before), matches the + ``PACKABLE_OBJECT_TYPE`` to the type hint given (same logic as IDE matchups, meaning subclasses considered as + unpackable) and checks if the artifact type is in the returned supported artifacts list from + ``get_supported_artifact_types``. + + Preferably, each packager should handle a single type of object. + + Linking Artifacts (extra data) + ------------------------------ + + In order to link between packages (using the extra data or metrics spec attributes of an artifact), you should use + the key as if it exists and as value ellipses (...). The manager will link all packages once it is done packing. + + For example, given extra data keys in the log hint as `extra_data`, setting them to an artifact should be:: + + artifact = Artifact(key="my_artifact") + artifact.spec.extra_data = {key: ... for key in extra_data} + + Clearing Outputs + ---------------- + + Some of the packagers may produce files and temporary directories that should be deleted once done with logging the + artifact. The packager can mark paths of files and directories to delete after logging using the class method + ``future_clear``. + + For example, in the following packager's ``pack`` method we can write a text file, create an Artifact and then mark + the text file to be deleted once the artifact is logged:: + + with open("./some_file.txt", "w") as file: + file.write("Pack me") + artifact = Artifact(key="my_artifact") + cls.future_clear(path="./some_file.txt") + return artifact, None + """ + + # The type of object this packager can pack and unpack: + PACKABLE_OBJECT_TYPE: Type = ... + + # The priority of this packager in the packagers collection of the manager (lower is better) + PRIORITY = ... + + # List of all paths to be deleted by the manager of this packager post logging the packages: + _CLEARING_PATH_LIST: List[str] = [] + + @classmethod + @abstractmethod + def get_default_packing_artifact_type(cls, obj: Any) -> str: + """ + Get the default artifact type used for packing. The method will be used when an object is sent for packing + without an artifact type noted by the user. + + :param obj: The about to be packed object. + + :return: The default artifact type. + """ + pass + + @classmethod + @abstractmethod + def get_default_unpacking_artifact_type(cls, data_item: DataItem) -> str: + """ + Get the default artifact type used for unpacking a data item holding an object of this packager. The method will + be used when a data item is sent for unpacking without it being a package, but a simple url or an old / manually + logged artifact. + + :param data_item: The about to be unpacked data item. + + :return: The default artifact type. + """ + pass + + @classmethod + @abstractmethod + def get_supported_artifact_types(cls) -> List[str]: + """ + Get all the supported artifact types on this packager. + + :return: A list of all the supported artifact types. + """ + pass + + @classmethod + @abstractmethod + def pack( + cls, obj: Any, artifact_type: str = None, configurations: dict = None + ) -> Union[Tuple[Artifact, dict], dict]: + """ + Pack an object as the given artifact type using the provided configurations. + + :param obj: The object to pack. + :param artifact_type: Artifact type to log to MLRun. + :param configurations: Log hints configurations to pass to the packing method. + + :return: If the packed object is an artifact, a tuple of the packed artifact and unpacking instructions + dictionary. If the packed object is a result, a dictionary containing the result key and value. + """ + pass + + @classmethod + @abstractmethod + def unpack( + cls, + data_item: DataItem, + artifact_type: str = None, + instructions: dict = None, + ) -> Any: + """ + Unpack the data item's artifact by the provided type using the given instructions. + + :param data_item: The data input to unpack. + :param artifact_type: The artifact type to unpack the data item as. + :param instructions: Additional instructions noted in the package to pass to the unpacking method. + + :return: The unpacked data item's object. + """ + pass + + @classmethod + def is_packable(cls, obj: Any, artifact_type: str = None) -> bool: + """ + Check if this packager can pack an object of the provided type as the provided artifact type. + + The default implementation check if the packable object type of this packager is equal to the given object's + type. If it does match, it will look for the artifact type in the list returned from + `get_supported_artifact_types`. + + :param obj: The object to pack. + :param artifact_type: The artifact type to log the object as. + + :return: True if packable and False otherwise. + """ + # Get the object's type: + object_type = type(obj) + + # Validate the object type (ellipses means any type): + if ( + cls.PACKABLE_OBJECT_TYPE is not ... + and object_type != cls.PACKABLE_OBJECT_TYPE + ): + return False + + # Validate the artifact type (if given): + if artifact_type and artifact_type not in cls.get_supported_artifact_types(): + return False + + return True + + @classmethod + def is_unpackable( + cls, data_item: DataItem, type_hint: Type, artifact_type: str = None + ) -> bool: + """ + Check if this packager can unpack an input according to the user given type hint and the provided artifact type. + + The default implementation tries to match the packable object type of this packager to the given type hint, if + it does match, it will look for the artifact type in the list returned from `get_supported_artifact_types`. + + :param data_item: The input data item to check if unpackable. + :param type_hint: The type hint of the input to unpack. + :param artifact_type: The artifact type to unpack the object as. + + :return: True if unpackable and False otherwise. + """ + # Check type (ellipses means any type): + if cls.PACKABLE_OBJECT_TYPE is not ...: + if not TypeHintUtils.is_matching( + object_type=cls.PACKABLE_OBJECT_TYPE, + type_hint=type_hint, + reduce_type_hint=False, + ): + return False + + # Check the artifact type: + if artifact_type and artifact_type not in cls.get_supported_artifact_types(): + return False + + # Unpackable: + return True + + @classmethod + def add_future_clearing_path(cls, path: Union[str, Path]): + """ + Mark a path to be cleared by this packager's manager post logging the packaged artifacts. + + :param path: The path to clear. + """ + cls._CLEARING_PATH_LIST.append(str(path)) + + @classmethod + def get_future_clearing_path_list(cls) -> List[str]: + """ + Get the packager's future clearing path list. + + :return: The clearing path list. + """ + return cls._CLEARING_PATH_LIST diff --git a/mlrun/package/packagers/__init__.py b/mlrun/package/packagers/__init__.py new file mode 100644 index 000000000000..5cdf7bf6df62 --- /dev/null +++ b/mlrun/package/packagers/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx +from .default_packager import DefaultPackager +from .numpy_packagers import NumPySupportedFormat diff --git a/mlrun/package/packagers/default_packager.py b/mlrun/package/packagers/default_packager.py new file mode 100644 index 000000000000..b2b37d795427 --- /dev/null +++ b/mlrun/package/packagers/default_packager.py @@ -0,0 +1,422 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +from types import MethodType +from typing import Any, List, Tuple, Type, Union + +from mlrun.artifacts import Artifact +from mlrun.datastore import DataItem +from mlrun.utils import logger + +from ..errors import MLRunPackagePackingError, MLRunPackageUnpackingError +from ..packager import Packager +from ..utils import DEFAULT_PICKLE_MODULE, ArtifactType, Pickler, TypeHintUtils + + +class DefaultPackager(Packager): + """ + A default packager that handles all types and pack them as pickle files. + + The default packager implements all the required methods and have a default logic that should be satisfying most + use cases. In order to work with this class, you shouldn't override the abstract class methods, but follow the + guidelines below: + + * The class variable ``PACKABLE_OBJECT_TYPE``: The type of object this packager can pack and unpack (used in the + ``is_packable`` method). + * The class variable ``PACK_SUBCLASSES``: A flag that indicates whether to pack all subclasses of the + ``PACKABLE_OBJECT_TYPE` (used in the ``is_packable`` method). Default is False. + * The class variable ``DEFAULT_PACKING_ARTIFACT_TYPE``: The default artifact type to pack as. It is being returned + from the method ``get_default_packing_artifact_type`` + * The class variable ``DEFAULT_UNPACKING_ARTIFACT_TYPE``: The default artifact type to unpack from. It is being + returned from the method ``get_default_unpacking_artifact_type``. + * The abstract class method ``pack``: The method is implemented to get the object and send it to the relevant + packing method by the artifact type given using the following naming: "pack_". (if artifact type + was not provided, the default one will be used). For example: if the artifact type is "object" then the class + method ``pack_object`` must be implemented. The signature of each pack class method must be:: + + @classmethod + def pack_x(cls, obj: Any, ...) -> Union[Tuple[Artifact, dict], dict]: + pass + + Where 'x' is the artifact type, 'obj' is the object to pack, ... are additional custom log hint configurations and + the returning values are the packed artifact and the instructions for unpacking it, or in case of result, the + dictionary of the result with its key and value. The log hint configurations are sent by the user and shouldn't be + mandatory, meaning they should have a default value (otherwise, the user will have to add them to every log hint). + * The abstract class method ``unpack``: The method is implemented to get a ``DataItem`` and send it to the relevant + unpacking method by the artifact type using the following naming: "unpack_" (if artifact type was + not provided, the default one will be used). For example: if the artifact type stored within the ``DataItem`` is + "object" then the class method ``unpack_object`` must be implemented. The signature of each unpack class method + must be:: + + @classmethod + def unpack_x(cls, data_item: mlrun.DataItem, ...) -> Any: + pass + + Where 'x' is the artifact type, 'data_item' is the artifact's data item to unpack, ... are the instructions that + were originally returned from ``pack_x`` (Each instruction must be optional (have a default value) to support + objects from this type that were not packaged but customly logged) and the returning value is the unpacked + object. + * The abstract class method ``is_packable``: The method is implemented to validate the object type and artifact type + automatically by the following rules: + + * Object type validation: Checking if the object type given match to the variable ``PACKABLE_OBJECT_TYPE`` with + respect to the ``PACK_SUBCLASSES`` class variable. + * Artifact type validation: Checking if the artifact type given is in the list returned from + ``get_supported_artifact_types``. + + * The abstract class method ``is_unpackable``: The method is left as implemented in ``Packager``. + * The abstract class method ``get_supported_artifact_types``: The method is implemented to look for all + pack + unpack class methods implemented to collect the supported artifact types. If ``PackagerX`` has ``pack_y``, + ``unpack_y`` and ``pack_z``, ``unpack_z`` that means the artifact types supported are 'y' and 'z'. + * The abstract class method ``get_default_packing_artifact_type``: The method is implemented to return the new class + variable ``DEFAULT_PACKING_ARTIFACT_TYPE``. You may still override the method if the default artifact type you + need may change according to the object that's about to be packed. + * The abstract class method ``get_default_unpacking_artifact_type``: The method is implemented to return the new + class variable ``DEFAULT_UNPACKING_ARTIFACT_TYPE``. You may still override the method if the default artifact type + you need may change according to the data item that's about to be unpacked. + + Important to remember (from the ``Packager`` docstring): + + * Linking artifacts ("extra data"): In order to link between packages (using the extra data or metrics spec + attributes of an artifact), you should use the key as if it exists and as value ellipses (...). The manager will + link all packages once it is done packing. + + For example, given extra data keys in the log hint as `extra_data`, setting them to an artifact should be:: + + artifact = Artifact(key="my_artifact") + artifact.spec.extra_data = {key: ... for key in extra_data} + + * Clearing outputs: Some packagers may produce files and temporary directories that should be deleted once done with + logging the artifact. The packager can mark paths of files and directories to delete after logging using the class + method ``future_clear``. + + For example, in the following packager's ``pack`` method we can write a text file, create an Artifact and then + mark the text file to be deleted once the artifact is logged:: + + with open("./some_file.txt", "w") as file: + file.write("Pack me") + artifact = Artifact(key="my_artifact") + cls.future_clear(path="./some_file.txt") + return artifact, None + """ + + # The type of object this packager can pack and unpack: + PACKABLE_OBJECT_TYPE: Type = ... + # A flag for indicating whether to pack all subclasses of the `PACKABLE_OBJECT_TYPE` as well: + PACK_SUBCLASSES = False + # The default artifact type to pack as: + DEFAULT_PACKING_ARTIFACT_TYPE = ArtifactType.OBJECT + # The default artifact type to unpack from: + DEFAULT_UNPACKING_ARTIFACT_TYPE = ArtifactType.OBJECT + + @classmethod + def get_default_packing_artifact_type(cls, obj: Any) -> str: + """ + Get the default artifact type for packing an object of this packager. + + :param obj: The about to be packed object. + + :return: The default artifact type. + """ + return cls.DEFAULT_PACKING_ARTIFACT_TYPE + + @classmethod + def get_default_unpacking_artifact_type(cls, data_item: DataItem) -> str: + """ + Get the default artifact type used for unpacking a data item holding an object of this packager. The method will + be used when a data item is sent for unpacking without it being a package, but a simple url or an old / manually + logged artifact. + + :param data_item: The about to be unpacked data item. + + :return: The default artifact type. + """ + return cls.DEFAULT_UNPACKING_ARTIFACT_TYPE + + @classmethod + def get_supported_artifact_types(cls) -> List[str]: + """ + Get all the supported artifact types on this packager. + + :return: A list of all the supported artifact types. + """ + # We look for pack + unpack method couples so there won't be a scenario where an object can be packed but not + # unpacked. Result has no unpacking so we add it separately. + return [ + key[len("pack_") :] + for key in dir(cls) + if key.startswith("pack_") and f"unpack_{key[len('pack_'):]}" in dir(cls) + ] + ["result"] + + @classmethod + def pack( + cls, + obj: Any, + artifact_type: str = None, + configurations: dict = None, + ) -> Union[Tuple[Artifact, dict], dict]: + """ + Pack an object as the given artifact type using the provided configurations. + + :param obj: The object to pack. + :param artifact_type: Artifact type to log to MLRun. If passing `None`, the default artifact type will be used. + :param configurations: Log hints configurations to pass to the packing method. + + :return: If the packed object is an artifact, a tuple of the packed artifact and unpacking instructions + dictionary. If the packed object is a result, a dictionary containing the result key and value. + """ + # Get default artifact type in case it was not provided: + if artifact_type is None: + artifact_type = cls.get_default_packing_artifact_type(obj=obj) + + # Set empty dictionary in case no configurations were given: + configurations = configurations or {} + + # Get the packing method according to the artifact type: + pack_method = getattr(cls, f"pack_{artifact_type}") + + # Validate correct configurations were passed: + cls._validate_method_arguments( + method=pack_method, + arguments=configurations, + is_packing=True, + ) + + # Call the packing method and return the package: + return pack_method(obj, **configurations) + + @classmethod + def unpack( + cls, + data_item: DataItem, + artifact_type: str = None, + instructions: dict = None, + ) -> Any: + """ + Unpack the data item's artifact by the provided type using the given instructions. + + :param data_item: The data input to unpack. + :param artifact_type: The artifact type to unpack the data item as. If passing `None`, the default artifact type + will be used. + :param instructions: Additional instructions noted in the package to pass to the unpacking method. + + :return: The unpacked data item's object. + + :raise MLRunPackageUnpackingError: In case the packager could not unpack the data item. + """ + # Get default artifact type in case it was not provided: + if artifact_type is None: + artifact_type = cls.get_default_unpacking_artifact_type(data_item=data_item) + + # Set empty dictionary in case no instructions were given: + instructions = instructions or {} + + # Get the unpacking method according to the artifact type: + unpack_method = getattr(cls, f"unpack_{artifact_type}") + + # Validate correct instructions were passed: + cls._validate_method_arguments( + method=unpack_method, + arguments=instructions, + is_packing=False, + ) + + # Call the unpacking method and return the object: + return unpack_method(data_item, **instructions) + + @classmethod + def is_packable(cls, obj: Any, artifact_type: str = None) -> bool: + """ + Check if this packager can pack an object of the provided type as the provided artifact type. + + The method is implemented to validate the object's type and artifact type by checking if the object type given + match to the variable ``PACKABLE_OBJECT_TYPE`` with respect to the ``PACK_SUBCLASSES`` class variable. If it + does, it will check if the artifact type given is in the list returned from ``get_supported_artifact_types``. + + :param obj: The object to pack. + :param artifact_type: The artifact type to log the object as. + + :return: True if packable and False otherwise. + """ + # Get the object's type: + object_type = type(obj) + + # Check type (ellipses means any type): + if cls.PACKABLE_OBJECT_TYPE is not ...: + if not TypeHintUtils.is_matching( + object_type=object_type, + type_hint=cls.PACKABLE_OBJECT_TYPE, + include_subclasses=cls.PACK_SUBCLASSES, + reduce_type_hint=False, + ): + return False + + # Check the artifact type: + if ( + artifact_type is not None + and artifact_type not in cls.get_supported_artifact_types() + ): + return False + + # Packable: + return True + + @classmethod + def pack_object( + cls, + obj: Any, + key: str, + pickle_module_name: str = DEFAULT_PICKLE_MODULE, + ) -> Tuple[Artifact, dict]: + """ + Pack a python object, pickling it into a pkl file and store it in an artifact. + + :param obj: The object to pack and log. + :param key: The artifact's key. + :param pickle_module_name: The pickle module name to use for serializing the object. + + :return: The artifacts and it's pickling instructions. + """ + # Pickle the object to file: + pickle_path, instructions = Pickler.pickle( + obj=obj, pickle_module_name=pickle_module_name + ) + + # Initialize an artifact to the pkl file: + artifact = Artifact(key=key, src_path=pickle_path) + + # Add the pickle path to the clearing list: + cls.add_future_clearing_path(path=pickle_path) + + return artifact, instructions + + @classmethod + def pack_result(cls, obj: Any, key: str) -> dict: + """ + Pack an object as a result. + + :param obj: The object to pack and log. + :param key: The result's key. + + :return: The result dictionary. + """ + return {key: obj} + + @classmethod + def unpack_object( + cls, + data_item: DataItem, + pickle_module_name: str = DEFAULT_PICKLE_MODULE, + object_module_name: str = None, + python_version: str = None, + pickle_module_version: str = None, + object_module_version: str = None, + ) -> Any: + """ + Unpack the data item's object, unpickle it using the instructions and return. + + Warnings of mismatching python and module versions between the original pickling interpreter and this one may be + raised. + + :param data_item: The data item holding the pkl file. + :param pickle_module_name: Module to use for unpickling the object. + :param object_module_name: The original object's module. Used to verify the current interpreter object module + version match the pickled object version before unpickling the object. + :param python_version: The python version in which the original object was pickled. Used to verify the + current interpreter python version match the pickled object version before + unpickling the object. + :param pickle_module_version: The pickle module version. Used to verify the current interpreter module version + match the one who pickled the object before unpickling it. + :param object_module_version: The original object's module version to match to the interpreter's module version. + + :return: The un-pickled python object. + """ + # Get the pkl file to local directory: + pickle_path = data_item.local() + + # Add the pickle path to the clearing list: + cls.add_future_clearing_path(path=pickle_path) + + # Unpickle and return: + return Pickler.unpickle( + pickle_path=pickle_path, + pickle_module_name=pickle_module_name, + object_module_name=object_module_name, + python_version=python_version, + pickle_module_version=pickle_module_version, + object_module_version=object_module_version, + ) + + @classmethod + def _validate_method_arguments( + cls, method: MethodType, arguments: dict, is_packing: bool + ): + """ + Validate keyword arguments to pass to a method. Used for validating log hint configurations for packing methods + and instructions for unpacking methods. + + :param method: The method to validate the arguments for. + :param arguments: Keyword arguments to validate. + :param is_packing: Flag to know if the arguments came from packing or unpacking, to raise the correct exception + if validation failed. + + :raise MLRunPackagePackingError: If there are missing configurations in the log hint. + :raise MLRunPackageUnpackingError: If there are missing instructions in the artifact's spec. + """ + # Get the possible and mandatory (arguments that has no default value) arguments from the functions: + possible_arguments = inspect.signature(method).parameters + mandatory_arguments = [ + name + for name, parameter in possible_arguments.items() + # If default value is `empty` it is mandatory: + if parameter.default is inspect.Parameter.empty + # Ignore the *args and **kwargs parameters: + and parameter.kind + not in [inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL] + ] + mandatory_arguments.remove("obj" if is_packing else "data_item") + + # Validate there are no missing arguments (only mandatory ones): + missing_arguments = [ + mandatory_argument + for mandatory_argument in mandatory_arguments + if mandatory_argument not in arguments + ] + if missing_arguments: + if is_packing: + raise MLRunPackagePackingError( + f"The packager '{cls.__name__}' could not pack the package due to missing configurations: " + f"{', '.join(missing_arguments)}. Add the missing arguments to the log hint of this object in " + f"order to pack it. Make sure you pass a dictionary log hint and not a string in order to pass " + f"configurations in the log hint." + ) + raise MLRunPackageUnpackingError( + f"The packager '{cls.__name__}' could not unpack the package due to missing instructions: " + f"{', '.join(missing_arguments)}. Missing instructions are likely due to an update in the packager's " + f"code that not support the old implementation. This backward compatibility should not occur. To " + f"overcome it, try to edit the instructions in the artifact's spec to enable unpacking it again." + ) + + # Validate all given arguments are correct: + incorrect_arguments = [ + argument for argument in arguments if argument not in possible_arguments + ] + if incorrect_arguments: + arguments_type = "configurations" if is_packing else "instructions" + logger.warn( + f"Unexpected {arguments_type} given for {cls.__name__}: {', '.join(incorrect_arguments)}. " + f"Possible {arguments_type} are: {', '.join(possible_arguments.keys())}. The packager will try to " + f"continue by ignoring the incorrect arguments." + ) diff --git a/mlrun/package/packagers/numpy_packagers.py b/mlrun/package/packagers/numpy_packagers.py new file mode 100644 index 000000000000..12c2ed3f7a7d --- /dev/null +++ b/mlrun/package/packagers/numpy_packagers.py @@ -0,0 +1,612 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import pathlib +import tempfile +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple, Union + +import numpy as np +import pandas as pd + +from mlrun.artifacts import Artifact, DatasetArtifact +from mlrun.datastore import DataItem +from mlrun.errors import MLRunInvalidArgumentError + +from ..utils import ArtifactType, SupportedFormat +from .default_packager import DefaultPackager + +# Type for collection of numpy arrays (list / dict of arrays): +NumPyArrayCollectionType = Union[List[np.ndarray], Dict[str, np.ndarray]] + + +class _Formatter(ABC): + """ + An abstract class for a numpy formatter - supporting saving and loading arrays to and from specific file type. + """ + + @classmethod + @abstractmethod + def save( + cls, + obj: Union[np.ndarray, NumPyArrayCollectionType], + file_path: str, + **save_kwargs: dict, + ): + """ + Save the given array to the file path given. + + :param obj: The numpy array to save. + :param file_path: The file to save to. + :param save_kwargs: Additional keyword arguments to pass to the relevant save function of numpy. + """ + pass + + @classmethod + @abstractmethod + def load( + cls, file_path: str, **load_kwargs: dict + ) -> Union[np.ndarray, NumPyArrayCollectionType]: + """ + Load the array from the given file path. + + :param file_path: The file to load the array from. + :param load_kwargs: Additional keyword arguments to pass to the relevant load function of numpy. + + :return: The loaded array. + """ + pass + + +class _TXTFormatter(_Formatter): + """ + A static class for managing numpy txt files. + """ + + @classmethod + def save(cls, obj: np.ndarray, file_path: str, **save_kwargs: dict): + """ + Save the given array to the file path given. + + :param obj: The numpy array to save. + :param file_path: The file to save to. + :param save_kwargs: Additional keyword arguments to pass to the relevant save function of numpy. + + :raise MLRunInvalidArgumentError: If the array is above 2D. + """ + if len(obj.shape) > 2: + raise MLRunInvalidArgumentError( + f"Cannot save the given array to file. Only 1D and 2D arrays can be saved to text files but the given " + f"array is {len(obj.shape)}D (shape of {obj.shape}). Please use 'npy' format instead." + ) + np.savetxt(file_path, obj, **save_kwargs) + + @classmethod + def load(cls, file_path: str, **load_kwargs: dict) -> np.ndarray: + """ + Load the array from the given 'txt' file path. + + :param file_path: The file to load the array from. + :param load_kwargs: Additional keyword arguments to pass to the relevant load function of numpy. + + :return: The loaded array. + """ + return np.loadtxt(file_path, **load_kwargs) + + +class _CSVFormatter(_TXTFormatter): + """ + A static class for managing numpy csv files. + """ + + @classmethod + def save(cls, obj: np.ndarray, file_path: str, **save_kwargs: dict): + """ + Save the given array to the file path given. + + :param obj: The numpy array to save. + :param file_path: The file to save to. + :param save_kwargs: Additional keyword arguments to pass to the relevant save function of numpy. + + :raise MLRunInvalidArgumentError: If the array is above 2D. + """ + super().save(obj=obj, file_path=file_path, **{"delimiter": ",", **save_kwargs}) + + @classmethod + def load(cls, file_path: str, **load_kwargs: dict) -> np.ndarray: + """ + Load the array from the given 'txt' file path. + + :param file_path: The file to load the array from. + :param load_kwargs: Additional keyword arguments to pass to the relevant load function of numpy. + + :return: The loaded array. + """ + return super().load(file_path=file_path, **{"delimiter": ",", **load_kwargs}) + + +class _NPYFormatter(_Formatter): + """ + A static class for managing numpy npy files. + """ + + @classmethod + def save(cls, obj: np.ndarray, file_path: str, **save_kwargs: dict): + """ + Save the given array to the file path given. + + :param obj: The numpy array to save. + :param file_path: The file to save to. + :param save_kwargs: Additional keyword arguments to pass to the relevant save function of numpy. + """ + np.save(file_path, obj, **save_kwargs) + + @classmethod + def load(cls, file_path: str, **load_kwargs: dict) -> np.ndarray: + """ + Load the array from the given 'npy' file path. + + :param file_path: The file to load the array from. + :param load_kwargs: Additional keyword arguments to pass to the relevant load function of numpy. + + :return: The loaded array. + """ + return np.load(file_path, **load_kwargs) + + +class _NPZFormatter(_Formatter): + """ + A static class for managing numpy npz files. + """ + + @classmethod + def save( + cls, + obj: NumPyArrayCollectionType, + file_path: str, + is_compressed: bool = False, + **save_kwargs: dict, + ): + """ + Save the given array to the file path given. + + :param obj: The numpy array to save. + :param file_path: The file to save to. + :param is_compressed: Whether to save it as a compressed npz file. + :param save_kwargs: Additional keyword arguments to pass to the relevant save function of numpy. + """ + save_function = np.savez_compressed if is_compressed else np.savez + if isinstance(obj, list): + save_function(file_path, *obj) + else: + save_function(file_path, **obj) + + @classmethod + def load(cls, file_path: str, **load_kwargs: dict) -> Dict[str, np.ndarray]: + """ + Load the arrays from the given 'npz' file path. + + :param file_path: The file to load the array from. + :param load_kwargs: Additional keyword arguments to pass to the relevant load function of numpy. + + :return: The loaded arrays as a mapping (dictionary) of type `np.lib.npyio.NpzFile`. + """ + return np.load(file_path, **load_kwargs) + + +class NumPySupportedFormat(SupportedFormat[_Formatter]): + """ + Library of numpy formats (file extensions) supported by the NumPy packagers. + """ + + NPY = "npy" + NPZ = "npz" + TXT = "txt" + GZ = "gz" + CSV = "csv" + + _FORMAT_HANDLERS_MAP = { + NPY: _NPYFormatter, + NPZ: _NPZFormatter, + TXT: _TXTFormatter, + GZ: _TXTFormatter, # 'gz' format handled the same as 'txt'. + CSV: _CSVFormatter, + } + + @classmethod + def get_single_array_formats(cls) -> List[str]: + """ + Get the supported formats for saving one numpy array. + + :return: A list of all the supported formats for saving one numpy array. + """ + return [cls.NPY, cls.TXT, cls.GZ, cls.CSV] + + @classmethod + def get_multi_array_formats(cls) -> List[str]: + """ + Get the supported formats for saving a collection (multiple) numpy arrays - e.g. list of arrays or dictionary of + arrays. + + :return: A list of all the supported formats for saving multiple numpy arrays. + """ + return [cls.NPZ] + + +# Default file formats for numpy arrays file artifacts: +DEFAULT_NUMPY_ARRAY_FORMAT = NumPySupportedFormat.NPY +DEFAULT_NUMPPY_ARRAY_COLLECTION_FORMAT = NumPySupportedFormat.NPZ + + +class NumPyNDArrayPackager(DefaultPackager): + """ + ``numpy.ndarray`` packager. + """ + + PACKABLE_OBJECT_TYPE = np.ndarray + + # The size of an array to be stored as a result, rather than a file in the `get_default_packing_artifact_type` + # method: + _ARRAY_SIZE_AS_RESULT = 10 + + @classmethod + def get_default_packing_artifact_type(cls, obj: np.ndarray) -> str: + """ + Get the default artifact type. Will be a result if the array size is less than 10, otherwise file. + + :param obj: The about to be packed array. + + :return: The default artifact type. + """ + if obj.size < cls._ARRAY_SIZE_AS_RESULT: + return ArtifactType.RESULT + return ArtifactType.FILE + + @classmethod + def get_default_unpacking_artifact_type(cls, data_item: DataItem) -> str: + """ + Get the default artifact type used for unpacking. Returns dataset if the data item represents a + `DatasetArtifact` and otherwise, file. + + :param data_item: The about to be unpacked data item. + + :return: The default artifact type. + """ + is_artifact = data_item.get_artifact_type() + if is_artifact and is_artifact == "datasets": + return ArtifactType.DATASET + return ArtifactType.FILE + + @classmethod + def pack_result(cls, obj: np.ndarray, key: str) -> dict: + """ + Pack an array as a result. + + :param obj: The array to pack and log. + :param key: The result's key. + + :return: The result dictionary. + """ + # If the array is a number (size of 1), then we'll lok it as a single number. Otherwise, log as a list result: + if obj.size == 1: + obj = obj.item() + else: + obj = obj.tolist() + + return super().pack_result(obj=obj, key=key) + + @classmethod + def pack_file( + cls, + obj: np.ndarray, + key: str, + file_format: str = DEFAULT_NUMPY_ARRAY_FORMAT, + **save_kwargs, + ) -> Tuple[Artifact, dict]: + """ + Pack an array as a file by the given format. + + :param obj: The aray to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is npy. + :param save_kwargs: Additional keyword arguments to pass to the numpy save functions. + + :return: The packed artifact and instructions. + """ + # Save to file: + formatter = NumPySupportedFormat.get_format_handler(fmt=file_format) + temp_directory = pathlib.Path(tempfile.mkdtemp()) + cls.add_future_clearing_path(path=temp_directory) + file_path = temp_directory / f"{key}.{file_format}" + formatter.save(obj=obj, file_path=str(file_path), **save_kwargs) + + # Create the artifact and instructions: + artifact = Artifact(key=key, src_path=os.path.abspath(file_path)) + instructions = {"file_format": file_format} + + return artifact, instructions + + @classmethod + def pack_dataset( + cls, + obj: np.ndarray, + key: str, + file_format: str = "", + ) -> Tuple[Artifact, dict]: + """ + Pack an array as a dataset. + + :param obj: The aray to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is parquet. + + :return: The packed artifact and instructions. + + :raise MLRunInvalidArgumentError: IF the shape of the array is not 1D / 2D. + """ + # Validate it's a 2D array: + if len(obj.shape) > 2: + raise MLRunInvalidArgumentError( + f"Cannot log the given numpy array as a dataset. Only 2D arrays can be saved as dataset, but the array " + f"is {len(obj.shape)}D (shape of {obj.shape}). Please specify to log it as a 'file' instead ('npy' " + f"format) or as an 'object' (pickle)." + ) + + # Cast to a `pd.DataFrame`: + data_frame = pd.DataFrame(data=obj) + + # Create the artifact: + artifact = DatasetArtifact(key=key, df=data_frame, format=file_format) + + return artifact, {} + + @classmethod + def unpack_file(cls, data_item: DataItem, file_format: str = None) -> np.ndarray: + """ + Unpack a numppy array from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the array. Default is None - will be read by the file + extension. + + :return: The unpacked array. + """ + # Get the file: + file_path = data_item.local() + cls.add_future_clearing_path(path=file_path) + + # Get the archive format by the file extension if needed: + if file_format is None: + file_format = NumPySupportedFormat.match_format(path=file_path) + if ( + file_format is None + or file_format in NumPySupportedFormat.get_multi_array_formats() + ): + raise MLRunInvalidArgumentError( + f"File format of {data_item.key} ('{''.join(pathlib.Path(file_path).suffixes)}') is not supported. " + f"Supported formats are: {' '.join(NumPySupportedFormat.get_single_array_formats())}" + ) + + # Read the object: + formatter = NumPySupportedFormat.get_format_handler(fmt=file_format) + obj = formatter.load(file_path=file_path) + + return obj + + @classmethod + def unpack_dataset(cls, data_item: DataItem) -> np.ndarray: + """ + Unpack a numppy array from a dataset artifact. + + :param data_item: The data item to unpack. + + :return: The unpacked array. + """ + # Get the artifact's data frame: + data_frame = data_item.as_df() + + # Cast the data frame to a `np.ndarray` (1D arrays are returned as a 2D array with shape of 1xn, so we use + # squeeze to decrease the redundant dimension): + array = data_frame.to_numpy().squeeze() + + return array + + +class _NumPyNDArrayCollectionPackager(DefaultPackager): + """ + A base packager for builtin python dictionaries and lists of numpy arrays as they share common artifact and file + types. + """ + + DEFAULT_PACKING_ARTIFACT_TYPE = ArtifactType.FILE + DEFAULT_UNPACKING_ARTIFACT_TYPE = ArtifactType.FILE + PRIORITY = 4 + + @classmethod + def pack_file( + cls, + obj: NumPyArrayCollectionType, + key: str, + file_format: str = DEFAULT_NUMPPY_ARRAY_COLLECTION_FORMAT, + **save_kwargs, + ) -> Tuple[Artifact, dict]: + """ + Pack an array collection as a file by the given format. + + :param obj: The aray collection to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is npy. + :param save_kwargs: Additional keyword arguments to pass to the numpy save functions. + + :return: The packed artifact and instructions. + """ + # Save to file: + formatter = NumPySupportedFormat.get_format_handler(fmt=file_format) + temp_directory = pathlib.Path(tempfile.mkdtemp()) + cls.add_future_clearing_path(path=temp_directory) + file_path = temp_directory / f"{key}.{file_format}" + formatter.save(obj=obj, file_path=str(file_path), **save_kwargs) + + # Create the artifact and instructions: + artifact = Artifact(key=key, src_path=os.path.abspath(file_path)) + + return artifact, {"file_format": file_format} + + @classmethod + def unpack_file( + cls, data_item: DataItem, file_format: str = None + ) -> Dict[str, np.ndarray]: + """ + Unpack a numppy array collection from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the array collection. Default is None - will be read by + the file extension. + + :return: The unpacked array collection. + """ + # Get the file: + file_path = data_item.local() + cls.add_future_clearing_path(path=file_path) + + # Get the archive format by the file extension if needed: + if file_format is None: + file_format = NumPySupportedFormat.match_format(path=file_path) + if ( + file_format is None + or file_format in NumPySupportedFormat.get_single_array_formats() + ): + raise MLRunInvalidArgumentError( + f"File format of {data_item.key} ('{''.join(pathlib.Path(file_path).suffixes)}') is not supported. " + f"Supported formats are: {' '.join(NumPySupportedFormat.get_multi_array_formats())}" + ) + + # Read the object: + formatter = NumPySupportedFormat.get_format_handler(fmt=file_format) + obj = formatter.load(file_path=file_path) + + return obj + + +class NumPyNDArrayDictPackager(_NumPyNDArrayCollectionPackager): + """ + ``dict[str, numpy.ndarray]`` packager. + """ + + PACKABLE_OBJECT_TYPE = Dict[str, np.ndarray] + + @classmethod + def is_packable(cls, obj: Any, artifact_type: str = None) -> bool: + """ + Check if the object provided is a dictionary of numpy arrays. + + :param obj: The object to pack. + :param artifact_type: The artifact type to log the object as. + + :return: True if packable and False otherwise. + """ + if not ( + isinstance(obj, dict) + and all( + isinstance(key, str) and isinstance(value, np.ndarray) + for key, value in obj.items() + ) + ): + return False + if artifact_type and artifact_type not in cls.get_supported_artifact_types(): + return False + return True + + @classmethod + def pack_result(cls, obj: Dict[str, np.ndarray], key: str) -> dict: + """ + Pack an array dictionary as a result. + + :param obj: The array to pack and log. + :param key: The result's key. + + :return: The result dictionary. + """ + return { + key: { + array_key: array_value.tolist() + for array_key, array_value in obj.items() + } + } + + @classmethod + def unpack_file( + cls, data_item: DataItem, file_format: str = None + ) -> Dict[str, np.ndarray]: + # Load the object: + obj = super().unpack_file(data_item=data_item, file_format=file_format) + + # The returned object is a mapping of type NpzFile, so we cast it to a dictionary: + return {key: array for key, array in obj.items()} + + +class NumPyNDArrayListPackager(_NumPyNDArrayCollectionPackager): + """ + ``list[numpy.ndarray]`` packager. + """ + + PACKABLE_OBJECT_TYPE = List[np.ndarray] + + @classmethod + def is_packable(cls, obj: Any, artifact_type: str = None) -> bool: + """ + Check if the object provided is a list of numpy arrays. + + :param obj: The object to pack. + :param artifact_type: The artifact type to log the object as. + + :return: True if packable and False otherwise. + """ + if not ( + isinstance(obj, list) + and all(isinstance(value, np.ndarray) for value in obj) + ): + return False + if artifact_type and artifact_type not in cls.get_supported_artifact_types(): + return False + return True + + @classmethod + def pack_result(cls, obj: List[np.ndarray], key: str) -> dict: + return {key: [array.tolist() for array in obj]} + + @classmethod + def unpack_file( + cls, data_item: DataItem, file_format: str = None + ) -> List[np.ndarray]: + # Load the object: + obj = super().unpack_file(data_item=data_item, file_format=file_format) + + # The returned object is a mapping of type NpzFile, so we cast it to a list: + return list(obj.values()) + + +class NumPyNumberPackager(DefaultPackager): + """ + ``numpy.number`` packager. It is also used for all `number` inheriting numpy objects (`float32`, uint8, etc). + """ + + PACKABLE_OBJECT_TYPE = np.number + DEFAULT_PACKING_ARTIFACT_TYPE = ArtifactType.RESULT + PACK_SUBCLASSES = True # To include all dtypes ('float32', 'uint8', ...) + + @classmethod + def pack_result(cls, obj: np.number, key: str) -> dict: + return super().pack_result(obj=obj.item(), key=key) diff --git a/mlrun/package/packagers/pandas_packagers.py b/mlrun/package/packagers/pandas_packagers.py new file mode 100644 index 000000000000..6e974de8381b --- /dev/null +++ b/mlrun/package/packagers/pandas_packagers.py @@ -0,0 +1,968 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +import os +import pathlib +import tempfile +from abc import ABC, abstractmethod +from typing import Any, List, Tuple, Union + +import pandas as pd + +from mlrun.artifacts import Artifact, DatasetArtifact +from mlrun.datastore import DataItem +from mlrun.errors import MLRunInvalidArgumentError + +from ..utils import ArtifactType, SupportedFormat +from .default_packager import DefaultPackager + + +class _Formatter(ABC): + """ + An abstract class for a pandas formatter - supporting saving and loading dataframes to and from specific file type. + """ + + @classmethod + @abstractmethod + def to( + cls, obj: pd.DataFrame, file_path: str, flatten: bool = True, **to_kwargs + ) -> dict: + """ + Save the given dataframe to the file path given. + + :param obj: The dataframe to save. + :param file_path: The file to save to. + :param flatten: Whether to flatten the dataframe before saving. For some formats it is mandatory to enable + flattening, otherwise saving and loading the dataframe will cause unexpected behavior + especially in case it is multi-level or multi-index. Default to True. + :param to_kwargs: Additional keyword arguments to pass to the relevant `to_x` function. + + :return A dictionary of keyword arguments for reading the dataframe from file. + """ + pass + + @classmethod + @abstractmethod + def read( + cls, file_path: str, unflatten_kwargs: dict = None, **read_kwargs + ) -> pd.DataFrame: + """ + Read the dataframe from the given file path. + + :param file_path: The file to read the dataframe from. + :param unflatten_kwargs: Unflatten keyword arguments for unflattening the read dataframe. + :param read_kwargs: Additional keyword arguments to pass to the relevant read function of pandas. + + :return: The loaded dataframe. + """ + pass + + @staticmethod + def _flatten_dataframe(dataframe: pd.DataFrame) -> Tuple[pd.DataFrame, dict]: + """ + Flatten the dataframe: moving all indexes to be columns at the start (from column 0) and lowering the columns + levels to 1, renaming them from tuples. All columns and index info is stored so it can be unflatten later on. + + :param dataframe: The dataframe to flatten. + + :return: The flat dataframe. + """ + # Save columns info: + columns = list(dataframe.columns) + if isinstance(dataframe.columns, pd.MultiIndex): + columns = [list(column_tuple) for column_tuple in columns] + columns_levels = list(dataframe.columns.names) + + # Save index info: + index_levels = list(dataframe.index.names) + + # Turn multi-index columns into single columns: + if len(columns_levels) > 1: + # We turn the column tuple into a string to eliminate parsing issues during savings to text formats: + dataframe.columns = pd.Index( + "-".join(column_tuple) for column_tuple in columns + ) + + # Rename indexes in case they appear in the columns so it won't get overriden when the index reset: + dataframe.index.set_names( + names=[ + name + if name is not None and name not in dataframe.columns + else f"INDEX_{name}_{i}" + for i, name in enumerate(dataframe.index.names) + ], + inplace=True, + ) + + # Reset the index, moving the current index to a column: + dataframe.reset_index(inplace=True) + + return dataframe, { + "columns": columns, + "columns_levels": columns_levels, + "index_levels": index_levels, + } + + @staticmethod + def _unflatten_dataframe( + dataframe: pd.DataFrame, + columns: list, + columns_levels: list, + index_levels: list, + ) -> pd.DataFrame: + """ + Unflatten the dataframe, moving the indexes from the columns and resuming the columns levels and names. + + :param dataframe: The dataframe to unflatten. + :param columns: The original list of columns. + :param columns_levels: The original columns levels names. + :param index_levels: The original index levels names. + + :return: The un-flatted dataframe. + """ + # Move back index from columns: + dataframe.set_index( + keys=list(dataframe.columns[: len(index_levels)]), inplace=True + ) + dataframe.index.set_names(names=index_levels, inplace=True) + + # Set the columns back in case they were multi-leveled: + if len(columns_levels) > 1: + dataframe.columns = pd.MultiIndex.from_tuples( + tuples=columns, names=columns_levels + ) + else: + dataframe.columns.set_names(names=columns_levels, inplace=True) + + return dataframe + + +class _ParquetFormatter(_Formatter): + """ + A static class for managing pandas parquet files. + """ + + @classmethod + def to( + cls, obj: pd.DataFrame, file_path: str, flatten: bool = True, **to_kwargs + ) -> dict: + """ + Save the given dataframe to the parquet file path given. + + :param obj: The dataframe to save. + :param file_path: The file to save to. + :param flatten: Ignored for parquet format. + :param to_kwargs: Additional keyword arguments to pass to the `to_parquet` function. + + :return A dictionary of keyword arguments for reading the dataframe from file. + """ + obj.to_parquet(path=file_path, **to_kwargs) + return {} + + @classmethod + def read( + cls, file_path: str, unflatten_kwargs: dict = None, **read_kwargs + ) -> pd.DataFrame: + """ + Read the dataframe from the given parquet file path. + + :param file_path: The file to read the dataframe from. + :param unflatten_kwargs: Ignored for parquet format. + :param read_kwargs: Additional keyword arguments to pass to the `read_parquet` function. + + :return: The loaded dataframe. + """ + return pd.read_parquet(path=file_path, **read_kwargs) + + +class _CSVFormatter(_Formatter): + """ + A static class for managing pandas csv files. + """ + + @classmethod + def to( + cls, obj: pd.DataFrame, file_path: str, flatten: bool = True, **to_kwargs + ) -> dict: + """ + Save the given dataframe to the csv file path given. + + :param obj: The dataframe to save. + :param file_path: The file to save to. + :param flatten: Whether to flatten the dataframe before saving. For some formats it is mandatory to enable + flattening, otherwise saving and loading the dataframe will cause unexpected behavior + especially in case it is multi-level or multi-index. Default to True. + :param to_kwargs: Additional keyword arguments to pass to the `to_csv` function. + + :return A dictionary of keyword arguments for reading the dataframe from file. + """ + # Flatten the dataframe (this format have problems saving multi-level dataframes): + instructions = {} + if flatten: + obj, unflatten_kwargs = cls._flatten_dataframe(dataframe=obj) + instructions["unflatten_kwargs"] = unflatten_kwargs + + # Write to csv: + obj.to_csv(path_or_buf=file_path, **to_kwargs) + + return instructions + + @classmethod + def read( + cls, file_path: str, unflatten_kwargs: dict = None, **read_kwargs + ) -> pd.DataFrame: + """ + Read the dataframe from the given csv file path. + + :param file_path: The file to read the dataframe from. + :param unflatten_kwargs: Unflatten keyword arguments for unflattening the read dataframe. + :param read_kwargs: Additional keyword arguments to pass to the `read_csv` function. + + :return: The loaded dataframe. + """ + # Read the csv: + obj = pd.read_csv(filepath_or_buffer=file_path, **read_kwargs) + + # Check if it was flattened in packing: + if unflatten_kwargs is not None: + # Remove the default index (joined with reset index): + if obj.columns[0] == "Unnamed: 0": + obj.drop(columns=["Unnamed: 0"], inplace=True) + # Unflatten the dataframe: + obj = cls._unflatten_dataframe(dataframe=obj, **unflatten_kwargs) + + return obj + + +class _H5Formatter(_Formatter): + """ + A static class for managing pandas h5 files. + """ + + @classmethod + def to( + cls, obj: pd.DataFrame, file_path: str, flatten: bool = True, **to_kwargs + ) -> dict: + """ + Save the given dataframe to the h5 file path given. + + :param obj: The dataframe to save. + :param file_path: The file to save to. + :param flatten: Ignored for h5 format. + :param to_kwargs: Additional keyword arguments to pass to the `to_hdf` function. + + :return A dictionary of keyword arguments for reading the dataframe from file. + """ + # If user didn't provide a key for the dataframe, use default key 'table': + key = to_kwargs.pop("key", "table") + + # Write to h5: + obj.to_hdf(path_or_buf=file_path, key=key, **to_kwargs) + + return {"key": key} + + @classmethod + def read( + cls, file_path: str, unflatten_kwargs: dict = None, **read_kwargs + ) -> pd.DataFrame: + """ + Read the dataframe from the given h5 file path. + + :param file_path: The file to read the dataframe from. + :param unflatten_kwargs: Ignored for h5 format. + :param read_kwargs: Additional keyword arguments to pass to the `read_hdf` function. + + :return: The loaded dataframe. + """ + return pd.read_hdf(path_or_buf=file_path, **read_kwargs) + + +class _XMLFormatter(_Formatter): + """ + A static class for managing pandas xml files. + """ + + @classmethod + def to( + cls, obj: pd.DataFrame, file_path: str, flatten: bool = True, **to_kwargs + ) -> dict: + """ + Save the given dataframe to the xml file path given. + + :param obj: The dataframe to save. + :param file_path: The file to save to. + :param flatten: Whether to flatten the dataframe before saving. For some formats it is mandatory to enable + flattening, otherwise saving and loading the dataframe will cause unexpected behavior + especially in case it is multi-level or multi-index. Default to True. + :param to_kwargs: Additional keyword arguments to pass to the `to_xml` function. + + :return A dictionary of keyword arguments for reading the dataframe from file. + """ + # Get the parser (if not provided, try to use `lxml`, otherwise `etree`): + parser = to_kwargs.pop("parser", None) + if parser is None: + try: + importlib.import_module("lxml") + parser = "lxml" + except ModuleNotFoundError: + parser = "etree" + instructions = {"parser": parser} + + # Flatten the dataframe (this format have problems saving multi-level dataframes): + if flatten: + obj, unflatten_kwargs = cls._flatten_dataframe(dataframe=obj) + instructions["unflatten_kwargs"] = unflatten_kwargs + + # Write to xml: + obj.to_xml(path_or_buffer=file_path, parser="etree", **to_kwargs) + + return instructions + + @classmethod + def read( + cls, file_path: str, unflatten_kwargs: dict = None, **read_kwargs + ) -> pd.DataFrame: + """ + Read the dataframe from the given xml file path. + + :param file_path: The file to read the dataframe from. + :param unflatten_kwargs: Unflatten keyword arguments for unflattening the read dataframe. + :param read_kwargs: Additional keyword arguments to pass to the `read_xml` function. + + :return: The loaded dataframe. + """ + # Read the xml: + obj = pd.read_xml(path_or_buffer=file_path, **read_kwargs) + + # Check if it was flattened in packing: + if unflatten_kwargs is not None: + # Remove the default index (joined with reset index): + if obj.columns[0] == "index": + obj.drop(columns=["index"], inplace=True) + # Unflatten the dataframe: + obj = cls._unflatten_dataframe(dataframe=obj, **unflatten_kwargs) + + return obj + + +class _XLSXFormatter(_Formatter): + """ + A static class for managing pandas xlsx files. + """ + + @classmethod + def to( + cls, obj: pd.DataFrame, file_path: str, flatten: bool = True, **to_kwargs + ) -> dict: + """ + Save the given dataframe to the xlsx file path given. + + :param obj: The dataframe to save. + :param file_path: The file to save to. + :param flatten: Whether to flatten the dataframe before saving. For some formats it is mandatory to enable + flattening, otherwise saving and loading the dataframe will cause unexpected behavior + especially in case it is multi-level or multi-index. Default to True. + :param to_kwargs: Additional keyword arguments to pass to the `to_excel` function. + """ + # Get the engine to pass when unpacked: + instructions = {"engine": to_kwargs.get("engine", None)} + + # Flatten the dataframe (this format have problems saving multi-level dataframes): + if flatten: + obj, unflatten_kwargs = cls._flatten_dataframe(dataframe=obj) + instructions["unflatten_kwargs"] = unflatten_kwargs + + # Write to xlsx: + obj.to_excel(excel_writer=file_path, **to_kwargs) + + return instructions + + @classmethod + def read( + cls, file_path: str, unflatten_kwargs: dict = None, **read_kwargs + ) -> pd.DataFrame: + """ + Read the dataframe from the given xlsx file path. + + :param file_path: The file to read the dataframe from. + :param unflatten_kwargs: Unflatten keyword arguments for unflattening the read dataframe. + :param read_kwargs: Additional keyword arguments to pass to the `read_excel` function. + + :return: The loaded dataframe. + """ + # Read the xlsx: + obj = pd.read_excel(io=file_path, **read_kwargs) + + # Check if it was flattened in packing: + if unflatten_kwargs is not None: + # Remove the default index (joined with reset index): + if obj.columns[0] == "Unnamed: 0": + obj.drop(columns=["Unnamed: 0"], inplace=True) + # Unflatten the dataframe: + obj = cls._unflatten_dataframe(dataframe=obj, **unflatten_kwargs) + + return obj + + +class _HTMLFormatter(_Formatter): + """ + A static class for managing pandas html files. + """ + + @classmethod + def to( + cls, obj: pd.DataFrame, file_path: str, flatten: bool = True, **to_kwargs + ) -> dict: + """ + Save the given dataframe to the html file path given. + + :param obj: The dataframe to save. + :param file_path: The file to save to. + :param flatten: Whether to flatten the dataframe before saving. For some formats it is mandatory to enable + flattening, otherwise saving and loading the dataframe will cause unexpected behavior + especially in case it is multi-level or multi-index. Default to True. + :param to_kwargs: Additional keyword arguments to pass to the `to_html` function. + + :return A dictionary of keyword arguments for reading the dataframe from file. + """ + # Flatten the dataframe (this format have problems saving multi-level dataframes): + instructions = {} + if flatten: + obj, unflatten_kwargs = cls._flatten_dataframe(dataframe=obj) + instructions["unflatten_kwargs"] = unflatten_kwargs + + # Write to html: + obj.to_html(buf=file_path, **to_kwargs) + return instructions + + @classmethod + def read( + cls, file_path: str, unflatten_kwargs: dict = None, **read_kwargs + ) -> pd.DataFrame: + """ + Read dataframes from the given html file path. + + :param file_path: The file to read the dataframe from. + :param unflatten_kwargs: Unflatten keyword arguments for unflattening the read dataframe. + :param read_kwargs: Additional keyword arguments to pass to the `read_html` function. + + :return: The loaded dataframe. + """ + # Read the html: + obj = pd.read_html(io=file_path, **read_kwargs)[0] + + # Check if it was flattened in packing: + if unflatten_kwargs is not None: + # Remove the default index (joined with reset index): + if obj.columns[0] == "Unnamed: 0": + obj.drop(columns=["Unnamed: 0"], inplace=True) + # Unflatten the dataframe: + obj = cls._unflatten_dataframe(dataframe=obj, **unflatten_kwargs) + + return obj + + +class _JSONFormatter(_Formatter): + """ + A static class for managing pandas json files. + """ + + @classmethod + def to( + cls, obj: pd.DataFrame, file_path: str, flatten: bool = True, **to_kwargs + ) -> dict: + """ + Save the given dataframe to the json file path given. + + :param obj: The dataframe to save. + :param file_path: The file to save to. + :param flatten: Whether to flatten the dataframe before saving. For some formats it is mandatory to enable + flattening, otherwise saving and loading the dataframe will cause unexpected behavior + especially in case it is multi-level or multi-index. Default to True. + :param to_kwargs: Additional keyword arguments to pass to the `to_json` function. + + :return A dictionary of keyword arguments for reading the dataframe from file. + """ + # Get the orient to pass when unpacked: + instructions = {"orient": to_kwargs.get("orient", None)} + + # Flatten the dataframe (this format have problems saving multi-level dataframes): + if flatten: + obj, unflatten_kwargs = cls._flatten_dataframe(dataframe=obj) + instructions["unflatten_kwargs"] = unflatten_kwargs + + # Write to json: + obj.to_json(path_or_buf=file_path, **to_kwargs) + + return instructions + + @classmethod + def read( + cls, file_path: str, unflatten_kwargs: dict = None, **read_kwargs + ) -> pd.DataFrame: + """ + Read dataframes from the given json file path. + + :param file_path: The file to read the dataframe from. + :param unflatten_kwargs: Unflatten keyword arguments for unflattening the read dataframe. + :param read_kwargs: Additional keyword arguments to pass to the `read_json` function. + + :return: The loaded dataframe. + """ + # Read the json: + obj = pd.read_json(path_or_buf=file_path, **read_kwargs) + + # Check if it was flattened in packing: + if unflatten_kwargs is not None: + obj = cls._unflatten_dataframe(dataframe=obj, **unflatten_kwargs) + + return obj + + +class _FeatherFormatter(_Formatter): + """ + A static class for managing pandas feather files. + """ + + @classmethod + def to( + cls, obj: pd.DataFrame, file_path: str, flatten: bool = True, **to_kwargs + ) -> dict: + """ + Save the given dataframe to the feather file path given. + + :param obj: The dataframe to save. + :param file_path: The file to save to. + :param flatten: Whether to flatten the dataframe before saving. For some formats it is mandatory to enable + flattening, otherwise saving and loading the dataframe will cause unexpected behavior + especially in case it is multi-level or multi-index. Default to True. + :param to_kwargs: Additional keyword arguments to pass to the `to_feather` function. + + :return A dictionary of keyword arguments for reading the dataframe from file. + """ + # Flatten the dataframe (this format have problems saving multi-level dataframes): + instructions = {} + if flatten: + obj, unflatten_kwargs = cls._flatten_dataframe(dataframe=obj) + instructions["unflatten_kwargs"] = unflatten_kwargs + + # Write to feather: + obj.to_feather(path=file_path, **to_kwargs) + + return instructions + + @classmethod + def read( + cls, file_path: str, unflatten_kwargs: dict = None, **read_kwargs + ) -> pd.DataFrame: + """ + Read dataframes from the given feather file path. + + :param file_path: The file to read the dataframe from. + :param unflatten_kwargs: Unflatten keyword arguments for unflattening the read dataframe. + :param read_kwargs: Additional keyword arguments to pass to the `read_feather` function. + + :return: The loaded dataframe. + """ + # Read the feather: + obj = pd.read_feather(path=file_path, **read_kwargs) + + # Check if it was flattened in packing: + if unflatten_kwargs is not None: + obj = cls._unflatten_dataframe(dataframe=obj, **unflatten_kwargs) + + return obj + + +class _ORCFormatter(_Formatter): + """ + A static class for managing pandas orc files. + """ + + @classmethod + def to( + cls, obj: pd.DataFrame, file_path: str, flatten: bool = True, **to_kwargs + ) -> dict: + """ + Save the given dataframe to the orc file path given. + + :param obj: The dataframe to save. + :param file_path: The file to save to. + :param flatten: Whether to flatten the dataframe before saving. For some formats it is mandatory to enable + flattening, otherwise saving and loading the dataframe will cause unexpected behavior + especially in case it is multi-level or multi-index. Default to True. + :param to_kwargs: Additional keyword arguments to pass to the `to_orc` function. + + :return A dictionary of keyword arguments for reading the dataframe from file. + """ + # Flatten the dataframe (this format have problems saving multi-level dataframes): + instructions = {} + if flatten: + obj, unflatten_kwargs = cls._flatten_dataframe(dataframe=obj) + instructions["unflatten_kwargs"] = unflatten_kwargs + + # Write to feather: + obj.to_orc(path=file_path, **to_kwargs) + + return instructions + + @classmethod + def read( + cls, file_path: str, unflatten_kwargs: dict = None, **read_kwargs + ) -> pd.DataFrame: + """ + Read dataframes from the given orc file path. + + :param file_path: The file to read the dataframe from. + :param unflatten_kwargs: Unflatten keyword arguments for unflattening the read dataframe. + :param read_kwargs: Additional keyword arguments to pass to the `read_orc` function. + + :return: The loaded dataframe. + """ + # Read the feather: + obj = pd.read_orc(path=file_path, **read_kwargs) + + # Check if it was flattened in packing: + if unflatten_kwargs is not None: + obj = cls._unflatten_dataframe(dataframe=obj, **unflatten_kwargs) + + return obj + + +class PandasSupportedFormat(SupportedFormat[_Formatter]): + """ + Library of Pandas formats (file extensions) supported by the Pandas packagers. + """ + + PARQUET = "parquet" + CSV = "csv" + H5 = "h5" + XML = "xml" + XLSX = "xlsx" + HTML = "html" + JSON = "json" + FEATHER = "feather" + ORC = "orc" + + _FORMAT_HANDLERS_MAP = { + PARQUET: _ParquetFormatter, + CSV: _CSVFormatter, + H5: _H5Formatter, + XML: _XMLFormatter, + XLSX: _XLSXFormatter, + HTML: _HTMLFormatter, + JSON: _JSONFormatter, + FEATHER: _FeatherFormatter, + ORC: _ORCFormatter, + } + + +# Default file formats for pandas DataFrame and Series file artifacts: +DEFAULT_PANDAS_FORMAT = PandasSupportedFormat.PARQUET +NON_STRING_COLUMN_NAMES_DEFAULT_PANDAS_FORMAT = PandasSupportedFormat.CSV + + +class PandasDataFramePackager(DefaultPackager): + """ + ``pd.DataFrame`` packager. + """ + + PACKABLE_OBJECT_TYPE = pd.DataFrame + DEFAULT_PACKING_ARTIFACT_TYPE = ArtifactType.DATASET + + @classmethod + def get_default_unpacking_artifact_type(cls, data_item: DataItem) -> str: + """ + Get the default artifact type used for unpacking. Returns dataset if the data item represents a + `DatasetArtifact` and otherwise, file. + + :param data_item: The about to be unpacked data item. + + :return: The default artifact type. + """ + is_artifact = data_item.get_artifact_type() + if is_artifact and is_artifact == "datasets": + return ArtifactType.DATASET + return ArtifactType.FILE + + @classmethod + def pack_result(cls, obj: pd.DataFrame, key: str) -> dict: + """ + Pack a dataframe as a result. + + :param obj: The dataframe to pack and log. + :param key: The result's key. + + :return: The result dictionary. + """ + # Parse to dictionary according to the indexes in the dataframe: + if len(obj.index.names) > 1: + # Multiple indexes: + orient = "split" + elif obj.index.name is not None: + # Not a default index (user would likely want to keep it): + orient = "dict" + else: + # Default index can be ignored: + orient = "list" + + # Cast to dictionary: + dataframe_dictionary = obj.to_dict(orient=orient) + + # Prepare the result (casting tuples to lists): + dataframe_dictionary = PandasDataFramePackager._prepare_result( + obj=dataframe_dictionary + ) + + return super().pack_result(obj=dataframe_dictionary, key=key) + + @classmethod + def pack_file( + cls, + obj: pd.DataFrame, + key: str, + file_format: str = None, + flatten: bool = True, + **to_kwargs, + ) -> Tuple[Artifact, dict]: + """ + Pack a dataframe as a file by the given format. + + :param obj: The series to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is parquet or csv (depends on the column names as + parquet cannot be used for non string column names). + :param flatten: Whether to flatten the dataframe before saving. For some formats it is mandatory to enable + flattening, otherwise saving and loading the dataframe will cause unexpected behavior + especially in case it is multi-level or multi-index. Default to True. + :param to_kwargs: Additional keyword arguments to pass to the pandas `to_x` functions. + + :return: The packed artifact and instructions. + """ + # Set default file format if not given: + if file_format is None: + file_format = ( + DEFAULT_PANDAS_FORMAT + if all(isinstance(name, str) for name in obj.columns) + else NON_STRING_COLUMN_NAMES_DEFAULT_PANDAS_FORMAT + ) + + # Save to file: + formatter = PandasSupportedFormat.get_format_handler(fmt=file_format) + temp_directory = pathlib.Path(tempfile.mkdtemp()) + cls.add_future_clearing_path(path=temp_directory) + file_path = temp_directory / f"{key}.{file_format}" + read_kwargs = formatter.to( + obj=obj, file_path=str(file_path), flatten=flatten, **to_kwargs + ) + + # Create the artifact and instructions: + artifact = Artifact(key=key, src_path=os.path.abspath(file_path)) + + return artifact, {"file_format": file_format, "read_kwargs": read_kwargs} + + @classmethod + def pack_dataset(cls, obj: pd.DataFrame, key: str, file_format: str = "parquet"): + """ + Pack a pandas dataframe as a dataset. + + :param obj: The dataframe to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is parquet. + + :return: The packed artifact and instructions. + """ + return DatasetArtifact(key=key, df=obj, format=file_format), {} + + @classmethod + def unpack_file( + cls, + data_item: DataItem, + file_format: str = None, + read_kwargs: dict = None, + ) -> pd.DataFrame: + """ + Unpack a pandas dataframe from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the series. Default is None - will be read by the file + extension. + :param read_kwargs: Keyword arguments to pass to the read of the formatter. + + :return: The unpacked series. + """ + # Get the file: + file_path = data_item.local() + cls.add_future_clearing_path(path=file_path) + + # Get the archive format by the file extension if needed: + if file_format is None: + file_format = PandasSupportedFormat.match_format(path=file_path) + if file_format is None: + raise MLRunInvalidArgumentError( + f"File format of {data_item.key} ('{''.join(pathlib.Path(file_path).suffixes)}') is not supported. " + f"Supported formats are: {' '.join(PandasSupportedFormat.get_all_formats())}" + ) + + # Read the object: + formatter = PandasSupportedFormat.get_format_handler(fmt=file_format) + if read_kwargs is None: + read_kwargs = {} + return formatter.read(file_path=file_path, **read_kwargs) + + @classmethod + def unpack_dataset(cls, data_item: DataItem): + """ + Unpack a padnas dataframe from a dataset artifact. + + :param data_item: The data item to unpack. + + :return: The unpacked dataframe. + """ + return data_item.as_df() + + @staticmethod + def _prepare_result(obj: Union[list, dict, tuple]) -> Any: + """ + A dataframe can be logged as a result when it being cast to a dictionary. If the dataframe has multiple indexes, + pandas store them as a tuple, which is not json serializable, so we cast them into lists. + + :param obj: The dataframe dictionary (or list and tuple as it is recursive). + + :return: Prepared result. + """ + if isinstance(obj, dict): + for key, value in obj.items(): + obj[ + PandasDataFramePackager._prepare_result(obj=key) + ] = PandasDataFramePackager._prepare_result(obj=value) + elif isinstance(obj, list): + for i, value in enumerate(obj): + obj[i] = PandasDataFramePackager._prepare_result(obj=value) + elif isinstance(obj, tuple): + obj = [PandasDataFramePackager._prepare_result(obj=value) for value in obj] + return obj + + +class PandasSeriesPackager(PandasDataFramePackager): + """ + ``pd.Series`` packager. + """ + + PACKABLE_OBJECT_TYPE = pd.Series + DEFAULT_PACKING_ARTIFACT_TYPE = ArtifactType.FILE + + @classmethod + def get_supported_artifact_types(cls) -> List[str]: + """ + Get all the supported artifact types on this packager. It will be the same as `PandasDataFramePackager` but + without the 'dataset' artifact type support. + + :return: A list of all the supported artifact types. + """ + supported_artifacts = super().get_supported_artifact_types() + supported_artifacts.remove("dataset") + return supported_artifacts + + @classmethod + def pack_result(cls, obj: pd.Series, key: str) -> dict: + """ + Pack a series as a result. + + :param obj: The series to pack and log. + :param key: The result's key. + + :return: The result dictionary. + """ + return super().pack_result(obj=pd.DataFrame(obj), key=key) + + @classmethod + def pack_file( + cls, + obj: pd.Series, + key: str, + file_format: str = None, + flatten: bool = True, + **to_kwargs, + ) -> Tuple[Artifact, dict]: + """ + Pack a series as a file by the given format. + + :param obj: The series to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is parquet or csv (depends on the column names as + parquet cannot be used for non string column names). + :param flatten: Whether to flatten the dataframe before saving. For some formats it is mandatory to enable + flattening, otherwise saving and loading the dataframe will cause unexpected behavior + especially in case it is multi-level or multi-index. Default to True. + :param to_kwargs: Additional keyword arguments to pass to the pandas `to_x` functions. + + :return: The packed artifact and instructions. + """ + # Get the series column name: + column_name = obj.name + + # Cast to dataframe and call the parent `pack_file`: + artifact, instructions = super().pack_file( + obj=pd.DataFrame(obj), + key=key, + file_format=file_format, + flatten=flatten, + **to_kwargs, + ) + + # Return the artifact with the updated instructions: + return artifact, {**instructions, "column_name": column_name} + + @classmethod + def unpack_file( + cls, + data_item: DataItem, + file_format: str = None, + read_kwargs: dict = None, + column_name: Union[str, int] = None, + ) -> pd.Series: + """ + Unpack a pandas series from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the series. Default is None - will be read by the file + extension. + :param read_kwargs: Keyword arguments to pass to the read of the formatter. + :param column_name: The name of the series column. + + :return: The unpacked series. + """ + # Read the object: + obj = super().unpack_file( + data_item=data_item, + file_format=file_format, + read_kwargs=read_kwargs, + ) + + # Cast the dataframe into a series: + if len(obj.columns) != 1: + raise MLRunInvalidArgumentError( + f"The data item received is of a `pandas.DataFrame` with more than one column: " + f"{', '.join(obj.columns)}. Hence it cannot be turned into a `pandas.Series`." + ) + obj = obj[obj.columns[0]] + + # Edit the column name (if `read_kwargs` is not None we can be sure it is a packed file artifact, so the column + # name, even if None, should be set to restore the object as it was): + if read_kwargs is not None: + obj.name = column_name + + return obj diff --git a/mlrun/package/packagers/python_standard_library_packagers.py b/mlrun/package/packagers/python_standard_library_packagers.py new file mode 100644 index 000000000000..fae400ad8311 --- /dev/null +++ b/mlrun/package/packagers/python_standard_library_packagers.py @@ -0,0 +1,616 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import pathlib +import tempfile +from typing import Tuple, Union + +from mlrun.artifacts import Artifact +from mlrun.datastore import DataItem +from mlrun.errors import MLRunInvalidArgumentError + +from ..utils import ( + DEFAULT_ARCHIVE_FORMAT, + DEFAULT_STRUCT_FILE_FORMAT, + ArchiveSupportedFormat, + ArtifactType, + StructFileSupportedFormat, +) +from .default_packager import DefaultPackager + +# ---------------------------------------------------------------------------------------------------------------------- +# builtins packagers: +# ---------------------------------------------------------------------------------------------------------------------- + + +class IntPackager(DefaultPackager): + """ + ``builtins.int`` packager. + """ + + PACKABLE_OBJECT_TYPE = int + DEFAULT_PACKING_ARTIFACT_TYPE = ArtifactType.RESULT + + +class FloatPackager(DefaultPackager): + """ + ``builtins.float`` packager. + """ + + PACKABLE_OBJECT_TYPE = float + DEFAULT_PACKING_ARTIFACT_TYPE = ArtifactType.RESULT + + +class BoolPackager(DefaultPackager): + """ + ``builtins.bool`` packager. + """ + + PACKABLE_OBJECT_TYPE = bool + DEFAULT_PACKING_ARTIFACT_TYPE = ArtifactType.RESULT + + +class StrPackager(DefaultPackager): + """ + ``builtins.str`` packager. + """ + + PACKABLE_OBJECT_TYPE = str + DEFAULT_PACKING_ARTIFACT_TYPE = ArtifactType.RESULT + DEFAULT_UNPACKING_ARTIFACT_TYPE = ArtifactType.PATH + + @classmethod + def pack_path( + cls, obj: str, key: str, archive_format: str = DEFAULT_ARCHIVE_FORMAT + ) -> Tuple[Artifact, dict]: + """ + Pack a path string value content (pack the file or directory in that path). + + :param obj: The string path value to pack. + :param key: The key to use for the artifact. + :param archive_format: The archive format to use in case the path is of a directory. Default is zip. + + :return: The packed artifact and instructions. + """ + # TODO: Add a configuration like `archive_file: bool = False` to enable archiving a single file to shrink it in + # size. In that case the `is_directory` instruction will make it so when an archive is received, if its + # a directory, when exporting it a directory path should be returned. And, if its a file, a path to the + # single file exported should be returned. + # Verify the path is of an existing file: + if not os.path.exists(obj): + raise MLRunInvalidArgumentError(f"The given path do not exist: '{obj}'") + + # Proceed by path type (file or directory): + if os.path.isfile(obj): + # Create the artifact: + artifact = Artifact(key=key, src_path=os.path.abspath(obj)) + instructions = {"is_directory": False} + elif os.path.isdir(obj): + # Archive the directory: + output_path = tempfile.mkdtemp() + archiver = ArchiveSupportedFormat.get_format_handler(fmt=archive_format) + archive_path = archiver.create_archive( + directory_path=obj, output_path=output_path + ) + # Create the artifact: + artifact = Artifact(key=key, src_path=archive_path) + instructions = {"archive_format": archive_format, "is_directory": True} + else: + raise MLRunInvalidArgumentError( + f"The given path is not a file nor a directory: '{obj}'" + ) + + return artifact, instructions + + @classmethod + def unpack_path( + cls, data_item: DataItem, is_directory: bool = False, archive_format: str = None + ) -> str: + """ + Unpack a data item representing a path string. If the path is of a file, the file is downloaded to a local + temporary directory and its path is returned. If the path is of a directory, the archive is extracted and the + directory path extracted is returned. + + :param data_item: The data item to unpack. + :param is_directory: Whether the path should be treated as a file or a directory. Files (even archives like + zip) won't be extracted. + :param archive_format: The archive format to use in case the path is of a directory. Default is None - will be + read by the archive file extension. + + :return: The unpacked string. + """ + # Get the file to a local temporary directory: + path = data_item.local() + + # Mark the downloaded file for future clear: + cls.add_future_clearing_path(path=path) + + # If it's not a directory, return the file path. Otherwise, it should be extracted according to the archive + # format: + if not is_directory: + return path + + # Get the archive format by the file extension: + if archive_format is None: + archive_format = ArchiveSupportedFormat.match_format(path=path) + if archive_format is None: + raise MLRunInvalidArgumentError( + f"Archive format of {data_item.key} ('{''.join(pathlib.Path(path).suffixes)}') is not supported. " + f"Supported formats are: {' '.join(ArchiveSupportedFormat.get_all_formats())}" + ) + + # Extract the archive: + archiver = ArchiveSupportedFormat.get_format_handler(fmt=archive_format) + directory_path = archiver.extract_archive( + archive_path=path, output_path=os.path.dirname(path) + ) + + # Mark the extracted content for future clear: + cls.add_future_clearing_path(path=directory_path) + + # Return the extracted directory path: + return directory_path + + +class _BuiltinCollectionPackager(DefaultPackager): + """ + A base packager for builtin python dictionaries and lists as they share common artifact and file types. + """ + + DEFAULT_PACKING_ARTIFACT_TYPE = ArtifactType.RESULT + DEFAULT_UNPACKING_ARTIFACT_TYPE = ArtifactType.FILE + + @classmethod + def pack_file( + cls, + obj: Union[dict, list], + key: str, + file_format: str = DEFAULT_STRUCT_FILE_FORMAT, + ) -> Tuple[Artifact, dict]: + """ + Pack a builtin collection as a file by the given format. + + :param obj: The builtin collection to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is json. + + :return: The packed artifact and instructions. + """ + # Write to file: + formatter = StructFileSupportedFormat.get_format_handler(fmt=file_format) + temp_directory = pathlib.Path(tempfile.mkdtemp()) + cls.add_future_clearing_path(path=temp_directory) + file_path = temp_directory / f"{key}.{file_format}" + formatter.write(obj=obj, file_path=str(file_path)) + + # Create the artifact and instructions: + artifact = Artifact(key=key, src_path=os.path.abspath(file_path)) + instructions = {"file_format": file_format} + + return artifact, instructions + + @classmethod + def unpack_file( + cls, data_item: DataItem, file_format: str = None + ) -> Union[dict, list]: + """ + Unpack a builtin collection from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the builtin collection. Default is None - will be read by + the file extension. + + :return: The unpacked builtin collection. + """ + # Get the file: + file_path = data_item.local() + cls.add_future_clearing_path(path=file_path) + + # Get the archive format by the file extension if needed: + if file_format is None: + file_format = StructFileSupportedFormat.match_format(path=file_path) + if file_format is None: + raise MLRunInvalidArgumentError( + f"File format of {data_item.key} ('{''.join(pathlib.Path(file_path).suffixes)}') is not supported. " + f"Supported formats are: {' '.join(StructFileSupportedFormat.get_all_formats())}" + ) + + # Read the object: + formatter = StructFileSupportedFormat.get_format_handler(fmt=file_format) + obj = formatter.read(file_path=file_path) + + return obj + + +class DictPackager(_BuiltinCollectionPackager): + """ + ``builtins.dict`` packager. + """ + + PACKABLE_OBJECT_TYPE = dict + + @classmethod + def unpack_file(cls, data_item: DataItem, file_format: str = None) -> dict: + """ + Unpack a dictionary from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the dictionary. Default is None - will be read by the + file extension. + + :return: The unpacked dictionary. + """ + # Unpack the object: + obj = super().unpack_file(data_item=data_item, file_format=file_format) + + # Check if needed to cast from list: + if isinstance(obj, list): + return {index: element for index, element in enumerate(obj)} + return obj + + +class ListPackager(_BuiltinCollectionPackager): + """ + ``builtins.list`` packager. + """ + + PACKABLE_OBJECT_TYPE = list + + @classmethod + def unpack_file(cls, data_item: DataItem, file_format: str = None) -> list: + """ + Unpack a list from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the list. Default is None - will be read by the file + extension. + + :return: The unpacked list. + """ + # Unpack the object: + obj = super().unpack_file(data_item=data_item, file_format=file_format) + + # Check if needed to cast from dict: + if isinstance(obj, dict): + return list(obj.values()) + return obj + + +class TuplePackager(ListPackager): + """ + ``builtins.tuple`` packager. + + Notice: a ``tuple`` returned from a function is usually treated as multiple returned objects, and so MLRun will try + to pack each of them separately and not as a single tuple. For example:: + + def example_func_1(): + return 10, [1, 2, 3], "Hello MLRun" + + Will be returned as a ``tuple`` of 3 items: `(10, [1, 2, 3], "Hello MLRun")` but the items will be packaged + separately one by one and not as a single ``tuple``. + + In order to pack tuples (not recommended), use the configuration:: + + mlrun.mlconf.packagers.pack_tuple = True + + Or more correctly, cast your returned tuple to a ``list`` like so:: + + def example_func_2(): + my_tuple = (2, 4) + return list(my_tuple) + """ + + PACKABLE_OBJECT_TYPE = tuple + + @classmethod + def pack_result(cls, obj: tuple, key: str) -> dict: + """ + Pack a tuple as a result. + + :param obj: The tuple to pack and log. + :param key: The result's key. + + :return: The result dictionary. + """ + return super().pack_result(obj=list(obj), key=key) + + @classmethod + def pack_file( + cls, obj: tuple, key: str, file_format: str = DEFAULT_STRUCT_FILE_FORMAT + ) -> Tuple[Artifact, dict]: + """ + Pack a tuple as a file by the given format. + + :param obj: The tuple to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is json. + + :return: The packed artifact and instructions. + """ + return super().pack_file(obj=list(obj), key=key, file_format=file_format) + + @classmethod + def unpack_file(cls, data_item: DataItem, file_format: str = None) -> tuple: + """ + Unpack a tuple from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the tuple. Default is None - will be read by the file + extension. + + :return: The unpacked tuple. + """ + return tuple(super().unpack_file(data_item=data_item, file_format=file_format)) + + +class SetPackager(ListPackager): + """ + ``builtins.set`` packager. + """ + + PACKABLE_OBJECT_TYPE = set + + @classmethod + def pack_result(cls, obj: set, key: str) -> dict: + """ + Pack a set as a result. + + :param obj: The set to pack and log. + :param key: The result's key. + + :return: The result dictionary. + """ + return super().pack_result(obj=list(obj), key=key) + + @classmethod + def pack_file( + cls, obj: set, key: str, file_format: str = DEFAULT_STRUCT_FILE_FORMAT + ) -> Tuple[Artifact, dict]: + """ + Pack a set as a file by the given format. + + :param obj: The set to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is json. + + :return: The packed artifact and instructions. + """ + return super().pack_file(obj=list(obj), key=key, file_format=file_format) + + @classmethod + def unpack_file(cls, data_item: DataItem, file_format: str = None) -> set: + """ + Unpack a set from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the set. Default is None - will be read by the file + extension. + + :return: The unpacked set. + """ + return set(super().unpack_file(data_item=data_item, file_format=file_format)) + + +class FrozensetPackager(SetPackager): + """ + ``builtins.frozenset`` packager. + """ + + PACKABLE_OBJECT_TYPE = frozenset + + @classmethod + def pack_file( + cls, obj: frozenset, key: str, file_format: str = DEFAULT_STRUCT_FILE_FORMAT + ) -> Tuple[Artifact, dict]: + """ + Pack a frozenset as a file by the given format. + + :param obj: The frozenset to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is json. + + :return: The packed artifact and instructions. + """ + return super().pack_file(obj=set(obj), key=key, file_format=file_format) + + @classmethod + def unpack_file(cls, data_item: DataItem, file_format: str = None) -> frozenset: + """ + Unpack a frozenset from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the frozenset. Default is None - will be read by the file + extension. + + :return: The unpacked frozenset. + """ + return frozenset( + super().unpack_file(data_item=data_item, file_format=file_format) + ) + + +class BytesPackager(ListPackager): + """ + ``builtins.bytes`` packager. + """ + + PACKABLE_OBJECT_TYPE = bytes + + @classmethod + def pack_result(cls, obj: bytes, key: str) -> dict: + """ + Pack bytes as a result. + + :param obj: The bytearray to pack and log. + :param key: The result's key. + + :return: The result dictionary. + """ + return {key: obj} + + @classmethod + def pack_file( + cls, obj: bytes, key: str, file_format: str = DEFAULT_STRUCT_FILE_FORMAT + ) -> Tuple[Artifact, dict]: + """ + Pack a bytes as a file by the given format. + + :param obj: The bytes to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is json. + + :return: The packed artifact and instructions. + """ + return super().pack_file(obj=list(obj), key=key, file_format=file_format) + + @classmethod + def unpack_file(cls, data_item: DataItem, file_format: str = None) -> bytes: + """ + Unpack a bytes from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the bytes. Default is None - will be read by the file + extension. + + :return: The unpacked bytes. + """ + return bytes(super().unpack_file(data_item=data_item, file_format=file_format)) + + +class BytearrayPackager(BytesPackager): + """ + ``builtins.bytearray`` packager. + """ + + PACKABLE_OBJECT_TYPE = bytearray + + @classmethod + def pack_result(cls, obj: bytearray, key: str) -> dict: + """ + Pack a bytearray as a result. + + :param obj: The bytearray to pack and log. + :param key: The result's key. + + :return: The result dictionary. + """ + return {key: bytes(obj)} + + @classmethod + def pack_file( + cls, obj: bytearray, key: str, file_format: str = DEFAULT_STRUCT_FILE_FORMAT + ) -> Tuple[Artifact, dict]: + """ + Pack a bytearray as a file by the given format. + + :param obj: The bytearray to pack. + :param key: The key to use for the artifact. + :param file_format: The file format to save as. Default is json. + + :return: The packed artifact and instructions. + """ + return super().pack_file(obj=bytes(obj), key=key, file_format=file_format) + + @classmethod + def unpack_file(cls, data_item: DataItem, file_format: str = None) -> bytearray: + """ + Unpack a bytearray from file. + + :param data_item: The data item to unpack. + :param file_format: The file format to use for reading the bytearray. Default is None - will be read by the file + extension. + + :return: The unpacked bytearray. + """ + return bytearray( + super().unpack_file(data_item=data_item, file_format=file_format) + ) + + +# ---------------------------------------------------------------------------------------------------------------------- +# pathlib packagers: +# ---------------------------------------------------------------------------------------------------------------------- + + +class PathPackager(StrPackager): + """ + ``pathlib.Path`` packager. It is also used for all `Path` inheriting pathlib objects (`PosixPath` and + `WindowsPath`). + """ + + PACKABLE_OBJECT_TYPE = pathlib.Path + PACK_SUBCLASSES = True + DEFAULT_PACKING_ARTIFACT_TYPE = "path" + + @classmethod + def pack_result(cls, obj: pathlib.Path, key: str) -> dict: + """ + Pack the `Path` as a string result. + + :param obj: The `Path` to pack. + :param key: The key to use in the results dictionary. + + :return: The packed result. + """ + return super().pack_result(obj=str(obj), key=key) + + @classmethod + def pack_path( + cls, obj: pathlib.Path, key: str, archive_format: str = DEFAULT_ARCHIVE_FORMAT + ) -> Tuple[Artifact, dict]: + """ + Pack a `Path` value (pack the file or directory in that path). + + :param obj: The `Path` to pack. + :param key: The key to use for the artifact. + :param archive_format: The archive format to use in case the path is of a directory. Default is zip. + + :return: The packed artifact and instructions. + """ + return super().pack_path(obj=str(obj), key=key, archive_format=archive_format) + + @classmethod + def unpack_path( + cls, data_item: DataItem, is_directory: bool = False, archive_format: str = None + ) -> pathlib.Path: + """ + Unpack a data item representing a `Path`. If the path is of a file, the file is downloaded to a local + temporary directory and its path is returned. If the path is of a directory, the archive is extracted and the + directory path extracted is returned. + + :param data_item: The data item to unpack. + :param is_directory: Whether the path should be treated as a file or a directory. Files (even archives like + zip) won't be extracted. + :param archive_format: The archive format to use in case the path is of a directory. Default is None - will be + read by the archive file extension. + + :return: The unpacked `Path`. + """ + return pathlib.Path( + super().unpack_path( + data_item=data_item, + is_directory=is_directory, + archive_format=archive_format, + ) + ) + + +# ---------------------------------------------------------------------------------------------------------------------- +# TODO: collection packagers: +# ---------------------------------------------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------------------------------------------- +# TODO: datetime packagers: +# ---------------------------------------------------------------------------------------------------------------------- diff --git a/mlrun/package/packagers_manager.py b/mlrun/package/packagers_manager.py new file mode 100644 index 000000000000..f3fe6b0a1148 --- /dev/null +++ b/mlrun/package/packagers_manager.py @@ -0,0 +1,781 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +import inspect +import os +import shutil +import traceback +from typing import Any, Dict, List, Tuple, Type, Union + +from mlrun.artifacts import Artifact +from mlrun.datastore import DataItem, store_manager +from mlrun.errors import MLRunInvalidArgumentError +from mlrun.utils import logger + +from .errors import ( + MLRunPackageCollectionError, + MLRunPackagePackingError, + MLRunPackageUnpackingError, +) +from .packager import Packager +from .packagers.default_packager import DefaultPackager +from .utils import LogHintKey, TypeHintUtils + + +class PackagersManager: + """ + A packager manager is holding the project's packagers and sending them objects to pack and data items to unpack. + + It prepares the instructions / log hint configurations and then looks for the first packager who fits the task. + That's why when the manager collects its packagers, it first collects builtin MLRun packagers and only then the + user's custom packagers, this way user's custom packagers will have higher priority. + """ + + def __init__(self, default_packager: Type[Packager] = None): + """ + Initialize a packagers manager. + + :param default_packager: The default packager should be a packager that fits to all types. It will be the first + packager in the manager's packagers (meaning it will be used at lowest priority) and it + should be found fitting when all packagers managed by the manager do not fit an + object or data item. Default to ``mlrun.DefaultPackager``. + """ + # Set the default packager: + self._default_packager = default_packager or DefaultPackager + + # Initialize the packagers list (with the default packager in it): + self._packagers: List[Type[Packager]] = [] + + # Set an artifacts list and results dictionary to collect all packed objects (will be used later to write extra + # data if noted by the user using the log hint key "extra_data") + self._artifacts: List[Artifact] = [] + self._results = {} + + @property + def artifacts(self) -> List[Artifact]: + """ + Get the artifacts that were packed by the manager. + + :return: A list of artifacts. + """ + return self._artifacts + + @property + def results(self) -> dict: + """ + Get the results that were packed by the manager. + + :return: A results dictionary. + """ + return self._results + + def collect_packagers( + self, packagers: List[Union[Type, str]], default_priority: int = 5 + ): + """ + Collect the provided packagers. Packagers passed as module paths will be imported and validated to be of type + `Packager`. If needed to import all packagers from a module, use the module path with a "*" at the end (packager + with a name that start with a '_' won't be collected). + + Notice: Only packagers that are declared in the module will be collected (packagers imported in the module scope + won't be collected). For example:: + + from mlrun import Packager + from x import XPackager + + class YPackager(Packager): + pass + + Only "YPackager" will be collected as it is declared in the module, but not "XPackager" which is only imported. + + :param packagers: List of packagers to add. + :param default_priority: + + :raise MLRunPackageCollectingError: In case the packager could not be collected. + """ + # Collect the packagers: + for packager in packagers: + # If it's a string, it's the module path of the class, so we import it: + if isinstance(packager, str): + # TODO: For supporting Hub packagers, if the string is a hub url, then look in the labels for the + # packagers to import and import the function as a module. + # Import the module: + module_name, class_name = self._split_module_path(module_path=packager) + try: + module = importlib.import_module(module_name) + except ModuleNotFoundError as module_not_found_error: + raise MLRunPackageCollectionError( + f"The packager '{class_name}' could not be collected from the module '{module_name}' as it " + f"cannot be imported: {module_not_found_error}" + ) from module_not_found_error + # Check if needed to import all packagers from the given module: + if class_name == "*": + # Get all the packagers from the module and collect them (this time they will be sent as `Packager` + # types to the method): + self.collect_packagers( + packagers=[ + member + for _, member in inspect.getmembers( + module, + lambda m: ( + # Validate it is declared in the module: + hasattr(m, "__module__") + and m.__module__ == module.__name__ + # Validate it is a `Packager`: + and isinstance(m, type) + and issubclass(m, Packager) + # Validate it is not a "protected" `Packager`: + and not m.__name__.startswith("_") + ), + ) + ] + ) + # Collected from the previous call, continue to the next packager in the list: + continue + # Import the packager and continue like as if it was given as a type: + try: + packager = getattr(module, class_name) + except AttributeError as attribute_error: + raise MLRunPackageCollectionError( + f"The packager '{class_name}' could not be collected as it does not exist in the module " + f"'{module.__name__}': {attribute_error}" + ) from attribute_error + # Validate the class given is a `Packager` type: + if not issubclass(packager, Packager): + raise MLRunPackageCollectionError( + f"The packager '{packager.__name__}' could not be collected as it is not a `mlrun.Packager`." + ) + # Set default priority in case it is not set in the packager's class: + if packager.PRIORITY is ...: + packager.PRIORITY = default_priority + # Collect the packager (putting him first in the list for highest priority: + self._packagers.insert(0, packager) + # For debugging, we'll print the collected packager: + logger.debug( + f"The packagers manager collected the packager: {str(packager)}" + ) + + # Sort the packagers: + self._packagers.sort() + + def pack( + self, obj: Any, log_hint: Dict[str, str] + ) -> Union[Artifact, dict, None, List[Union[Artifact, dict, None]]]: + """ + Pack an object using one of the manager's packagers. A `dict` ("**") or `list` ("*") unpacking syntax in the + log hint key will pack the objects within them in separate packages. + + :param obj: The object to pack as an artifact. + :param log_hint: The log hint to use. + + :return: The packaged artifact or result. None is returned if there was a problem while packing the object. If + a prefix of dict or list unpacking was provided in the log hint key, a list of all the arbitrary number + of packaged objects will be returned. + + :raise MLRunInvalidArgumentError: If the key in the log hint is noting to log an arbitrary amount of artifacts + but the object type does not match the "*" or "**" used in the key. + :raise MLRunPackagePackingError: If there was an error during the packing. + """ + # Get the key to see if needed to pack arbitrary number of objects via list or dict prefixes: + log_hint_key = log_hint[LogHintKey.KEY] + if log_hint_key.startswith("**"): + # A dictionary unpacking prefix was given, validate the object is a dictionary and prepare the objects to + # pack with their keys: + if not isinstance(obj, dict): + raise MLRunInvalidArgumentError( + f"The log hint key '{log_hint_key}' has a dictionary unpacking prefix ('**') to log arbitrary " + f"number of objects within the dictionary, but a dictionary was not provided, the given object is " + f"of type '{self._get_type_name(type(obj))}'. The object is ignored, to log it, please remove the " + f"'**' prefix from the key." + ) + objects_to_pack = { + f"{log_hint_key[len('**'):]}{dict_key}": dict_obj + for dict_key, dict_obj in obj.items() + } + elif log_hint_key.startswith("*"): + # An iterable unpacking prefix was given, validate the object is iterable and prepare the objects to pack + # with their keys: + is_iterable = True + try: + for _ in obj: + break + except TypeError: + is_iterable = False + if not is_iterable: + raise MLRunInvalidArgumentError( + f"The log hint key '{log_hint_key}' has an iterable unpacking prefix ('*') to log arbitrary number " + f"of objects within it (like a `list` or `set`), but an iterable object was not provided, the " + f"given object is of type '{self._get_type_name(type(obj))}'. The object is ignored, to log it, " + f"please remove the '*' prefix from the key." + ) + objects_to_pack = { + f"{log_hint_key[len('*'):]}{i}": obj_i for i, obj_i in enumerate(obj) + } + else: + # A single object is required to be packaged: + objects_to_pack = {log_hint_key: obj} + + # Go over the collected keys and objects and pack them: + packages = [] + for key, per_key_obj in objects_to_pack.items(): + # Edit the key in the log hint: + per_key_log_hint = log_hint.copy() + per_key_log_hint[LogHintKey.KEY] = key + # Pack and collect the package: + try: + packages.append(self._pack(obj=per_key_obj, log_hint=per_key_log_hint)) + except Exception as exception: + raise MLRunPackagePackingError( + f"An exception was raised during the packing of '{per_key_log_hint}': {exception}" + ) from exception + + # If multiple packages were packed, return a list, otherwise return the single package: + return packages if len(packages) > 1 else packages[0] + + def unpack(self, data_item: DataItem, type_hint: Type) -> Any: + """ + Unpack an object using one of the manager's packagers. The data item can be unpacked in two options: + + * As a package: If the data item contains a package and the type hint provided is equal to the object + type noted in the package. Or, if it's a package and a type hint was not provided. + * As a data item: If the data item is not a package or the type hint provided is not equal to the one noted in + the package. + + If the type hint is a `mlrun.DataItem` then it won't be unpacked. + + Notice: It is not recommended to use a different packager than the one who originally packed the object to + unpack it. A warning will be shown in that case. + + :param data_item: The data item holding the package. + :param type_hint: The type hint to parse the data item as. + + :return: The unpacked object parsed as type hinted. + """ + # Check if `DataItem` is hinted - meaning the user can expect a data item and do not want to unpack it: + if TypeHintUtils.is_matching(object_type=DataItem, type_hint=type_hint): + return data_item + + # Set variables to hold the manager notes and packager instructions: + artifact_key = None + packaging_instructions = None + + # Try to get the notes and instructions (can be found only in artifacts but data item may be a simple path/url): + if data_item.get_artifact_type(): + # Get the artifact object in the data item: + artifact, _ = store_manager.get_store_artifact(url=data_item.artifact_url) + # Get the key from the artifact's metadata and instructions from the artifact's spec: + artifact_key = artifact.metadata.key + packaging_instructions = artifact.spec.unpackaging_instructions + + # Unpack: + try: + if packaging_instructions: + # The data item is a package and the object type is equal or part of the type hint (part of is in case + # of a `typing.Union` for example): + return self._unpack_package( + data_item=data_item, + artifact_key=artifact_key, + packaging_instructions=packaging_instructions, + type_hint=type_hint, + ) + # The data item is not a package or the object type is not equal or part of the type hint: + return self._unpack_data_item( + data_item=data_item, + type_hint=type_hint, + ) + except Exception as exception: + raise MLRunPackageUnpackingError( + f"An exception was raised during the unpacking of '{data_item.key}': {exception}" + ) from exception + + def link_packages( + self, + additional_artifacts: List[Artifact], + additional_results: dict, + ): + """ + Link packages between each other according to the provided extra data and metrics spec keys. A future link is + marked with ellipses (...). If no link was found, None will be used and a warning will get printed. + + :param additional_artifacts: Additional artifacts to link (should come from a `mlrun.MLClientCtx`). + :param additional_results: Additional results to link (should come from a `mlrun.MLClientCtx`). + """ + # Join the manager's artifacts and results with the additional ones to look for a link in all of them: + joined_artifacts = [*additional_artifacts, *self.artifacts] + joined_results = {**additional_results, **self.results} + + # Go over the artifacts and link: + for artifact in self.artifacts: + # Go over the extra data keys: + for key in artifact.spec.extra_data: + # Future link is marked with ellipses (...): + if artifact.spec.extra_data[key] is ...: + # Look for an artifact or result with this key to link it: + extra_data = self._look_for_extra_data( + key=key, artifacts=joined_artifacts, results=joined_results + ) + # Print a warning if a link is missing: + if extra_data is None: + logger.warn( + f"Could not find {key} to link as extra data for {artifact.key}." + ) + # Link it (None will be used in case it was not found): + artifact.spec.extra_data[key] = extra_data + # Go over the metrics keys if available (`ModelArtifactSpec` has a metrics property that may be waiting for + # values from logged results): + if hasattr(artifact.spec, "metrics"): + for key in artifact.spec.metrics: + # Future link is marked with ellipses (...): + if artifact.spec.metrics[key] is ...: + # Link it (None will be used in case it was not found): + artifact.spec.metrics[key] = joined_results.get(key, None) + + def clear_packagers_outputs(self): + """ + Clear the outputs of all packagers. This method should be called at the end of the run after logging all + artifacts as some will require uploading the files that will be deleted in this method. + """ + for packager in self._get_packagers_with_default_packager(): + for path in packager.get_future_clearing_path_list(): + if not os.path.exists(path): + continue + if os.path.isdir(path): + shutil.rmtree(path) + else: + os.remove(path) + + class _InstructionsNotesKey: + """ + Library of keys for the packager instructions to be added to the packed artifact's spec. + """ + + PACKAGER_NAME = "packager_name" + OBJECT_TYPE = "object_type" + ARTIFACT_TYPE = "artifact_type" + INSTRUCTIONS = "instructions" + + def _get_packagers_with_default_packager(self) -> List[Type[Packager]]: + """ + Get the full list of packagers - the collected packagers and the default packager (located at last place in the + list - the lowest priority). + + :return: A list of the manager's packagers with the default packager. + """ + return [*self._packagers, self._default_packager] + + def _get_packager_by_name(self, name: str) -> Union[Type[Packager], None]: + """ + Look for a packager with the given name and return it. + + If a packager was not found None will be returned. + + :param name: The name of the packager to get. + + :return: The found packager or None if it wasn't found. + """ + # Look for a packager by exact name: + for packager in self._get_packagers_with_default_packager(): + if packager.__name__ == name: + return packager + + # No packager was found: + logger.warn(f"The packager '{name}' was not found.") + return None + + def _get_packager_for_packing( + self, + obj: Any, + artifact_type: str = None, + ) -> Union[Type[Packager], None]: + """ + Look for a packager that can pack the provided object as the provided artifact type. + + If a packager was not found None will be returned. + + :param obj: The object to pack. + :param artifact_type: The artifact type the packager to get should pack / unpack as. + + :return: The found packager or None if it wasn't found. + """ + # Look for a packager for the combination of object nad artifact type: + for packager in self._packagers: + if packager.is_packable(obj=obj, artifact_type=artifact_type): + return packager + + # No packager was found: + return None + + def _get_packager_for_unpacking( + self, + data_item: Any, + type_hint: type, + artifact_type: str = None, + ) -> Union[Type[Packager], None]: + """ + Look for a packager that can unpack the data item of the given type hint as the provided artifact type. + + If a packager was not found None will be returned. + + :param data_item: The data item to unpack. + :param type_hint: The type hint the packager to get should handle. + :param artifact_type: The artifact type the packager to get should pack / unpack as. + + :return: The found packager or None if it wasn't found. + """ + # Look for a packager for the combination of object type nad artifact type: + for packager in self._packagers: + if packager.is_unpackable( + data_item=data_item, type_hint=type_hint, artifact_type=artifact_type + ): + return packager + + # No packager was found: + return None + + def _pack(self, obj: Any, log_hint: dict) -> Union[Artifact, dict, None]: + """ + Pack an object using one of the manager's packagers. + + :param obj: The object to pack as an artifact. + :param log_hint: The log hint to use. + + :return: The packaged artifact or result. None is returned if there was a problem while packing the object. + """ + # Get the artifact type (if user didn't pass any, the packager will use its configured default): + artifact_type = log_hint.pop(LogHintKey.ARTIFACT_TYPE, None) + + # Get a packager: + packager = self._get_packager_for_packing(obj=obj, artifact_type=artifact_type) + if packager is None: + if self._default_packager.is_packable(obj=obj, artifact_type=artifact_type): + logger.info( + f"Using the default packager to pack the object '{log_hint[LogHintKey.KEY]}'" + ) + packager = self._default_packager + else: + raise MLRunPackagePackingError( + f"No packager was found for the combination of " + f"'object_type={self._get_type_name(typ=type(obj))}' and 'artifact_type={artifact_type}'." + ) + + # Use the packager to pack the object: + packed_object = packager.pack( + obj=obj, artifact_type=artifact_type, configurations=log_hint + ) + + # If the packed object is a result, return it as is: + if isinstance(packed_object, dict): + # Collect the result and return: + self._results.update(packed_object) + return packed_object + + # It is an artifact, continue with the packaging: + artifact, instructions = packed_object + + # Prepare the manager's unpackagingg intructions notes: + unpackaging_instructions = { + self._InstructionsNotesKey.PACKAGER_NAME: packager.__name__, + self._InstructionsNotesKey.OBJECT_TYPE: self._get_type_name(typ=type(obj)), + self._InstructionsNotesKey.ARTIFACT_TYPE: ( + artifact_type + if artifact_type + else packager.get_default_packing_artifact_type(obj=obj) + ), + self._InstructionsNotesKey.INSTRUCTIONS: instructions, + } + + # Set the instructions in the artifact's spec: + artifact.spec.unpackaging_instructions = unpackaging_instructions + + # Collect the artifact and return: + self._artifacts.append(artifact) + return artifact + + def _unpack_package( + self, + data_item: DataItem, + artifact_key: str, + packaging_instructions: dict, + type_hint: type, + ) -> Any: + """ + Unpack a data item as a package using the given notes. + + :param data_item: The data item to unpack. + :param artifact_key: The artifact's key (used only to raise a meaningful error message in case of an + error). + :param packaging_instructions: The manager's noted instructions. + :param type_hint: The user's type hint. + + :return: The unpacked object. + + :raise MLRunPackageUnpackingError: If there is no packager with the given name. + """ + # Extract the packaging instructions: + packager_name = packaging_instructions[self._InstructionsNotesKey.PACKAGER_NAME] + try: + # For validation, we'll try to get the type of the original packaged object. The original object type might + # not be available for 3 reasons: + # 1. The user is trying to parse the data item to a different type than the one it was packaged - meaning it + # is ok to be missing, the method will call `unpack_data_item` down the road. + # 2. The interpreter does not have the required module to unpack this object meaning it will not have the + # original packager as well, so it will try to use another package before raising an error. + # 3. An edge case where the user declared the class at the MLRun function itself. Read the long warning to + # understand more. + self._get_type_from_name( + type_name=packaging_instructions[self._InstructionsNotesKey.OBJECT_TYPE] + ) + except ModuleNotFoundError: + logger.warn( + f"Could not import the original type " + f"('{packaging_instructions[self._InstructionsNotesKey.OBJECT_TYPE]}') of the input artifact " + f"'{artifact_key}' due to a `ModuleNotFoundError`.\n" + f"Note: If you wish to parse the input to a different type (which is not recommended) you may ignore " + f"this warning. Otherwise, make sure the interpreter has the required module to import the type.\n" + f"If it does, you probably implemented the class at the same file of your MLRun function, making " + f"Python collect it twice: one from the object's own Packager class and another from the function " + f"code. When MLRun is converting code to a MLRun function, it counts on it to be able to be imported " + f"as a stand alone file. If other classes (like the packager who imports it) require objects declared " + f"in this file, it is no longer stand alone. For example:\n\n" + f"" + f"Let us look at a file '/src/my_module/my_file.py':" + f"\tclass MyClass:\n" + f"\t\tpass\n\n" + f"\tclass MyClassPackager(Packager):\n" + f"\t\tPACKABLE_OBJECT_TYPE = MyClass\n\n" + f"" + f"The packager of this class will have the class variable `PACKABLE_OBJECT_TYPE=MyClass` where " + f"`MyClass`'s module is `src.my_module.my_file.MyClass` because it is being collected from the repo " + f"downloaded with the project.\n" + f"But, if creating a MLRun function of '/src/my_module/my_file.py', then 'my_file.py' will be imported " + f"as a stand alone module, making the same class to be imported twice: one time as `my_file.MyClass` " + f"from the stand alone function, and another from the packager who has the correct full module path: " + f"`src.my_module.my_file.MyClass`. This will cause both classes, although the same, to be not equal " + f"and the first one to be not even importable outside the scope of 'my_file.py' - yielding this " + f"warning." + ) + artifact_type = packaging_instructions[self._InstructionsNotesKey.ARTIFACT_TYPE] + instructions = ( + packaging_instructions[self._InstructionsNotesKey.INSTRUCTIONS] or {} + ) + + # Get the original packager by its name: + packager = self._get_packager_by_name(name=packager_name) + + # Check if the original packager can be used (the user do not count on parsing to a different type): + unpack_as_package = False + if packager is None: + # The original packager was not found, the user either did not add the custom packager or perhaps wants + # to unpack the data item as a different type than the original one. We will warn and continue to unpack as + # a non-package data item: + logger.warn( + f"{artifact_key} was originally packaged by a packager of type '{packager_name}' but it " + f"was not found. Custom packagers should be added to the project running the function " + f"using the `add_custom_packager` method and make sure the function was set in the project " + f"with the attribute 'with_repo=True`.\n" + f"MLRun will try to unpack according to the provided type hint in code." + ) + elif type_hint is None: + # User count on the type noted in the package, so we unpack it as is: + unpack_as_package = True + else: + # A type hint is provided, check if the type hint is packable by the packager: + type_hints = {type_hint} + while not unpack_as_package and len(type_hints) > 0: + # Check for each hint (one match is enough): + for hint in type_hints: + if packager.is_unpackable( + data_item=data_item, type_hint=hint, artifact_type=artifact_type + ): + unpack_as_package = True + break + if not unpack_as_package: + # Reduce the hints and continue: + type_hints = TypeHintUtils.reduce_type_hint(type_hint=type_hints) + if not unpack_as_package: + # They are not equal, so we can't count on the original packager noted on the package as the user + # require different type, so we unpack as data item: + logger.warn( + f"{artifact_key} was originally packaged by '{packager_name}' but the type hint given to " + f"unpack it as '{type_hint}' is not supported by it. MLRun will try to look for a matching " + f"packager to the type hint instead. Note: it is not recommended to parse an object from type to " + f"type using the unpacking mechanism of packagers as unknown behavior might happen." + ) + + # Unpack: + if unpack_as_package: + return packager.unpack( + data_item=data_item, + artifact_type=artifact_type, + instructions=instructions, + ) + return self._unpack_data_item(data_item=data_item, type_hint=type_hint) + + def _unpack_data_item(self, data_item: DataItem, type_hint: Type): + """ + Unpack a data item to the desired hinted type. In case the type hint includes multiple types (like in case of + `typing.Union`), the manager will go over the types, reduce them while looking for the first packager that + successfully unpack the data item. + + :param data_item: The data item to unpack. + :param type_hint: The type hint to unpack it to. + + :return: The unpacked object. + + :raise MLRunPackageUnpackingError: If there is no packager that supports the provided type hint. + """ + # Prepare a list of a packager and exception string for all the failures in case there was no fitting packager: + found_packagers: List[Tuple[Type[Packager], str]] = [] + + # Try to unpack as one of the possible types in the type hint: + possible_type_hints = {type_hint} + while len(possible_type_hints) > 0: + for hint in possible_type_hints: + # Get the packager by the given type: + packager = self._get_packager_for_unpacking( + data_item=data_item, type_hint=hint + ) + if packager is None: + # No packager was found that supports this hinted type: + continue + # Unpack: + try: + return packager.unpack( + data_item=data_item, + instructions={}, + ) + except Exception as exception: + # Could not unpack as the reduced type hint, collect the exception and go to the next one: + exception_string = "".join( + traceback.format_exception( + etype=type(exception), + value=exception, + tb=exception.__traceback__, + ) + ) + found_packagers.append((packager, exception_string)) + # Reduce the type hint list and continue: + possible_type_hints = TypeHintUtils.reduce_type_hint( + type_hint=possible_type_hints + ) + + # Check the default packager: + logger.info( + f"Trying to use the default packager to unpack the data item '{data_item.key}'" + ) + try: + return self._default_packager.unpack( + data_item=data_item, + artifact_type=None, + instructions={}, + ) + except Exception as exception: + exception_string = "".join( + traceback.format_exception( + etype=type(exception), + value=exception, + tb=exception.__traceback__, + ) + ) + found_packagers.append((self._default_packager, exception_string)) + + # The method did not return until this point, raise an error: + raise MLRunPackageUnpackingError( + f"Could not unpack data item with the hinted type '{type_hint}'. The following packagers were tried to " + f"be used to unpack it but raised the exceptions joined:\n\n" + + "\n".join( + [ + f"Found packager: '{packager}'\nException: {exception}\n" + for packager, exception in found_packagers + ] + ) + ) + + @staticmethod + def _look_for_extra_data( + key: str, + artifacts: List[Artifact], + results: dict, + ) -> Union[Artifact, str, int, float, None]: + """ + Look for an extra data item (artifact or result) by given key. If not found, None is returned. + + :param key: Key to look for. + :param artifacts: Artifacts to look in. + :param results: Results to look in. + + :return: The artifact or result with the same key or None if not found. + """ + # Look in the artifacts: + for artifact in artifacts: + if key == artifact.key: + return artifact + + # Look in the results: + return results.get(key, None) + + @staticmethod + def _split_module_path(module_path: str) -> Tuple[str, str]: + """ + Split a module path to the module name and the class name. Notice inner classes are not supported. + + :param module_path: The module path to split. + + :return: A tuple of strings of the module name and the class name. + """ + # Set the main script module in case there is no module to be found: + if "." not in module_path: + module_path = f"__main__.{module_path}" + + # Split and return: + module_name, class_name = module_path.rsplit(".", 1) + return module_name, class_name + + @staticmethod + def _get_type_name(typ: Type) -> str: + """ + Get an object type full name - its module path. For example, the name of a pandas data frame will be "DataFrame" + but its full name (module path) is: "pandas.core.frame.DataFrame". + + Notice: Type hints are not an object type. They are as their name suggests, only hints. As such, typing hints + should not be given to this function (they do not have '__name__' and '__qualname__' attributes for example). + + :param typ: The object's type to get its full name. + + :return: The object's type full name. + """ + # Get the module name: + module_name = typ.__module__ if hasattr(typ, "__module__") else "" + + # Get the type's (class) name + class_name = typ.__qualname__ if hasattr(typ, "__qualname__") else typ.__name__ + + return f"{module_name}.{class_name}" if module_name else class_name + + @staticmethod + def _get_type_from_name(type_name: str) -> Type: + """ + Get the type object out of the given module path. The module must be a full module path (for example: + "pandas.DataFrame" and not "DataFrame") otherwise it assumes to be from the local run module - __main__. + + :param type_name: The type full name (module path) string. + + :return: The represented type as imported from its module. + """ + module_name, class_name = PackagersManager._split_module_path( + module_path=type_name + ) + module = importlib.import_module(module_name) + return getattr(module, class_name) diff --git a/mlrun/package/utils/__init__.py b/mlrun/package/utils/__init__.py new file mode 100644 index 000000000000..93e6e97e0d69 --- /dev/null +++ b/mlrun/package/utils/__init__.py @@ -0,0 +1,53 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx + +from ._archiver import ArchiveSupportedFormat +from ._formatter import StructFileSupportedFormat +from ._pickler import Pickler +from ._supported_format import SupportedFormat +from .log_hint_utils import LogHintKey, LogHintUtils +from .type_hint_utils import TypeHintUtils + +# The default pickle module to use for pickling objects: +DEFAULT_PICKLE_MODULE = "cloudpickle" +# The default archive format to use for archiving directories: +DEFAULT_ARCHIVE_FORMAT = ArchiveSupportedFormat.ZIP +# The default struct file format to use for savings python struct objects (dictionaries and lists): +DEFAULT_STRUCT_FILE_FORMAT = StructFileSupportedFormat.JSON + + +class ArtifactType: + """ + Possible artifact types to pack objects as and log using a `mlrun.Packager`. + """ + + OBJECT = "object" + PATH = "path" + FILE = "file" + DATASET = "dataset" + MODEL = "model" + PLOT = "plot" + RESULT = "result" + + +class DatasetFileFormat: + """ + All file format for logging objects as `DatasetArtifact`. + """ + + CSV = "csv" + PARQUET = "parquet" diff --git a/mlrun/package/utils/_archiver.py b/mlrun/package/utils/_archiver.py new file mode 100644 index 000000000000..d2c49b596924 --- /dev/null +++ b/mlrun/package/utils/_archiver.py @@ -0,0 +1,226 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import tarfile +import zipfile +from abc import ABC, abstractmethod +from pathlib import Path + +from ._supported_format import SupportedFormat + + +class _Archiver(ABC): + """ + An abstract base class for an archiver - a class to manage archives of multiple files. + """ + + @classmethod + @abstractmethod + def create_archive(cls, directory_path: str, output_path: str) -> str: + """ + Create an archive of all the contents in the given directory and save it to an archive file named as the + directory in the provided output path. + + :param directory_path: The directory with the files to archive. + :param output_path: The output path to store the created archive file. + + :return: The created archive path. + """ + pass + + @classmethod + @abstractmethod + def extract_archive(cls, archive_path: str, output_path: str) -> str: + """ + Extract the given archive to a directory named as the archive file (without the extension) located in the + provided output path. + + :param archive_path: The archive file to extract its contents. + :param output_path: The output path to extract the directory of the archive to. + + :return: The extracted contents directory path. + """ + pass + + +class _ZipArchiver(_Archiver): + """ + A static class for managing zip archives. + """ + + @classmethod + def create_archive(cls, directory_path: str, output_path: str) -> str: + """ + Create an archive of all the contents in the given directory and save it to an archive file named as the + directory in the provided output path. + + :param directory_path: The directory with the files to archive. + :param output_path: The output path to store the created archive file. + + :return: The created archive path. + """ + # Convert to `pathlib.Path` objects: + directory_path = Path(directory_path) + output_path = Path(output_path) + + # Construct the archive file path: + archive_path = output_path / f"{directory_path.stem}.zip" + + # Archive: + with zipfile.ZipFile(archive_path, "w") as zip_file: + for path in directory_path.rglob("*"): + zip_file.write(filename=path, arcname=path.relative_to(directory_path)) + + return str(archive_path) + + @classmethod + def extract_archive(cls, archive_path: str, output_path: str) -> str: + """ + Extract the given archive to a directory named as the archive file (without the extension) located in the + provided output path. + + :param archive_path: The archive file to extract its contents. + :param output_path: The output path to extract the directory of the archive to. + + :return: The extracted contents directory path. + """ + # Convert to `pathlib.Path` objects: + archive_path = Path(archive_path) + output_path = Path(output_path) + + # Create the directory path: + directory_path = output_path / archive_path.stem + os.makedirs(directory_path) + + # Extract: + with zipfile.ZipFile(archive_path, "r") as zip_file: + zip_file.extractall(directory_path) + + return str(directory_path) + + +class _TarArchiver(_Archiver): + """ + A static class for managing tar archives. + """ + + # Inner class variable to note how to open a `TarFile` object for reading and writing: + _MODE_STRING = "" + + @classmethod + def create_archive(cls, directory_path: str, output_path: str) -> str: + """ + Create an archive of all the contents in the given directory and save it to an archive file named as the + directory in the provided output path. + + :param directory_path: The directory with the files to archive. + :param output_path: The output path to store the created archive file. + + :return: The created archive path. + """ + # Convert to `pathlib.Path` objects: + directory_path = Path(directory_path) + output_path = Path(output_path) + + # Construct the archive file path: + archive_file_extension = ( + "tar" if cls._MODE_STRING == "" else f"tar.{cls._MODE_STRING}" + ) + archive_path = output_path / f"{directory_path.stem}.{archive_file_extension}" + + # Archive: + with tarfile.open(archive_path, f"w:{cls._MODE_STRING}") as tar_file: + for path in directory_path.rglob("*"): + tar_file.add(name=path, arcname=path.relative_to(directory_path)) + + return str(archive_path) + + @classmethod + def extract_archive(cls, archive_path: str, output_path: str) -> str: + """ + Extract the given archive to a directory named as the archive file (without the extension) located in the + provided output path. + + :param archive_path: The archive file to extract its contents. + :param output_path: The output path to extract the directory of the archive to. + + :return: The extracted contents directory path. + """ + # Convert to `pathlib.Path` objects: + archive_path = Path(archive_path) + output_path = Path(output_path) + + # Get the archive file name (can be constructed of multiple extensions like tar.gz so `Path.stem` won't work): + archive_file_name = archive_path + while archive_file_name.with_suffix(suffix="") != archive_file_name: + archive_file_name = archive_file_name.with_suffix(suffix="") + archive_file_name = archive_file_name.stem + + # Create the directory path: + directory_path = output_path / archive_file_name + os.makedirs(directory_path) + + # Extract: + with tarfile.open(archive_path, f"r:{cls._MODE_STRING}") as tar_file: + tar_file.extractall(directory_path) + + return str(directory_path) + + +class _TarGZArchiver(_TarArchiver): + """ + A static class for managing tar.gz archives. + """ + + # Inner class variable to note how to open a `TarFile` object for reading and writing: + _MODE_STRING = "gz" + + +class _TarBZ2Archiver(_TarArchiver): + """ + A static class for managing tar.bz2 archives. + """ + + # Inner class variable to note how to open a `TarFile` object for reading and writing: + _MODE_STRING = "bz2" + + +class _TarXZArchiver(_TarArchiver): + """ + A static class for managing tar.gz archives. + """ + + # Inner class variable to note how to open a `TarFile` object for reading and writing: + _MODE_STRING = "xz" + + +class ArchiveSupportedFormat(SupportedFormat[_Archiver]): + """ + Library of archive formats (file extensions) supported by some builtin MLRun packagers. + """ + + ZIP = "zip" + TAR = "tar" + TAR_GZ = "tar.gz" + TAR_BZ2 = "tar.bz2" + TAR_XZ = "tar.xz" + + _FORMAT_HANDLERS_MAP = { + ZIP: _ZipArchiver, + TAR: _TarArchiver, + TAR_GZ: _TarGZArchiver, + TAR_BZ2: _TarBZ2Archiver, + TAR_XZ: _TarXZArchiver, + } diff --git a/mlrun/package/utils/_formatter.py b/mlrun/package/utils/_formatter.py new file mode 100644 index 000000000000..51a27eaecc5b --- /dev/null +++ b/mlrun/package/utils/_formatter.py @@ -0,0 +1,211 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import ast +import json +from abc import ABC, abstractmethod +from typing import Any, Union + +import yaml + +from ._supported_format import SupportedFormat + + +class _Formatter(ABC): + """ + An abstract base class for a formatter - a class to format python structures into and from files. + """ + + @classmethod + @abstractmethod + def write(cls, obj: Any, file_path: str, **dump_kwargs: dict): + """ + Write the object to a file. The object must be serializable according to the used format. + + :param obj: The object to write. + :param file_path: The file path to write to. + :param dump_kwargs: Additional keyword arguments to pass to the dump method of the formatter in use. + """ + pass + + @classmethod + @abstractmethod + def read(cls, file_path: str) -> Any: + """ + Read an object from the file given. + + :param file_path: The file to read the object from. + + :return: The read object. + """ + pass + + +class _JSONFormatter(_Formatter): + """ + A static class for managing json files. + """ + + # A set of default configurations to pass to the dump function: + DEFAULT_DUMP_KWARGS = {"indent": 4} + + @classmethod + def write(cls, obj: Union[list, dict], file_path: str, **dump_kwargs: dict): + """ + Write the object to a json file. The object must be serializable according to the json format. + + :param obj: The object to write. + :param file_path: The file path to write to. + :param dump_kwargs: Additional keyword arguments to pass to the `json.dump` method of the formatter in use. + """ + dump_kwargs = dump_kwargs or cls.DEFAULT_DUMP_KWARGS + with open(file_path, "w") as file: + json.dump(obj, file, **dump_kwargs) + + @classmethod + def read(cls, file_path: str) -> Union[list, dict]: + """ + Read an object from the json file given. + + :param file_path: The json file to read the object from. + + :return: The read object. + """ + with open(file_path, "r") as file: + obj = json.load(file) + return obj + + +class _JSONLFormatter(_Formatter): + """ + A static class for managing jsonl files. + """ + + @classmethod + def write(cls, obj: Union[list, dict], file_path: str, **dump_kwargs: dict): + """ + Write the object to a jsonl file. The object must be serializable according to the json format. + + :param obj: The object to write. + :param file_path: The file path to write to. + :param dump_kwargs: Additional keyword arguments to pass to the `json.dumps` method of the formatter in use. + """ + if isinstance(obj, dict): + obj = [obj] + + with open(file_path, "w") as file: + for line in obj: + file.write(json.dumps(obj=line, **dump_kwargs) + "\n") + + @classmethod + def read(cls, file_path: str) -> Union[list, dict]: + """ + Read an object from the jsonl file given. + + :param file_path: The jsonl file to read the object from. + + :return: The read object. + """ + with open(file_path, "r") as file: + lines = file.readlines() + + obj = [] + for line in lines: + obj.append(json.loads(s=line)) + + return obj[0] if len(obj) == 1 else obj + + +class _YAMLFormatter(_Formatter): + """ + A static class for managing yaml files. + """ + + # A set of default configurations to pass to the dump function: + DEFAULT_DUMP_KWARGS = {"default_flow_style": False, "indent": 4} + + @classmethod + def write(cls, obj: Union[list, dict], file_path: str, **dump_kwargs: dict): + """ + Write the object to a yaml file. The object must be serializable according to the yaml format. + + :param obj: The object to write. + :param file_path: The file path to write to. + :param dump_kwargs: Additional keyword arguments to pass to the `yaml.dump` method of the formatter in use. + """ + dump_kwargs = dump_kwargs or cls.DEFAULT_DUMP_KWARGS + with open(file_path, "w") as file: + yaml.dump(obj, file, **dump_kwargs) + + @classmethod + def read(cls, file_path: str) -> Union[list, dict]: + """ + Read an object from the yaml file given. + + :param file_path: The yaml file to read the object from. + + :return: The read object. + """ + with open(file_path, "r") as file: + obj = yaml.safe_load(file) + return obj + + +class _TXTFormatter(_Formatter): + """ + A static class for managing txt files. + """ + + @classmethod + def write(cls, obj: Any, file_path: str, **dump_kwargs: dict): + """ + Write the object to a text file. The object must be serializable according to python's ast module. + + :param obj: The object to write. + :param file_path: The file path to write to. + :param dump_kwargs: Ignored. + """ + with open(file_path, "w") as file: + file.write(str(obj)) + + @classmethod + def read(cls, file_path: str) -> Any: + """ + Read an object from the yaml file given. + + :param file_path: The yaml file to read the object from. + + :return: The read object. + """ + with open(file_path, "r") as file: + obj = ast.literal_eval(file.read()) + return obj + + +class StructFileSupportedFormat(SupportedFormat[_Formatter]): + """ + Library of struct formats (file extensions) supported by some builtin MLRun packagers. + """ + + JSON = "json" + JSONL = "jsonl" + YAML = "yaml" + TXT = "txt" + + _FORMAT_HANDLERS_MAP = { + JSON: _JSONFormatter, + JSONL: _JSONLFormatter, + YAML: _YAMLFormatter, + TXT: _TXTFormatter, + } diff --git a/mlrun/package/utils/_pickler.py b/mlrun/package/utils/_pickler.py new file mode 100644 index 000000000000..00cce706e27e --- /dev/null +++ b/mlrun/package/utils/_pickler.py @@ -0,0 +1,234 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +import os +import sys +import tempfile +import warnings +from types import ModuleType +from typing import Any, Dict, Tuple, Union + +from mlrun.errors import MLRunInvalidArgumentError +from mlrun.utils import logger + + +class Pickler: + """ + A static class to pickle objects with multiple modules while capturing the environment of the pickled object. The + pickler will raise warnings in case the object is un-pickled in a mismatching environment (different modules + and / or python versions) + """ + + @staticmethod + def pickle( + obj: Any, pickle_module_name: str, output_path: str = None + ) -> Tuple[str, Dict[str, Union[str, None]]]: + """ + Pickle an object using the given module. The pickled object will be saved to file to the given output path. + + :param obj: The object to pickle. + :param pickle_module_name: The pickle module to use. For example: "pickle", "joblib", "cloudpickle". + :param output_path: The output path to save the 'pkl' file to. If not provided, the pickle will be saved + to a temporary directory. The user is responsible to clean the temporary directory. + + :return: A tuple of the path of the 'pkl' file and the instructions the pickler noted. + """ + # Get the pickle module: + pickle_module = importlib.import_module(pickle_module_name) + Pickler._validate_pickle_module(pickle_module=pickle_module) + pickle_module_version = Pickler._get_module_version( + module_name=pickle_module_name + ) + + # Get the object's module (module name can be extracted usually from the object's class): + object_module_name = ( + obj.__module__.split(".")[0] + if hasattr(obj, "__module__") + else type(obj).__module__.split(".")[0] + ) + object_module_version = Pickler._get_module_version( + module_name=object_module_name + ) + + # Get the python version: + python_version = Pickler._get_python_version() + + # Construct the pickler labels dictionary (versions may not be available): + instructions = { + "object_module_name": object_module_name, + "pickle_module_name": pickle_module_name, + "python_version": python_version, + } + if object_module_version is not None: + instructions["object_module_version"] = object_module_version + if pickle_module_version is not None: + instructions["pickle_module_version"] = pickle_module_version + + # Generate a temporary output path if not provided: + if output_path is None: + output_path = os.path.join(tempfile.mkdtemp(), "obj.pkl") + + # Pickle the object to file: + with open(output_path, "wb") as pkl_file: + pickle_module.dump(obj, pkl_file) + + return output_path, instructions + + @staticmethod + def unpickle( + pickle_path: str, + pickle_module_name: str, + object_module_name: str = None, + python_version: str = None, + pickle_module_version: str = None, + object_module_version: str = None, + ) -> Any: + """ + Unpickle an object using the given instructions. Warnings may be raised in case any of the versions are + mismatching (only if provided - not None). + + :param pickle_path: Path to the 'pkl' file to un-pickle. + :param pickle_module_name: Module to use for unpickling the object. + :param object_module_name: The original object's module. Used to verify the current interpreter object module + version match the pickled object version before unpickling the object. + :param python_version: The python version in which the original object was pickled. Used to verify the + current interpreter python version match the pickled object version before + unpickling the object. + :param pickle_module_version: The pickle module version. Used to verify the current interpreter module version + match the one who pickled the object before unpickling it. + :param object_module_version: The original object's module version to match to the interpreter's module version. + + :return: The un-pickled object. + """ + # Check the python version against the pickled object: + if python_version is not None: + current_python_version = Pickler._get_python_version() + if python_version != current_python_version: + logger.warn( + f"MLRun is trying to load an object that was pickled on python version " + f"'{python_version}' but the current python version is '{current_python_version}'. " + f"When using pickle, it is recommended to save and load an object on the same python version to " + f"reduce unexpected errors." + ) + + # Get the pickle module: + pickle_module = importlib.import_module(pickle_module_name) + Pickler._validate_pickle_module(pickle_module=pickle_module) + + # Check the pickle module against the pickled object (only if the version is given): + if pickle_module_version is not None: + current_pickle_module_version = Pickler._get_module_version( + module_name=pickle_module_name + ) + if pickle_module_version != current_pickle_module_version: + logger.warn( + f"MLRun is trying to load an object that was pickled using " + f"{pickle_module_name} version {pickle_module_version} but the current module version is " + f"'{current_pickle_module_version}'. " + f"When using pickle, it is recommended to save and load an " + f"object using the same pickling module version to reduce unexpected errors." + ) + + # Check the object module against the pickled object (only if the version is given): + if object_module_version is not None and object_module_name is not None: + current_object_module_version = Pickler._get_module_version( + module_name=object_module_name + ) + if object_module_version != current_object_module_version: + logger.warn( + f"MLRun is trying to load an object from module {object_module_name} version " + f"{object_module_version} but the current module version is '{current_object_module_version}'. " + f"When using pickle, it is recommended to save and load an object using " + f"the same exact module version to reduce unexpected errors." + ) + + # Load the object from the pickle file: + with open(pickle_path, "rb") as pickle_file: + obj = pickle_module.load(pickle_file) + + return obj + + @staticmethod + def _validate_pickle_module(pickle_module: ModuleType): + """ + Validate the pickle module to use have a `dump` and `load` functions so the Pickler can use it. + + :param pickle_module: The pickle module tot validate. + + :raise MLRunInvalidArgumentError: If the pickle module is not valid. + """ + for function_name in ["dump", "load"]: + if not hasattr(pickle_module, function_name): + raise MLRunInvalidArgumentError( + f"A pickle module is expected to have a `{function_name}` function but the provided module " + f"{pickle_module.__name__} does not have it." + ) + + @staticmethod + def _get_module_version(module_name: str) -> Union[str, None]: + """ + Get a module's version. Most updated modules have versions but some don't. In case the version could not be + read, None is returned. + + :param module_name: The module's name to get its version. + + :return: The module's version if found and None otherwise. + """ + # First we'll try to get the module version from `importlib`: + try: + # Since Python 3.8, `version` is part of `importlib.metadata`. Before 3.8, we'll use the module + # `importlib_metadata` to get `version`. + if ( + sys.version_info[1] > 7 + ): # TODO: Remove once Python 3.7 is not supported. + from importlib.metadata import version + else: + from importlib_metadata import version + + return version(module_name) + except (ModuleNotFoundError, importlib.metadata.PackageNotFoundError): + # User won't necessarily have the `importlib_metadata` module, so we will ignore it by catching + # `ModuleNotFoundError`. `PackageNotFoundError` is ignored as well as this is raised when `version` could + # not find the package related to the module. + pass + + # Secondly, if importlib could not get the version (most likely 'importlib_metadata' is not installed), we'll + # try to use `pkg_resources` to get the version (the version will be found only if the package name is equal to + # the module name. For example, if the module name is 'x' then the way we installed the package must be + # 'pip install x'): + import pkg_resources + + with warnings.catch_warnings(): + # If a module's package is not found, a `PkgResourcesDeprecationWarning` warning will be raised and then + # `DistributionNotFound` exception will be raised, so we ignore them both: + warnings.filterwarnings( + "ignore", category=pkg_resources.PkgResourcesDeprecationWarning + ) + try: + return pkg_resources.get_distribution(module_name).version + except pkg_resources.DistributionNotFound: + pass + + # The version could not be found. + return None + + @staticmethod + def _get_python_version() -> str: + """ + Get the current running python's version. + + :return: The python version string. + """ + return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" diff --git a/mlrun/package/utils/_supported_format.py b/mlrun/package/utils/_supported_format.py new file mode 100644 index 000000000000..d9e30d1d5290 --- /dev/null +++ b/mlrun/package/utils/_supported_format.py @@ -0,0 +1,71 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from abc import ABC +from typing import Dict, Generic, List, Type, TypeVar, Union + +# A generic type for a supported format handler class type: +FileHandlerType = TypeVar("FileHandlerType") + + +class SupportedFormat(ABC, Generic[FileHandlerType]): + """ + Library of supported formats by some builtin MLRun packagers. + """ + + # Add here the all the supported formats in ALL CAPS and their value as a string: + ... + + # The map to use in the method `get_format_handler`. A dictionary of string key to a class type to handle that + # format. New supported formats and handlers should be added to it: + _FORMAT_HANDLERS_MAP: Dict[str, Type[FileHandlerType]] = {} + + @classmethod + def get_all_formats(cls) -> List[str]: + """ + Get all supported formats. + + :return: A list of all the supported formats. + """ + return [ + value + for key, value in cls.__dict__.items() + if isinstance(value, str) and not key.startswith("_") + ] + + @classmethod + def get_format_handler(cls, fmt: str) -> Type[FileHandlerType]: + """ + Get the format handler to the provided format (file extension): + + :param fmt: The file extension to get the corresponding handler. + + :return: The handler class. + """ + return cls._FORMAT_HANDLERS_MAP[fmt] + + @classmethod + def match_format(cls, path: str) -> Union[str, None]: + """ + Try to match one of the available formats this class holds to a given path. + + :param path: The path to match the format to. + + :return: The matched format if found and None otherwise. + """ + formats = cls.get_all_formats() + for fmt in formats: + if path.endswith(f".{fmt}"): + return fmt + return None diff --git a/mlrun/package/utils/log_hint_utils.py b/mlrun/package/utils/log_hint_utils.py new file mode 100644 index 000000000000..03344962985a --- /dev/null +++ b/mlrun/package/utils/log_hint_utils.py @@ -0,0 +1,93 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import typing + +from mlrun.errors import MLRunInvalidArgumentError + + +class LogHintKey: + """ + Known keys for a log hint to have. + """ + + KEY = "key" + ARTIFACT_TYPE = "artifact_type" + EXTRA_DATA = "extra_data" + METRICS = "metrics" + + +class LogHintUtils: + """ + Static class for utilities functions to process log hints. + """ + + @staticmethod + def parse_log_hint( + log_hint: typing.Union[typing.Dict[str, str], str, None] + ) -> typing.Union[typing.Dict[str, str], None]: + """ + Parse a given log hint from string to a logging configuration dictionary. The string will be read as the + artifact key ('key' in the dictionary) and if the string have a single colon, the following structure is + assumed: " : ". + + If a logging configuration dictionary is received, it will be validated to have a key field. + + None will be returned as None. + + :param log_hint: The log hint to parse. + + :return: The hinted logging configuration. + + :raise MLRunInvalidArgumentError: In case the log hint is not following the string structure or the dictionary + is missing the key field. + """ + # Check for None value: + if log_hint is None: + return None + + # If the log hint was provided as a string, construct a dictionary out of it: + if isinstance(log_hint, str): + # Check if only key is given: + if ":" not in log_hint: + log_hint = {LogHintKey.KEY: log_hint} + # Check for valid " : " pattern: + else: + if log_hint.count(":") > 1: + raise MLRunInvalidArgumentError( + f"Incorrect log hint pattern. Log hints can have only a single ':' in them to specify the " + f"desired artifact type the returned value will be logged as: " + f"' : ', but given: {log_hint}" + ) + # Split into key and type: + key, artifact_type = log_hint.replace(" ", "").split(":") + if artifact_type == "": + raise MLRunInvalidArgumentError( + f"Incorrect log hint pattern. The ':' in a log hint should specify the desired artifact type " + f"the returned value will be logged as in the following pattern: " + f"' : ', but no artifact type was given: {log_hint}" + ) + log_hint = { + LogHintKey.KEY: key, + LogHintKey.ARTIFACT_TYPE: artifact_type, + } + + # Validate the log hint dictionary has the mandatory key: + if LogHintKey.KEY not in log_hint: + raise MLRunInvalidArgumentError( + f"A log hint dictionary must include the 'key' - the artifact key (it's name). The following log hint " + f"is missing the key: {log_hint}." + ) + + return log_hint diff --git a/mlrun/package/utils/type_hint_utils.py b/mlrun/package/utils/type_hint_utils.py new file mode 100644 index 000000000000..1d517e92483a --- /dev/null +++ b/mlrun/package/utils/type_hint_utils.py @@ -0,0 +1,298 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import builtins +import importlib +import itertools +import re +import sys +import typing + +from mlrun.errors import MLRunInvalidArgumentError +from mlrun.utils import logger + + +class TypeHintUtils: + """ + Static class for utilities functions to process type hints. + """ + + @staticmethod + def is_typing_type(type_hint: type) -> bool: + """ + Check whether a given type is a type hint from one of the modules `typing` and `types`. The function will return + True for generic type aliases also, meaning Python 3.9's new hinting feature that includes hinting like + `list[int]` instead of `typing.List[int]`. + + :param type_hint: The type to check. + + :return: True if the type hint from `typing` / `types` and False otherwise. + """ + # A type hint should be one of the based typing classes, meaning it will have "typing" as its module. Some + # typing classes are considered a type (like `TypeVar`) so we check their type as well. The only case "types" + # will be a module is for generic aliases like `list[int]`. + return (type_hint.__module__ == "typing") or ( + type(type_hint).__module__ in ["typing", "types"] + ) + + @staticmethod + def parse_type_hint(type_hint: typing.Union[type, str]) -> type: + """ + Parse a given type hint from string to its actual hinted type class object. The string must be one of the + following: + + * Python builtin type - for example: `tuple`, `list`, `set`, `dict` and `bytearray`. + * Full module import path. An alias (if `import pandas as pd is used`, the type hint cannot be `pd.DataFrame`) + is not allowed. + + The type class on its own (like `DataFrame`) cannot be used as the scope of this function is not the same as the + handler itself, hence modules and objects that were imported in the handler's scope are not available. This is + the same reason import aliases cannot be used as well. + + If the provided type hint is not a string, it will simply be returned as is. + + :param type_hint: The type hint to parse. + + :return: The hinted type. + + :raise MLRunInvalidArgumentError: In case the type hint is not following the 2 options mentioned above. + """ + if not isinstance(type_hint, str): + return type_hint + + # Validate the type hint is a valid module path: + if not bool( + re.fullmatch( + r"([a-zA-Z_][a-zA-Z0-9_]*\.)*[a-zA-Z_][a-zA-Z0-9_]*", type_hint + ) + ): + raise MLRunInvalidArgumentError( + f"Invalid type hint. An input type hint must be a valid python class name or its module import path. " + f"For example: 'list', 'pandas.DataFrame', 'numpy.ndarray', 'sklearn.linear_model.LinearRegression'. " + f"Type hint given: '{type_hint}'." + ) + + # Look for a builtin type (rest of the builtin types like `int`, `str`, `float` should be treated as results, + # hence not given as an input to an MLRun function, but as a parameter): + builtin_types = { + builtin_name: builtin_type + for builtin_name, builtin_type in builtins.__dict__.items() + if isinstance(builtin_type, type) + } + if type_hint in builtin_types: + return builtin_types[type_hint] + + # If it's not a builtin, its should have a full module path, meaning at least one '.' to separate the module and + # the class. If it doesn't, we will try to get the class from the main module: + if "." not in type_hint: + logger.warn( + f"The type hint string given '{type_hint}' is not a `builtins` python type. MLRun will try to look for " + f"it in the `__main__` module instead." + ) + try: + return TypeHintUtils.parse_type_hint(type_hint=f"__main__.{type_hint}") + except MLRunInvalidArgumentError: + raise MLRunInvalidArgumentError( + f"MLRun tried to get the type hint '{type_hint}' but it can't as it is not a valid builtin Python " + f"type (one of `list`, `dict`, `str`, `int`, etc.) nor a locally declared type (from the " + f"`__main__` module). Pay attention using only the type as string is not allowed as the handler's " + f"scope is different than MLRun's. To properly give a type hint as string, please specify the full " + f"module path without aliases. For example: do not use `DataFrame` or `pd.DataFrame`, use " + f"`pandas.DataFrame`." + ) + + # Import the module to receive the hinted type: + try: + # Get the module path and the type class (If we'll wish to support inner classes, the `rsplit` won't work): + module_path, type_hint = type_hint.rsplit(".", 1) + # Replace alias if needed (alias assumed to be imported already, hence we look in globals): + # For example: + # If in handler scope there was `import A.B.C as abc` and user gave a type hint "abc.Something" then: + # `module_path[0]` will be equal to "abc". Then, because it is an alias, it will appear in the globals, so + # we'll replace the alias with the full module name in order to import the module. + module_path = module_path.split(".") + if module_path[0] in globals(): + module_path[0] = globals()[module_path[0]].__name__ + module_path = ".".join(module_path) + # Import the module: + module = importlib.import_module(module_path) + # Get the class type from the module: + type_hint = getattr(module, type_hint) + except ModuleNotFoundError as module_not_found_error: + # May be raised from `importlib.import_module` in case the module does not exist. + raise MLRunInvalidArgumentError( + f"MLRun tried to get the type hint '{type_hint}' but the module '{module_path}' cannot be imported. " + f"Keep in mind that using alias in the module path (meaning: import module as alias) is not allowed. " + f"If the module path is correct, please make sure the module package is installed in the python " + f"interpreter." + ) from module_not_found_error + except AttributeError as attribute_error: + # May be raised from `getattr(module, type_hint)` in case the class type cannot be imported directly from + # the imported module. + raise MLRunInvalidArgumentError( + f"MLRun tried to get the type hint '{type_hint}' from the module '{module.__name__}' but it seems it " + f"doesn't exist. Make sure the class can be imported from the module with the exact module path you " + f"passed. Notice inner classes (a class inside of a class) are not supported." + ) from attribute_error + + return type_hint + + @staticmethod + def is_matching( + object_type: type, + type_hint: typing.Union[type, typing.Set[type]], + include_subclasses: bool = True, + reduce_type_hint: bool = True, + ) -> bool: + """ + Check if the given object type match the given hint. + + :param object_type: The object type to match with the type hint. + :param type_hint: The hint to match with. Can be given as a set resulted from a reduced hint. + :param include_subclasses: Whether to mark a subclass as valid match. Default to True. + :param reduce_type_hint: Whether to reduce the type hint to match with its reduced hints. + + :return: True if the object type match the type hint and False otherwise. + """ + # Wrap in a set if provided a single type hint: + type_hint = {type_hint} if not isinstance(type_hint, set) else type_hint + + # Try to match the object type to one of the hints: + while len(type_hint) > 0: + for hint in type_hint: + # Subclass check can be made only on actual object types (not typing module types): + if ( + not TypeHintUtils.is_typing_type(type_hint=object_type) + and not TypeHintUtils.is_typing_type(type_hint=hint) + and include_subclasses + and issubclass(object_type, hint) + ): + return True + if object_type == hint: + return True + # See if needed to reduce, if not end on first iteration: + if not reduce_type_hint: + break + type_hint = TypeHintUtils.reduce_type_hint(type_hint=type_hint) + return False + + @staticmethod + def reduce_type_hint( + type_hint: typing.Union[type, typing.Set[type]], + ) -> typing.Set[type]: + """ + Reduce a type hint (or a set of type hints) using the `_reduce_type_hint` function. + + :param type_hint: The type hint to reduce. + + :return: The reduced type hints set or an empty set if the type hint could not be reduced. + """ + # Wrap in a set if provided a single type hint: + type_hints = {type_hint} if not isinstance(type_hint, set) else type_hint + + # Iterate over the type hints and reduce each one: + return set( + itertools.chain( + *[ + TypeHintUtils._reduce_type_hint(type_hint=type_hint) + for type_hint in type_hints + ] + ) + ) + + @staticmethod + def _reduce_type_hint(type_hint: type) -> typing.List[type]: + """ + Reduce a type hint. If the type hint is a `typing` module, it will be reduced to its original hinted types. For + example: `typing.Union[int, float, typing.List[int]]` will return `[int, float, List[int]]` and + `typing.List[int]` will return `[list]`. Regular type hints - Python object types cannot be reduced as they are + already a core type. + + If a type hint cannot be reduced, an empty list will be returned. + + :param type_hint: The type hint to reduce. + + :return: The reduced type hint as list of hinted types or an empty list if the type hint could not be reduced. + """ + # TODO: Remove when we'll no longer support Python 3.7: + if sys.version_info[1] < 8: + return [] + + # If it's not a typing type (meaning it's an actual object type) then we can't reduce it further: + if not TypeHintUtils.is_typing_type(type_hint=type_hint): + return [] + + # If it's a type var, take its constraints (e.g. A = TypeVar("A", int, str) meaning an object of type A should + # be an integer or a string). If it doesn't have constraints, return an empty list: + if isinstance(type_hint, typing.TypeVar): + if len(type_hint.__constraints__) == 0: + return [] + return list(type_hint.__constraints__) + + # If it's a forward reference, we will try to import the reference: + if isinstance(type_hint, typing.ForwardRef): + try: + # ForwardRef is initialized with the string type it represents and optionally a module path, so we + # construct a full module path and try to parse it: + arg = type_hint.__forward_arg__ + if type_hint.__forward_module__: + arg = f"{type_hint.__forward_module__}.{arg}" + return [TypeHintUtils.parse_type_hint(type_hint=arg)] + except MLRunInvalidArgumentError: # May be raised from `TypeHintUtils.parse_type_hint` + logger.warn( + f"Could not reduce the type hint '{type_hint}' as it is a forward reference to a class without " + f"it's full module path. To enable importing forward references, please provide the full module " + f"path to them. For example: use `ForwardRef('pandas.DataFrame')` instead of " + f"`ForwardRef('DataFrame')`." + ) + return [] + + # Get the origin of the typing type. An origin is the subscripted typing type (origin of Union[str, int] is + # Union). The origin can be one of Callable, Tuple, Union, Literal, Final, ClassVar, Annotated or the actual + # type alias (e.g. origin of List[int] is list): + origin = typing.get_origin(type_hint) + + # If the typing type has no origin (e.g. None is returned), we cannot reduce it, so we return an empty list: + if origin is None: + return [] + + # If the origin is a type of one of `builtins`, `contextlib` or `collections` (for example: List's origin is + # list) then we can be sure there is nothing to reduce as it's a regular type: + if not TypeHintUtils.is_typing_type(type_hint=origin): + return [origin] + + # Get the type's subscriptions - arguments, in order to reduce it to them (we know for sure there are arguments, + # otherwise origin would have been None): + args = typing.get_args(type_hint) + + # Return the reduced type as its arguments according to the origin: + if origin is typing.Callable: + # A callable cannot be reduced to its arguments, so we'll return the origin - Callable: + return [typing.Callable] + if origin is typing.Literal: + # Literal arguments are not types, but values. So we'll take the types of the values as the reduced type: + return [type(arg) for arg in args] + if origin is typing.Union: + # A union is reduced to its arguments: + return list(args) + if origin is typing.Annotated: + # Annotated is used to describe (add metadata to) a type, so we take the first argument (the type the + # metadata is being added to): + return [args[0]] + if origin is typing.Final or origin is typing.ClassVar: + # Both Final and ClassVar takes only one argument - the type: + return [args[0]] + + # For Generic types we return an empty list: + return [] diff --git a/mlrun/platforms/iguazio.py b/mlrun/platforms/iguazio.py index c16d2698e96f..062122f9afa7 100644 --- a/mlrun/platforms/iguazio.py +++ b/mlrun/platforms/iguazio.py @@ -326,7 +326,6 @@ def v3io_to_vol(name, remote="~/", access_key="", user="", secret=None): if secret: secret = {"name": secret} - # vol = client.V1Volume(name=name, flex_volume=client.V1FlexVolumeSource('v3io/fuse', options=opts)) vol = { "flexVolume": client.V1FlexVolumeSource( "v3io/fuse", options=opts, secret_ref=secret @@ -403,6 +402,37 @@ def dump_record(rec): ) +class HTTPOutputStream: + """HTTP output source that usually used for CE mode and debugging process""" + + def __init__(self, stream_path: str): + self._stream_path = stream_path + + def push(self, data): + def dump_record(rec): + if isinstance(rec, bytes): + return rec + + if not isinstance(rec, str): + rec = dict_to_json(rec) + + return rec.encode("UTF-8") + + if not isinstance(data, list): + data = [data] + + for record in data: + + # Convert the new record to the required format + serialized_record = dump_record(record) + response = requests.post(self._stream_path, data=serialized_record) + if not response: + raise mlrun.errors.MLRunInvalidArgumentError( + f"API call failed push a new record through {self._stream_path}" + f"status {response.status_code}: {response.reason}" + ) + + class KafkaOutputStream: def __init__( self, @@ -650,6 +680,8 @@ def parse_path(url, suffix="/"): ) endpoint = f"{prefix}://{parsed_url.netloc}" else: + # no netloc is mainly when using v3io (v3io:///) and expecting the url to be resolved automatically from env or + # config endpoint = None return endpoint, parsed_url.path.strip("/") + suffix diff --git a/mlrun/projects/operations.py b/mlrun/projects/operations.py index e77d2f11571f..cebfc32bbbc7 100644 --- a/mlrun/projects/operations.py +++ b/mlrun/projects/operations.py @@ -70,7 +70,7 @@ def run_function( selector: str = None, project_object=None, auto_build: bool = None, - schedule: Union[str, mlrun.api.schemas.ScheduleCronTrigger] = None, + schedule: Union[str, mlrun.common.schemas.ScheduleCronTrigger] = None, artifact_path: str = None, notifications: List[mlrun.model.Notification] = None, returns: Optional[List[Union[str, Dict[str, str]]]] = None, @@ -92,16 +92,16 @@ def run_function( LABELS = "is_error" MODEL_CLASS = "sklearn.ensemble.RandomForestClassifier" DATA_PATH = "s3://bigdata/data.parquet" - function = mlrun.import_function("hub://auto_trainer") + function = mlrun.import_function("hub://auto-trainer") run1 = run_function(function, params={"label_columns": LABELS, "model_class": MODEL_CLASS}, inputs={"dataset": DATA_PATH}) example (use with project):: - # create a project with two functions (local and from marketplace) + # create a project with two functions (local and from hub) project = mlrun.new_project(project_name, "./proj) project.set_function("mycode.py", "myfunc", image="mlrun/mlrun") - project.set_function("hub://auto_trainer", "train") + project.set_function("hub://auto-trainer", "train") # run functions (refer to them by name) run1 = run_function("myfunc", params={"x": 7}) @@ -112,7 +112,7 @@ def run_function( @dsl.pipeline(name="test pipeline", description="test") def my_pipe(url=""): - run1 = run_function("loaddata", params={"url": url}) + run1 = run_function("loaddata", params={"url": url}, outputs=["data"]) run2 = run_function("train", params={"label_columns": LABELS, "model_class": MODEL_CLASS}, inputs={"dataset": run1.outputs["data"]}) @@ -138,7 +138,7 @@ def my_pipe(url=""): :param verbose: add verbose prints/logs :param project_object: override the project object to use, will default to the project set in the runtime context. :param auto_build: when set to True and the function require build it will be built on the first - function run, use only if you dont plan on changing the build config between runs + function run, use only if you do not plan on changing the build config between runs :param schedule: ScheduleCronTrigger class instance or a standard crontab expression string (which will be converted to the class using its `from_crontab` constructor), see this link for help: @@ -236,6 +236,7 @@ def build_function( commands: list = None, secret_name=None, requirements: Union[str, List[str]] = None, + requirements_file: str = None, mlrun_version_specifier=None, builder_env: dict = None, project_object=None, @@ -250,7 +251,8 @@ def build_function( :param base_image: base image name/path (commands and source code will be added to it) :param commands: list of docker build (RUN) commands e.g. ['pip install pandas'] :param secret_name: k8s secret for accessing the docker registry - :param requirements: list of python packages or pip requirements file path, defaults to None + :param requirements: list of python packages, defaults to None + :param requirements_file: pip requirements file path, defaults to None :param mlrun_version_specifier: which mlrun package version to include (if not current) :param builder_env: Kaniko builder pod env vars dict (for config/credentials) e.g. builder_env={"GIT_TOKEN": token}, does not work yet in KFP @@ -269,7 +271,7 @@ def build_function( if overwrite_build_params: function.spec.build.commands = None if requirements: - function.with_requirements(requirements) + function.with_requirements(requirements, requirements_file) if commands: function.with_commands(commands) return function.deploy_step( diff --git a/mlrun/projects/pipelines.py b/mlrun/projects/pipelines.py index 8d426bcbdf2e..22f19106e114 100644 --- a/mlrun/projects/pipelines.py +++ b/mlrun/projects/pipelines.py @@ -27,7 +27,7 @@ from kfp.compiler import compiler import mlrun -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.utils.notifications from mlrun.errors import err_to_str from mlrun.utils import ( @@ -79,7 +79,7 @@ def __init__( # TODO: deprecated, remove in 1.5.0 ttl=None, args_schema: dict = None, - schedule: typing.Union[str, mlrun.api.schemas.ScheduleCronTrigger] = None, + schedule: typing.Union[str, mlrun.common.schemas.ScheduleCronTrigger] = None, cleanup_ttl: int = None, ): if ttl: @@ -116,7 +116,13 @@ def get_source_file(self, context=""): self._tmp_path = workflow_path = workflow_fh.name else: workflow_path = self.path or "" - if context and not workflow_path.startswith("/"): + if ( + context + and not workflow_path.startswith("/") + # since the user may provide a path the includes the context, + # we need to make sure we don't add it twice + and not workflow_path.startswith(context) + ): workflow_path = os.path.join(context, workflow_path) return workflow_path @@ -279,7 +285,7 @@ def _enrich_kfp_pod_security_context(kfp_pod_template, function): if ( mlrun.runtimes.RuntimeKinds.is_local_runtime(function.kind) or mlrun.mlconf.function.spec.security_context.enrichment_mode - == mlrun.api.schemas.SecurityContextEnrichmentModes.disabled.value + == mlrun.common.schemas.SecurityContextEnrichmentModes.disabled.value ): return @@ -405,7 +411,7 @@ def enrich_function_object( f.spec.build.source = project.spec.source f.spec.build.load_source_on_run = project.spec.load_source_on_run f.spec.workdir = project.spec.workdir or project.spec.subpath - f.verify_base_image() + f.prepare_image_for_deploy() if project.spec.default_requirements: f.with_requirements(project.spec.default_requirements) @@ -705,7 +711,7 @@ def run( trace = traceback.format_exc() logger.error(trace) project.notifiers.push( - f"Workflow {workflow_id} run failed!, error: {e}\n{trace}", "error" + f":x: Workflow {workflow_id} run failed!, error: {e}\n{trace}", "error" ) state = mlrun.run.RunStatuses.failed mlrun.run.wait_for_runs_completion(pipeline_context.runs_map.values()) @@ -755,6 +761,7 @@ def _prepare_load_and_run_function( artifact_path: str, workflow_handler: str, namespace: str, + subpath: str, ) -> typing.Tuple[mlrun.runtimes.RemoteRuntime, "mlrun.RunObject"]: """ Helper function for creating the runspec of the load and run function. @@ -767,6 +774,7 @@ def _prepare_load_and_run_function( :param artifact_path: path to store artifacts :param workflow_handler: workflow function handler (for running workflow function directly) :param namespace: kubernetes namespace if other than default + :param subpath: project subpath (within the archive) :return: """ # Creating the load project and workflow running function: @@ -792,6 +800,7 @@ def _prepare_load_and_run_function( "engine": workflow_spec.engine, "local": workflow_spec.run_local, "schedule": workflow_spec.schedule, + "subpath": subpath, }, handler="mlrun.projects.load_and_run", ), @@ -826,8 +835,11 @@ def run( ) if "://" not in current_source: raise mlrun.errors.MLRunInvalidArgumentError( - f"remote workflows can only be performed by a project with remote source," - f" the given source '{current_source}' is not remote" + f"Remote workflows can only be performed by a project with remote source (e.g git:// or http://)," + f" but the specified source '{current_source}' is not remote. " + f"Either put your code in Git, or archive it and then set a source to it." + f" For more details, read" + f" https://docs.mlrun.org/en/latest/concepts/scheduled-jobs.html#scheduling-a-workflow" ) # Creating the load project and workflow running function: @@ -840,6 +852,7 @@ def run( artifact_path=artifact_path, workflow_handler=workflow_handler, namespace=namespace, + subpath=project.spec.subpath, ) # The returned engine for this runner is the engine of the workflow. @@ -874,7 +887,8 @@ def run( trace = traceback.format_exc() logger.error(trace) project.notifiers.push( - f"Workflow {workflow_name} run failed!, error: {e}\n{trace}", "error" + f":x: Workflow {workflow_name} run failed!, error: {e}\n{trace}", + "error", ) state = mlrun.run.RunStatuses.failed return _PipelineRunStatus( @@ -928,7 +942,11 @@ def create_pipeline(project, pipeline, functions, secrets=None, handler=None): if not handler and hasattr(mod, "pipeline"): handler = "pipeline" if not handler or not hasattr(mod, handler): - raise ValueError(f"pipeline function ({handler or 'pipeline'}) not found") + raise ValueError( + f"'workflow_handler' is not defined. " + f"Either provide it as set_workflow argument, or include a function named" + f" '{handler or 'pipeline'}' in your workflow .py file." + ) return getattr(mod, handler) @@ -967,7 +985,7 @@ def load_and_run( ttl: int = None, engine: str = None, local: bool = None, - schedule: typing.Union[str, mlrun.api.schemas.ScheduleCronTrigger] = None, + schedule: typing.Union[str, mlrun.common.schemas.ScheduleCronTrigger] = None, cleanup_ttl: int = None, ): """ @@ -1033,7 +1051,7 @@ def load_and_run( try: notification_pusher.push( message=message, - severity=mlrun.api.schemas.NotificationSeverity.ERROR, + severity=mlrun.common.schemas.NotificationSeverity.ERROR, ) except Exception as exc: diff --git a/mlrun/projects/project.py b/mlrun/projects/project.py index fac672ea52a7..f8851bdd1f03 100644 --- a/mlrun/projects/project.py +++ b/mlrun/projects/project.py @@ -14,7 +14,10 @@ import datetime import getpass import glob +import http +import importlib.util as imputil import json +import os.path import pathlib import shutil import tempfile @@ -31,21 +34,25 @@ import inflection import kfp import nuclio +import requests import yaml +from deprecated import deprecated -import mlrun.api.schemas +import mlrun.common.model_monitoring as model_monitoring_constants +import mlrun.common.schemas import mlrun.db import mlrun.errors +import mlrun.runtimes +import mlrun.runtimes.pod +import mlrun.runtimes.utils import mlrun.utils.regex -from mlrun.runtimes import RuntimeKinds from ..artifacts import Artifact, ArtifactProducer, DatasetArtifact, ModelArtifact from ..artifacts.manager import ArtifactManager, dict_to_artifact, extend_artifact_path from ..datastore import store_manager from ..features import Feature -from ..model import EntrypointParam, ModelObj +from ..model import EntrypointParam, ImageBuilder, ModelObj from ..run import code_to_function, get_object, import_function, new_function -from ..runtimes.utils import add_code_metadata from ..secrets import SecretsStore from ..utils import ( is_ipython, @@ -57,7 +64,6 @@ ) from ..utils.clones import clone_git, clone_tgz, clone_zip, get_repo_url from ..utils.helpers import ensure_git_branch, resolve_git_reference_from_source -from ..utils.model_monitoring import set_project_model_monitoring_credentials from ..utils.notifications import CustomNotificationPusher, NotificationTypes from .operations import ( BuildStatus, @@ -114,10 +120,10 @@ def new_project( example:: - # create a project with local and marketplace functions, a workflow, and an artifact + # create a project with local and hub functions, a workflow, and an artifact project = mlrun.new_project("myproj", "./", init_git=True, description="my new project") project.set_function('prep_data.py', 'prep-data', image='mlrun/mlrun', handler='prep_data') - project.set_function('hub://auto_trainer', 'train') + project.set_function('hub://auto-trainer', 'train') project.set_artifact('data', Artifact(target_path=data_url)) project.set_workflow('main', "./myflow.py") project.save() @@ -195,7 +201,7 @@ def new_project( if overwrite: logger.info(f"Deleting project {name} from MLRun DB due to overwrite") _delete_project_from_db( - name, secrets, mlrun.api.schemas.DeletionStrategy.cascade + name, secrets, mlrun.common.schemas.DeletionStrategy.cascade ) try: @@ -277,13 +283,22 @@ def load_project( clone_tgz(url, context, secrets, clone) elif url.endswith(".zip"): clone_zip(url, context, secrets, clone) - else: + elif url.startswith("db://") or "://" not in url: project = _load_project_from_db(url, secrets, user_project) project.spec.context = context if not path.isdir(context): makedirs(context) project.spec.subpath = subpath or project.spec.subpath + setup_file_path = path.join( + context, project.spec.subpath or "", "project_setup.py" + ) + project = _run_project_setup(project, setup_file_path) from_db = True + else: + raise mlrun.errors.MLRunInvalidArgumentError( + "Unsupported url scheme, supported schemes are: git://, db:// or " + ".zip/.tar.gz/.yaml file path (could be local or remote) or project name which will be loaded from DB" + ) if not repo: repo, url = init_repo(context, url, init_git) @@ -372,7 +387,7 @@ def get_or_create_project( # only loading project from db so no need to save it save=False, ) - logger.info(f"loaded project {name} from MLRun DB") + logger.info(f"Loaded project {name} from MLRun DB") return project except mlrun.errors.MLRunNotFoundError: @@ -390,7 +405,7 @@ def get_or_create_project( user_project=user_project, save=save, ) - message = f"loaded project {name} from {url or context}" + message = f"Loaded project {name} from {url or context}" if save: message = f"{message} and saved in MLRun DB" logger.info(message) @@ -406,16 +421,34 @@ def get_or_create_project( subpath=subpath, save=save, ) - message = f"created project {name}" + message = f"Created project {name}" if save: message = f"{message} and saved in MLRun DB" logger.info(message) return project +def _run_project_setup(project: "MlrunProject", setup_file_path: str): + """Run the project setup file if found""" + if not path.exists(setup_file_path): + return project + spec = imputil.spec_from_file_location("workflow", setup_file_path) + if spec is None: + raise ImportError(f"cannot import project setup file in {setup_file_path}") + mod = imputil.module_from_spec(spec) + spec.loader.exec_module(mod) + + if hasattr(mod, "setup"): + project = getattr(mod, "setup")(project) + else: + logger.warn("skipping setup, setup() handler was not found in project_setup.py") + return project + + def _load_project_dir(context, name="", subpath=""): subpath_str = subpath or "" fpath = path.join(context, subpath_str, "project.yaml") + setup_file_path = path.join(context, subpath_str, "project_setup.py") if path.isfile(fpath): with open(fpath) as fp: data = fp.read() @@ -435,6 +468,8 @@ def _load_project_dir(context, name="", subpath=""): }, } ) + elif path.exists(setup_file_path): + project = MlrunProject() else: raise mlrun.errors.MLRunNotFoundError( "project or function YAML not found in path" @@ -443,7 +478,7 @@ def _load_project_dir(context, name="", subpath=""): project.spec.context = context project.metadata.name = name or project.metadata.name project.spec.subpath = subpath - return project + return _run_project_setup(project, setup_file_path) def _add_username_to_project_name_if_needed(name, user_project): @@ -538,11 +573,13 @@ def __init__( goals=None, load_source_on_run=None, default_requirements: typing.Union[str, typing.List[str]] = None, - desired_state=mlrun.api.schemas.ProjectState.online.value, + desired_state=mlrun.common.schemas.ProjectState.online.value, owner=None, disable_auto_mount=None, workdir=None, default_image=None, + build=None, + custom_packagers: typing.List[typing.Tuple[str, bool]] = None, ): self.repo = None @@ -576,6 +613,13 @@ def __init__( self.disable_auto_mount = disable_auto_mount self.default_image = default_image + self.build = build + + # A list of custom packagers to include when running the functions of the project. A custom packager is stored + # in a tuple where the first index is the packager module's path (str) and the second is a flag (bool) for + # whether it is mandatory for a run (raise exception on collection error) or not. + self.custom_packagers = custom_packagers or [] + @property def source(self) -> str: """source url or git repo""" @@ -585,8 +629,6 @@ def source(self) -> str: if url: self._source = url - if self._source in [".", "./"]: - return path.abspath(self.context) return self._source @source.setter @@ -742,6 +784,54 @@ def remove_artifact(self, key): if key in self._artifacts: del self._artifacts[key] + @property + def build(self) -> ImageBuilder: + return self._build + + @build.setter + def build(self, build): + self._build = self._verify_dict(build, "build", ImageBuilder) + + def add_custom_packager(self, packager: str, is_mandatory: bool): + """ + Add a custom packager from the custom packagers list. + + :param packager: The packager module path to add. For example, if a packager `MyPackager` is in the + project's source at my_module.py, then the module path is: "my_module.MyPackager". + :param is_mandatory: Whether this packager must be collected during a run. If False, failing to collect it won't + raise an error during the packagers collection phase. + """ + # TODO: enable importing packagers from the hub. + if packager in [ + custom_packager[0] for custom_packager in self.custom_packagers + ]: + logger.warn( + f"The packager's module path '{packager}' is already registered in the project." + ) + return + self.custom_packagers.append((packager, is_mandatory)) + + def remove_custom_packager(self, packager: str): + """ + Remove a custom packager from the custom packagers list. + + :param packager: The packager module path to remove. + + :raise MLRunInvalidArgumentError: In case the packager was not in the list. + """ + # Look for the packager tuple in the list to remove it: + packager_tuple: typing.Tuple[str, bool] = None + for custom_packager in self.custom_packagers: + if custom_packager[0] == packager: + packager_tuple = custom_packager + + # If not found, raise an error, otherwise remove: + if packager_tuple is None: + raise mlrun.errors.MLRunInvalidArgumentError( + f"The packager module path '{packager}' is not registered in the project, hence it cannot be removed." + ) + self.custom_packagers.remove(packager_tuple) + def _source_repo(self): src = self.source if src: @@ -899,20 +989,28 @@ def source(self) -> str: def source(self, source): self.spec.source = source - def set_source(self, source, pull_at_runtime=False, workdir=None): + def set_source( + self, + source: str = "", + pull_at_runtime: bool = False, + workdir: Optional[str] = None, + ): """set the project source code path(can be git/tar/zip archive) - :param source: valid path to git, zip, or tar file, (or None for current) e.g. - git://github.com/mlrun/something.git - http://some/url/file.zip + :param source: valid absolute path or URL to git, zip, or tar file, (or None for current) e.g. + git://github.com/mlrun/something.git + http://some/url/file.zip + note path source must exist on the image or exist locally when run is local + (it is recommended to use 'workdir' when source is a filepath instead) :param pull_at_runtime: load the archive into the container at job runtime vs on build/deploy - :param workdir: the relative workdir path (under the context dir) + :param workdir: workdir path relative to the context dir or absolute """ + mlrun.utils.helpers.validate_builder_source(source, pull_at_runtime, workdir) + self.spec.load_source_on_run = pull_at_runtime self.spec.source = source or self.spec.source if self.spec.source.startswith("git://"): - source, reference, branch = resolve_git_reference_from_source(source) if not branch and not reference: logger.warn( @@ -925,20 +1023,23 @@ def set_source(self, source, pull_at_runtime=False, workdir=None): self.sync_functions() def get_artifact_uri( - self, key: str, category: str = "artifact", tag: str = None + self, key: str, category: str = "artifact", tag: str = None, iter: int = None ) -> str: """return the project artifact uri (store://..) from the artifact key example:: - uri = project.get_artifact_uri("my_model", category="model", tag="prod") + uri = project.get_artifact_uri("my_model", category="model", tag="prod", iter=0) :param key: artifact key/name :param category: artifact category (artifact, model, feature-vector, ..) :param tag: artifact version tag, default to latest version + :param iter: iteration number, default to no iteration """ uri = f"store://{category}s/{self.metadata.name}/{key}" - if tag: + if iter is not None: + uri = f"{uri}#{iter}" + if tag is not None: uri = f"{uri}:{tag}" return uri @@ -1022,7 +1123,7 @@ def set_workflow( engine=None, args_schema: typing.List[EntrypointParam] = None, handler=None, - schedule: typing.Union[str, mlrun.api.schemas.ScheduleCronTrigger] = None, + schedule: typing.Union[str, mlrun.common.schemas.ScheduleCronTrigger] = None, ttl=None, **args, ): @@ -1044,8 +1145,14 @@ def set_workflow( if not workflow_path: raise ValueError("valid workflow_path must be specified") if embed: - if self.spec.context and not workflow_path.startswith("/"): - workflow_path = path.join(self.spec.context, workflow_path) + if ( + self.context + and not workflow_path.startswith("/") + # since the user may provide a path the includes the context, + # we need to make sure we don't add it twice + and not workflow_path.startswith(self.context) + ): + workflow_path = path.join(self.context, workflow_path) with open(workflow_path, "r") as fp: txt = fp.read() workflow = {"name": name, "code": txt} @@ -1469,12 +1576,15 @@ def get_artifact(spec): with open(f"{temp_dir}/_body", "rb") as fp: artifact.spec._body = fp.read() artifact.target_path = "" + + # if the dataitem is not a file, it means we downloaded it from a remote source to a temp file, + # so we need to remove it after we're done with it + dataitem.remove_local() + return self.log_artifact( artifact, local_path=temp_dir, artifact_path=artifact_path ) - if dataitem.kind != "file": - remove(item_file) else: raise ValueError("unsupported file suffix, use .yaml, .json, or .zip") @@ -1514,6 +1624,7 @@ def set_function( with_repo: bool = None, tag: str = None, requirements: typing.Union[str, typing.List[str]] = None, + requirements_file: str = "", ) -> mlrun.runtimes.BaseRuntime: """update or add a function object to the project @@ -1522,7 +1633,7 @@ def set_function( object (s3://, v3io://, ..) MLRun DB e.g. db://project/func:ver - functions hub/market: e.g. hub://auto_trainer:master + functions hub/market: e.g. hub://auto-trainer:master examples:: @@ -1541,16 +1652,20 @@ def set_function( # by providing a path to a pip requirements file proj.set_function('my.py', requirements="requirements.txt") - :param func: function object or spec/code url, None refers to current Notebook - :param name: name of the function (under the project) - :param kind: runtime kind e.g. job, nuclio, spark, dask, mpijob - default: job - :param image: docker image to be used, can also be specified in - the function object/yaml - :param handler: default function handler to invoke (can only be set with .py/.ipynb files) - :param with_repo: add (clone) the current repo to the build source - :param tag: function version tag (none for 'latest', can only be set with .py/.ipynb files) - :param requirements: list of python packages or pip requirements file path + :param func: function object or spec/code url, None refers to current Notebook + :param name: name of the function (under the project), can be specified with a tag to support + versions (e.g. myfunc:v1) + :param kind: runtime kind e.g. job, nuclio, spark, dask, mpijob + default: job + :param image: docker image to be used, can also be specified in + the function object/yaml + :param handler: default function handler to invoke (can only be set with .py/.ipynb files) + :param with_repo: add (clone) the current repo to the build source + :param tag: function version tag (none for 'latest', can only be set with .py/.ipynb files) + if tag is specified and name is empty, the function key (under the project) + will be enriched with the tag value. (i.e. 'function-name:tag') + :param requirements: a list of python packages + :param requirements_file: path to a python requirements file :returns: project object """ @@ -1585,10 +1700,14 @@ def set_function( "requirements": requirements, } func = {k: v for k, v in function_dict.items() if v} - name, function_object = _init_function_from_dict(func, self) - func["name"] = name + resolved_function_name, function_object = _init_function_from_dict( + func, self + ) + func["name"] = resolved_function_name elif hasattr(func, "to_dict"): - name, function_object = _init_function_from_obj(func, self, name=name) + resolved_function_name, function_object = _init_function_from_obj( + func, self, name=name + ) if handler: raise ValueError( "default handler cannot be set for existing function object" @@ -1596,15 +1715,23 @@ def set_function( if image: function_object.spec.image = image if with_repo: + # mark source to be enriched before run with project source (enrich_function_object) function_object.spec.build.source = "./" if requirements: - function_object.with_requirements(requirements) - if not name: + function_object.with_requirements( + requirements, requirements_file=requirements_file + ) + if not resolved_function_name: raise ValueError("function name must be specified") else: raise ValueError("func must be a function url or object") - self.spec.set_function(name, function_object, func) + # if function name was not explicitly provided, + # we use the resolved name (from the function object) and add the tag + if tag and not name and ":" not in resolved_function_name: + resolved_function_name = f"{resolved_function_name}:{tag}" + + self.spec.set_function(resolved_function_name, function_object, func) return function_object def remove_function(self, name): @@ -1621,33 +1748,65 @@ def get_function( enrich=False, ignore_cache=False, copy_function=True, + tag: str = "", ) -> mlrun.runtimes.BaseRuntime: """get function object by name - :param key: name of key for search - :param sync: will reload/reinit the function from the project spec - :param enrich: add project info/config/source info to the function object - :param ignore_cache: read the function object from the DB (ignore the local cache) - :param copy_function: return a copy of the function object + :param key: name of key for search + :param sync: will reload/reinit the function from the project spec + :param enrich: add project info/config/source info to the function object + :param ignore_cache: read the function object from the DB (ignore the local cache) + :param copy_function: return a copy of the function object + :param tag: provide if the function key is tagged under the project (function was set with a tag) :returns: function object """ - if key in self.spec._function_objects and not sync and not ignore_cache: - function = self.spec._function_objects[key] - elif key in self.spec._function_definitions and not ignore_cache: - self.sync_functions([key]) - function = self.spec._function_objects[key] - else: - function = get_db_function(self, key) - self.spec._function_objects[key] = function + if tag and ":" not in key: + key = f"{key}:{tag}" + + function, err = self._get_function( + mlrun.utils.normalize_name(key), sync, ignore_cache + ) + if not function and "_" in key: + function, err = self._get_function(key, sync, ignore_cache) + + if not function: + raise err + if enrich: function = enrich_function_object( self, function, copy_function=copy_function ) self.spec._function_objects[key] = function + return function - def get_function_objects(self) -> typing.Dict[str, mlrun.runtimes.BaseRuntime]: + def _get_function(self, key, sync, ignore_cache): + """ + Function can be retrieved from the project spec (cache) or from the database. + In sync mode, we first perform a sync of the function_objects from the function_definitions, + and then returning it from the function_objects (if exists). + When not in sync mode, we verify and return from the function objects directly. + In ignore_cache mode, we query the function from the database rather than from the project spec. + """ + if key in self.spec._function_objects and not sync and not ignore_cache: + function = self.spec._function_objects[key] + + elif key in self.spec._function_definitions and not ignore_cache: + self.sync_functions([key]) + function = self.spec._function_objects[key] + else: + try: + function = get_db_function(self, key) + self.spec._function_objects[key] = function + except requests.HTTPError as exc: + if exc.response.status_code != http.HTTPStatus.NOT_FOUND.value: + raise exc + return None, exc + + return function, None + + def get_function_objects(self) -> FunctionsDict: """ "get a virtual dict with all the project functions ready for use in a pipeline""" self.sync_functions() return FunctionsDict(self) @@ -1746,7 +1905,7 @@ def sync_functions(self, names: list = None, always=True, save=False): if not names: names = self.spec._function_definitions.keys() funcs = {} - origin = add_code_metadata(self.spec.context) + origin = mlrun.runtimes.utils.add_code_metadata(self.spec.context) for name in names: f = self.spec._function_definitions.get(name) if not f: @@ -1813,7 +1972,7 @@ def set_secrets( self, secrets: dict = None, file_path: str = None, - provider: typing.Union[str, mlrun.api.schemas.SecretProviderName] = None, + provider: typing.Union[str, mlrun.common.schemas.SecretProviderName] = None, ): """set project secrets from dict or secrets env file when using a secrets file it should have lines in the form KEY=VALUE, comment line start with "#" @@ -1840,18 +1999,21 @@ def set_secrets( "must specify secrets OR file_path" ) if file_path: - secrets = dotenv.dotenv_values(file_path) - if None in secrets.values(): - raise mlrun.errors.MLRunInvalidArgumentError( - "env file lines must be in the form key=value" - ) + if path.isfile(file_path): + secrets = dotenv.dotenv_values(file_path) + if None in secrets.values(): + raise mlrun.errors.MLRunInvalidArgumentError( + "env file lines must be in the form key=value" + ) + else: + raise mlrun.errors.MLRunNotFoundError(f"{file_path} does not exist") # drop V3IO paths/credentials and MLrun service API address env_vars = { key: val for key, val in secrets.items() if key != "MLRUN_DBPATH" and not key.startswith("V3IO_") } - provider = provider or mlrun.api.schemas.SecretProviderName.kubernetes + provider = provider or mlrun.common.schemas.SecretProviderName.kubernetes mlrun.db.get_run_db().create_project_secrets( self.metadata.name, provider=provider, secrets=env_vars ) @@ -1896,7 +2058,9 @@ def run( ttl: int = None, engine: str = None, local: bool = None, - schedule: typing.Union[str, mlrun.api.schemas.ScheduleCronTrigger, bool] = None, + schedule: typing.Union[ + str, mlrun.common.schemas.ScheduleCronTrigger, bool + ] = None, timeout: int = None, overwrite: bool = False, source: str = None, @@ -1972,7 +2136,10 @@ def run( self.sync_functions(always=sync) if not self.spec._function_objects: - raise ValueError("no functions in the project") + raise ValueError( + "There are no functions in the project." + " Make sure you've set your functions with project.set_function()." + ) if not name and not workflow_path and not workflow_handler: if self.spec.workflows: @@ -1985,9 +2152,9 @@ def run( else: workflow_spec = self.spec._workflows[name].copy() workflow_spec.merge_args(arguments) - workflow_spec.cleanup_ttl = ( - cleanup_ttl or ttl or workflow_spec.cleanup_ttl or workflow_spec.ttl - ) + workflow_spec.cleanup_ttl = ( + cleanup_ttl or ttl or workflow_spec.cleanup_ttl or workflow_spec.ttl + ) workflow_spec.run_local = local name = f"{self.metadata.name}-{name}" if name else self.metadata.name @@ -2073,14 +2240,43 @@ def get_run_status( notifiers=notifiers, ) + # TODO: remove in 1.6.0 + @deprecated( + version="1.4.0", + reason="'clear_context' will be removed in 1.6.0, this can cause unexpected issues", + category=FutureWarning, + ) def clear_context(self): """delete all files and clear the context dir""" - if ( - self.spec.context - and path.exists(self.spec.context) - and path.isdir(self.spec.context) - ): - shutil.rmtree(self.spec.context) + warnings.warn( + "This method deletes all files and clears the context directory or subpath (if defined)!" + " Please keep in mind that this method can produce unexpected outcomes and is not recommended," + " it will be deprecated in 1.6.0." + ) + # clear only if the context path exists and not relative + if self.spec.context and os.path.isabs(self.spec.context): + + # if a subpath is defined, will empty the subdir instead of the entire context + if self.spec.subpath: + path_to_clear = path.join(self.spec.context, self.spec.subpath) + logger.info(f"Subpath is defined, Clearing path: {path_to_clear}") + else: + path_to_clear = self.spec.context + logger.info( + f"Subpath is not defined, Clearing context: {path_to_clear}" + ) + if path.exists(path_to_clear) and path.isdir(path_to_clear): + shutil.rmtree(path_to_clear) + else: + logger.warn( + f"Attempt to clear {path_to_clear} failed. Path either does not exist or is not a directory." + " Please ensure that your context or subdpath are properly defined." + ) + else: + logger.warn( + "Your context path is a relative path;" + " in order to avoid unexpected results, we do not allow the deletion of relative paths." + ) def save(self, filepath=None, store=True): """export project to yaml file and save project in database @@ -2139,15 +2335,43 @@ def export(self, filepath=None, include_files: str = None): mlrun.get_dataitem(filepath).upload(tmp_path) remove(tmp_path) - def set_model_monitoring_credentials(self, access_key: str): + def set_model_monitoring_credentials( + self, + access_key: str = None, + endpoint_store_connection: str = None, + stream_path: str = None, + ): """Set the credentials that will be used by the project's model monitoring infrastructure functions. - The supplied credentials must have data access - :param access_key: Model Monitoring access key for managing user permissions. + :param access_key: Model Monitoring access key for managing user permissions + :param endpoint_store_connection: Endpoint store connection string + :param stream_path: Path to the model monitoring stream """ - set_project_model_monitoring_credentials( - access_key=access_key, project=self.metadata.name + + secrets_dict = {} + if access_key: + secrets_dict[ + model_monitoring_constants.ProjectSecretKeys.ACCESS_KEY + ] = access_key + + if endpoint_store_connection: + secrets_dict[ + model_monitoring_constants.ProjectSecretKeys.ENDPOINT_STORE_CONNECTION + ] = endpoint_store_connection + + if stream_path: + if stream_path.startswith("kafka://") and "?topic" in stream_path: + raise mlrun.errors.MLRunInvalidArgumentError( + "Custom kafka topic is not allowed" + ) + secrets_dict[ + model_monitoring_constants.ProjectSecretKeys.STREAM_PATH + ] = stream_path + + self.set_secrets( + secrets=secrets_dict, + provider=mlrun.common.schemas.SecretProviderName.kubernetes, ) def run_function( @@ -2168,7 +2392,7 @@ def run_function( verbose: bool = None, selector: str = None, auto_build: bool = None, - schedule: typing.Union[str, mlrun.api.schemas.ScheduleCronTrigger] = None, + schedule: typing.Union[str, mlrun.common.schemas.ScheduleCronTrigger] = None, artifact_path: str = None, notifications: typing.List[mlrun.model.Notification] = None, returns: Optional[List[Union[str, Dict[str, str]]]] = None, @@ -2177,10 +2401,10 @@ def run_function( example (use with project):: - # create a project with two functions (local and from marketplace) + # create a project with two functions (local and from hub) project = mlrun.new_project(project_name, "./proj") project.set_function("mycode.py", "myfunc", image="mlrun/mlrun") - project.set_function("hub://auto_trainer", "train") + project.set_function("hub://auto-trainer", "train") # run functions (refer to them by name) run1 = project.run_function("myfunc", params={"x": 7}) @@ -2256,28 +2480,30 @@ def build_function( function: typing.Union[str, mlrun.runtimes.BaseRuntime], with_mlrun: bool = None, skip_deployed: bool = False, - image=None, - base_image=None, + image: str = None, + base_image: str = None, commands: list = None, - secret_name=None, + secret_name: str = None, requirements: typing.Union[str, typing.List[str]] = None, - mlrun_version_specifier=None, + mlrun_version_specifier: str = None, builder_env: dict = None, overwrite_build_params: bool = False, + requirements_file: str = None, ) -> typing.Union[BuildStatus, kfp.dsl.ContainerOp]: """deploy ML function, build container with its dependencies - :param function: name of the function (in the project) or function object - :param with_mlrun: add the current mlrun package to the container build - :param skip_deployed: skip the build if we already have an image for the function - :param image: target image name/path - :param base_image: base image name/path (commands and source code will be added to it) - :param commands: list of docker build (RUN) commands e.g. ['pip install pandas'] - :param secret_name: k8s secret for accessing the docker registry - :param requirements: list of python packages or pip requirements file path, defaults to None + :param function: name of the function (in the project) or function object + :param with_mlrun: add the current mlrun package to the container build + :param skip_deployed: skip the build if we already have an image for the function + :param image: target image name/path + :param base_image: base image name/path (commands and source code will be added to it) + :param commands: list of docker build (RUN) commands e.g. ['pip install pandas'] + :param secret_name: k8s secret for accessing the docker registry + :param requirements: list of python packages, defaults to None + :param requirements_file: pip requirements file path, defaults to None :param mlrun_version_specifier: which mlrun package version to include (if not current) - :param builder_env: Kaniko builder pod env vars dict (for config/credentials) - e.g. builder_env={"GIT_TOKEN": token}, does not work yet in KFP + :param builder_env: Kaniko builder pod env vars dict (for config/credentials) + e.g. builder_env={"GIT_TOKEN": token}, does not work yet in KFP :param overwrite_build_params: overwrite the function build parameters with the provided ones, or attempt to add to existing parameters """ @@ -2290,12 +2516,141 @@ def build_function( commands=commands, secret_name=secret_name, requirements=requirements, + requirements_file=requirements_file, mlrun_version_specifier=mlrun_version_specifier, builder_env=builder_env, project_object=self, overwrite_build_params=overwrite_build_params, ) + def build_config( + self, + image: str = None, + set_as_default: bool = False, + with_mlrun: bool = None, + base_image: str = None, + commands: list = None, + secret_name: str = None, + requirements: typing.Union[str, typing.List[str]] = None, + overwrite_build_params: bool = False, + requirements_file: str = None, + ): + """specify builder configuration for the project + + :param image: target image name/path. If not specified the project's existing `default_image` name will be + used. If not set, the `mlconf.default_project_image_name` value will be used + :param set_as_default: set `image` to be the project's default image (default False) + :param with_mlrun: add the current mlrun package to the container build + :param base_image: base image name/path + :param commands: list of docker build (RUN) commands e.g. ['pip install pandas'] + :param secret_name: k8s secret for accessing the docker registry + :param requirements: a list of packages to install on the built image + :param requirements_file: requirements file to install on the built image + :param overwrite_build_params: overwrite existing build configuration (default False) + + * False: the new params are merged with the existing (currently merge is applied to requirements and + commands) + * True: the existing params are replaced by the new ones + """ + default_image_name = mlrun.mlconf.default_project_image_name.format( + name=self.name + ) + image = image or self.default_image or default_image_name + + self.spec.build.build_config( + image=image, + base_image=base_image, + commands=commands, + secret=secret_name, + with_mlrun=with_mlrun, + requirements=requirements, + requirements_file=requirements_file, + overwrite=overwrite_build_params, + ) + + if set_as_default and image != self.default_image: + self.set_default_image(image) + + def build_image( + self, + image: str = None, + set_as_default: bool = True, + with_mlrun: bool = None, + skip_deployed: bool = False, + base_image: str = None, + commands: list = None, + secret_name: str = None, + requirements: typing.Union[str, typing.List[str]] = None, + mlrun_version_specifier: str = None, + builder_env: dict = None, + overwrite_build_params: bool = False, + requirements_file: str = None, + ) -> typing.Union[BuildStatus, kfp.dsl.ContainerOp]: + """Builder docker image for the project, based on the project's build config. Parameters allow to override + the build config. + + :param image: target image name/path. If not specified the project's existing `default_image` name will be + used. If not set, the `mlconf.default_project_image_name` value will be used + :param set_as_default: set `image` to be the project's default image (default False) + :param with_mlrun: add the current mlrun package to the container build + :param skip_deployed: skip the build if we already have the image specified built + :param base_image: base image name/path (commands and source code will be added to it) + :param commands: list of docker build (RUN) commands e.g. ['pip install pandas'] + :param secret_name: k8s secret for accessing the docker registry + :param requirements: list of python packages, defaults to None + :param requirements_file: pip requirements file path, defaults to None + :param mlrun_version_specifier: which mlrun package version to include (if not current) + :param builder_env: Kaniko builder pod env vars dict (for config/credentials) + e.g. builder_env={"GIT_TOKEN": token}, does not work yet in KFP + :param overwrite_build_params: overwrite existing build configuration (default False) + + * False: the new params are merged with the existing (currently merge is applied to requirements and + commands) + * True: the existing params are replaced by the new ones + """ + + self.build_config( + image=image, + set_as_default=set_as_default, + base_image=base_image, + commands=commands, + secret_name=secret_name, + with_mlrun=with_mlrun, + requirements=requirements, + requirements_file=requirements_file, + overwrite_build_params=overwrite_build_params, + ) + + function = mlrun.new_function("mlrun--project--image--builder", kind="job") + + build = self.spec.build + result = self.build_function( + function=function, + with_mlrun=build.with_mlrun, + image=build.image, + base_image=build.base_image, + commands=build.commands, + secret_name=build.secret, + requirements=build.requirements, + skip_deployed=skip_deployed, + overwrite_build_params=overwrite_build_params, + mlrun_version_specifier=mlrun_version_specifier, + builder_env=builder_env, + ) + + try: + mlrun.db.get_run_db(secrets=self._secrets).delete_function( + name=function.metadata.name + ) + except Exception as exc: + logger.warning( + f"Image was successfully built, but failed to delete temporary function {function.metadata.name}." + " To remove the function, attempt to manually delete it.", + exc=repr(exc), + ) + + return result + def deploy_function( self, function: typing.Union[str, mlrun.runtimes.BaseRuntime], @@ -2352,7 +2707,7 @@ def list_artifacts( iter: int = None, best_iteration: bool = False, kind: str = None, - category: typing.Union[str, mlrun.api.schemas.ArtifactCategories] = None, + category: typing.Union[str, mlrun.common.schemas.ArtifactCategories] = None, ) -> mlrun.lists.ArtifactList: """List artifacts filtered by various parameters. @@ -2460,17 +2815,17 @@ def list_functions(self, name=None, tag=None, labels=None): def list_runs( self, - name=None, - uid=None, - labels=None, - state=None, - sort=True, - last=0, - iter=False, - start_time_from: datetime.datetime = None, - start_time_to: datetime.datetime = None, - last_update_time_from: datetime.datetime = None, - last_update_time_to: datetime.datetime = None, + name: Optional[str] = None, + uid: Optional[Union[str, List[str]]] = None, + labels: Optional[Union[str, List[str]]] = None, + state: Optional[str] = None, + sort: bool = True, + last: int = 0, + iter: bool = False, + start_time_from: Optional[datetime.datetime] = None, + start_time_to: Optional[datetime.datetime] = None, + last_update_time_from: Optional[datetime.datetime] = None, + last_update_time_to: Optional[datetime.datetime] = None, **kwargs, ) -> mlrun.lists.RunList: """Retrieve a list of runs, filtered by various options. @@ -2484,6 +2839,10 @@ def list_runs( # return a list of runs matching the name and label and compare runs = project.list_runs(name='download', labels='owner=admin') runs.compare() + + # multi-label filter can also be provided + runs = project.list_runs(name='download', labels=["kind=job", "owner=admin"]) + # If running in Jupyter, can use the .show() function to display the results project.list_runs(name='').show() @@ -2491,8 +2850,8 @@ def list_runs( :param name: Name of the run to retrieve. :param uid: Unique ID of the run. :param project: Project that the runs belongs to. - :param labels: List runs that have a specific label assigned. Currently only a single label filter can be - applied, otherwise result will be empty. + :param labels: List runs that have specific labels assigned. a single or multi label filter can be + applied. :param state: List only runs whose state is specified. :param sort: Whether to sort the result according to their start time. Otherwise, results will be returned by their internal order in the DB (order will not be guaranteed). @@ -2521,13 +2880,53 @@ def list_runs( **kwargs, ) + def get_custom_packagers(self) -> typing.List[typing.Tuple[str, bool]]: + """ + Get the custom packagers registered in the project. + + :return: A list of the custom packagers module paths. + """ + # Return a copy so the user won't be able to edit the list by the reference returned (no need for deep copy as + # tuples do not support item assignment): + return self.spec.custom_packagers.copy() + + def add_custom_packager(self, packager: str, is_mandatory: bool): + """ + Add a custom packager from the custom packagers list. All project's custom packagers are added to each project + function. + + **Notice** that in order to run a function with the custom packagers included, you must set a source for the + project (using the `project.set_source` method) with the parameter `pull_at_runtime=True` so the source code of + the packagers will be able to be imported. + + :param packager: The packager module path to add. For example, if a packager `MyPackager` is in the + project's source at my_module.py, then the module path is: "my_module.MyPackager". + :param is_mandatory: Whether this packager must be collected during a run. If False, failing to collect it won't + raise an error during the packagers collection phase. + """ + self.spec.add_custom_packager(packager=packager, is_mandatory=is_mandatory) + + def remove_custom_packager(self, packager: str): + """ + Remove a custom packager from the custom packagers list. + + :param packager: The packager module path to remove. + + :raise MLRunInvalidArgumentError: In case the packager was not in the list. + """ + self.spec.remove_custom_packager(packager=packager) + def _set_as_current_default_project(project: MlrunProject): mlrun.mlconf.default_project = project.metadata.name pipeline_context.set(project) -def _init_function_from_dict(f, project, name=None): +def _init_function_from_dict( + f: dict, + project: MlrunProject, + name: typing.Optional[str] = None, +) -> typing.Tuple[str, mlrun.runtimes.BaseRuntime]: name = name or f.get("name", "") url = f.get("url", "") kind = f.get("kind", "") @@ -2593,6 +2992,7 @@ def _init_function_from_dict(f, project, name=None): raise ValueError(f"unsupported function url:handler {url}:{handler} or no spec") if with_repo: + # mark source to be enriched before run with project source (enrich_function_object) func.spec.build.source = "./" if requirements: func.with_requirements(requirements) @@ -2600,7 +3000,11 @@ def _init_function_from_dict(f, project, name=None): return _init_function_from_obj(func, project, name) -def _init_function_from_obj(func, project, name=None): +def _init_function_from_obj( + func: mlrun.runtimes.BaseRuntime, + project: MlrunProject, + name: typing.Optional[str] = None, +) -> typing.Tuple[str, mlrun.runtimes.BaseRuntime]: build = func.spec.build if project.spec.origin_url: origin = project.spec.origin_url @@ -2620,7 +3024,9 @@ def _init_function_from_obj(func, project, name=None): def _has_module(handler, kind): if not handler: return False - return (kind in RuntimeKinds.nuclio_runtimes() and ":" in handler) or "." in handler + return ( + kind in mlrun.runtimes.RuntimeKinds.nuclio_runtimes() and ":" in handler + ) or "." in handler def _is_imported_artifact(artifact): diff --git a/mlrun/run.py b/mlrun/run.py index 06873a5f8a39..39df3c8727e2 100644 --- a/mlrun/run.py +++ b/mlrun/run.py @@ -11,32 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -import importlib import importlib.util as imputil -import inspect import json import os import pathlib -import re import socket import tempfile import time import uuid import warnings from base64 import b64decode -from collections import OrderedDict from copy import deepcopy from os import environ, makedirs, path from pathlib import Path -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Union import nuclio import yaml from deprecated import deprecated from kfp import Client -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.errors import mlrun.utils.helpers from mlrun.kfpops import format_summary_from_kfp_run, show_kfp_run @@ -44,6 +39,7 @@ from .config import config as mlconf from .datastore import store_manager from .db import get_or_set_dburl, get_run_db +from .errors import MLRunInvalidArgumentError, MLRunTimeoutError from .execution import MLClientCtx from .model import BaseMetadata, RunObject, RunTemplate from .runtimes import ( @@ -61,7 +57,6 @@ get_runtime_class, ) from .runtimes.funcdoc import update_function_entry_points -from .runtimes.package.context_handler import ArtifactType, ContextHandler from .runtimes.serving import serving_subkind from .runtimes.utils import add_code_metadata, global_context from .utils import ( @@ -111,6 +106,12 @@ def transient_statuses(): ] +# TODO: remove in 1.6.0 +@deprecated( + version="1.4.0", + reason="'run_local' will be removed in 1.6.0, use 'function.run(local=True)' instead", + category=FutureWarning, +) def run_local( task=None, command="", @@ -162,6 +163,7 @@ def run_local( (allows to have function which don't depend on having targets, e.g a function which accepts a feature vector uri and generate the offline vector e.g. parquet_ for it if it doesn't exist) + :param notifications: list of notifications to push when the run is completed :param returns: List of configurations for how to log the returning values from the handler's run (as artifacts or results). The list's length must be equal to the amount of returning objects. A configuration may be given as: @@ -236,7 +238,7 @@ def function_to_module(code="", workdir=None, secrets=None, silent=False): mod.my_job(context, p1=1, p2='x') print(context.to_yaml()) - fn = mlrun.import_function('hub://open_archive') + fn = mlrun.import_function('hub://open-archive') mod = mlrun.function_to_module(fn) data = mlrun.run.get_dataitem("https://fpsignals-public.s3.amazonaws.com/catsndogs.tar.gz") context = mlrun.get_or_create_ctx('myfunc') @@ -458,7 +460,7 @@ def import_function(url="", secrets=None, db="", project=None, new_name=None): examples:: - function = mlrun.import_function("hub://auto_trainer") + function = mlrun.import_function("hub://auto-trainer") function = mlrun.import_function("./func.yaml") function = mlrun.import_function("https://raw.githubusercontent.com/org/repo/func.yaml") @@ -550,6 +552,7 @@ def new_function( source: str = None, requirements: Union[str, List[str]] = None, kfp=None, + requirements_file: str = "", ): """Create a new ML function from base properties @@ -585,9 +588,13 @@ def new_function( (job, mpijob, ..) the handler can also be specified in the `.run()` command, when not specified the entire file will be executed (as main). for nuclio functions the handler is in the form of module:function, defaults to "main:handler" - :param source: valid path to git, zip, or tar file, e.g. `git://github.com/mlrun/something.git`, + :param source: valid absolute path or URL to git, zip, or tar file, e.g. + `git://github.com/mlrun/something.git`, `http://some/url/file.zip` - :param requirements: list of python packages or pip requirements file path, defaults to None + note path source must exist on the image or exist locally when run is local + (it is recommended to use 'function.spec.workdir' when source is a filepath instead) + :param requirements: a list of python packages, defaults to None + :param requirements_file: path to a python requirements file :param kfp: reserved, flag indicating running within kubeflow pipeline :return: function object @@ -643,7 +650,7 @@ def new_function( runner.spec.build.source = source if handler: if kind == RuntimeKinds.serving: - raise mlrun.errors.MLRunInvalidArgumentError( + raise MLRunInvalidArgumentError( "cannot set the handler for serving runtime" ) elif kind in RuntimeKinds.nuclio_runtimes(): @@ -652,8 +659,13 @@ def new_function( runner.spec.default_handler = handler if requirements: - runner.with_requirements(requirements) - runner.verify_base_image() + runner.with_requirements( + requirements, + requirements_file=requirements_file, + prepare_image_for_deploy=False, + ) + + runner.prepare_image_for_deploy() return runner @@ -693,6 +705,7 @@ def code_to_function( labels: Dict[str, str] = None, with_doc: bool = True, ignored_tags=None, + requirements_file: str = "", ) -> Union[ MpiRuntimeV1Alpha1, MpiRuntimeV1, @@ -729,8 +742,7 @@ def code_to_function( - spark: run distributed Spark job using Spark Kubernetes Operator - remote-spark: run distributed Spark job on remote Spark service - Learn more about function runtimes here: - https://docs.mlrun.org/en/latest/runtimes/functions.html#function-runtimes + Learn more about {Kinds of function (runtimes)](../concepts/functions-overview.html). :param name: function name, typically best to use hyphen-case :param project: project used to namespace the function, defaults to 'default' @@ -747,6 +759,8 @@ def code_to_function( defaults to True :param description: short function description, defaults to '' :param requirements: list of python packages or pip requirements file path, defaults to None + :param requirements: a list of python packages + :param requirements_file: path to a python requirements file :param categories: list of categories for mlrun Function Hub, defaults to None :param labels: immutable name/value pairs to tag the function with useful metadata, defaults to None :param with_doc: indicates whether to document the function parameters, defaults to True @@ -793,12 +807,13 @@ def add_name(origin, name=""): def update_common(fn, spec): fn.spec.image = image or get_in(spec, "spec.image", "") + fn.spec.filename = filename or get_in(spec, "spec.filename", "") fn.spec.build.base_image = get_in(spec, "spec.build.baseImage") fn.spec.build.commands = get_in(spec, "spec.build.commands") fn.spec.build.secret = get_in(spec, "spec.build.secret") if requirements: - fn.with_requirements(requirements) + fn.with_requirements(requirements, requirements_file=requirements_file) if embed_code: fn.spec.build.functionSourceCode = get_in( @@ -920,7 +935,7 @@ def resolve_nuclio_subkind(kind): build.image = get_in(spec, "spec.build.image") update_common(r, spec) - r.verify_base_image() + r.prepare_image_for_deploy() if with_doc: update_function_entry_points(r, code) @@ -1136,22 +1151,24 @@ def wait_for_pipeline_completion( if remote: mldb = mlrun.db.get_run_db() - def get_pipeline_if_completed(run_id, namespace=namespace): - resp = mldb.get_pipeline(run_id, namespace=namespace, project=project) - status = resp["run"]["status"] - show_kfp_run(resp, clear_output=True) - if status not in RunStatuses.stable_statuses(): - # TODO: think of nicer liveness indication and make it re-usable - # log '.' each retry as a liveness indication - logger.debug(".") + def _wait_for_pipeline_completion(): + pipeline = mldb.get_pipeline(run_id, namespace=namespace, project=project) + pipeline_status = pipeline["run"]["status"] + show_kfp_run(pipeline, clear_output=True) + if pipeline_status not in RunStatuses.stable_statuses(): + logger.debug( + "Waiting for pipeline completion", + run_id=run_id, + status=pipeline_status, + ) raise RuntimeError("pipeline run has not completed yet") - return resp + return pipeline if mldb.kind != "http": raise ValueError( - "get pipeline require access to remote api-service" - ", please set the dbpath url" + "get pipeline requires access to remote api-service" + ", set the dbpath url" ) resp = retry_until_successful( @@ -1159,9 +1176,7 @@ def get_pipeline_if_completed(run_id, namespace=namespace): timeout, logger, False, - get_pipeline_if_completed, - run_id, - namespace=namespace, + _wait_for_pipeline_completion, ) else: client = Client(namespace=namespace) @@ -1194,8 +1209,8 @@ def get_pipeline( run_id, namespace=None, format_: Union[ - str, mlrun.api.schemas.PipelinesFormat - ] = mlrun.api.schemas.PipelinesFormat.summary, + str, mlrun.common.schemas.PipelinesFormat + ] = mlrun.common.schemas.PipelinesFormat.summary, project: str = None, remote: bool = True, ): @@ -1231,7 +1246,7 @@ def get_pipeline( resp = resp.to_dict() if ( not format_ - or format_ == mlrun.api.schemas.PipelinesFormat.summary.value + or format_ == mlrun.common.schemas.PipelinesFormat.summary.value ): resp = format_summary_from_kfp_run(resp) @@ -1247,7 +1262,7 @@ def list_pipelines( filter_="", namespace=None, project="*", - format_: mlrun.api.schemas.PipelinesFormat = mlrun.api.schemas.PipelinesFormat.metadata_only, + format_: mlrun.common.schemas.PipelinesFormat = mlrun.common.schemas.PipelinesFormat.metadata_only, ) -> Tuple[int, Optional[int], List[dict]]: """List pipelines @@ -1267,7 +1282,7 @@ def list_pipelines( :param format_: Control what will be returned (full/metadata_only/name_only) """ if full: - format_ = mlrun.api.schemas.PipelinesFormat.full + format_ = mlrun.common.schemas.PipelinesFormat.full run_db = mlrun.db.get_run_db() pipelines = run_db.list_pipelines( project, namespace, sort_by, page_token, filter_, format_, page_size @@ -1333,291 +1348,7 @@ def wait_for_runs_completion(runs: list, sleep=3, timeout=0, silent=False): if timeout and total_time > timeout: if silent: break - raise mlrun.errors.MLRunTimeoutError( - "some runs did not reach terminal state on time" - ) + raise MLRunTimeoutError("some runs did not reach terminal state on time") runs = running return completed - - -def _parse_type_hint(type_hint: Union[Type, str]) -> Type: - """ - Parse a given type hint from string to its actual hinted type class object. The string must be one of the following: - - * Python builtin type - one of ``tuple``, ``list``, ``set``, ``dict`` and ``bytearray``. - * Full module import path. An alias is not allowed (if ``import pandas as pd`` is used, the type hint cannot be - ``pd.DataFrame`` but ``pandas.DataFrame``). - - The type class on its own (like `DataFrame`) cannot be used as the scope of the decorator is not the same as the - handler itself, hence modules and objects that were imported in the handler's scope are not available. This is the - same reason import aliases cannot be used as well. - - If the provided type hint is not a string, it will simply be returned as is. - - **Notice**: This method should only run on client side as it dependent on user requirements. - - :param type_hint: The type hint to parse. - - :return: The hinted type. - - :raise MLRunInvalidArgumentError: In case the type hint is not following the 2 options mentioned above. - """ - if not isinstance(type_hint, str): - return type_hint - - # TODO: Remove once Packager is implemented (it will support typing hints) - # If a typing hint is provided, we return a dummy Union type so the parser will skip the data item: - if type_hint.startswith("typing."): - return Union[int, str] - - # Validate the type hint is a valid module path: - if not bool( - re.fullmatch(r"([a-zA-Z_][a-zA-Z0-9_]*\.)*[a-zA-Z_][a-zA-Z0-9_]*", type_hint) - ): - raise mlrun.errors.MLRunInvalidArgumentError( - f"Invalid type hint. An input type hint must be a valid python class name or its module import path. For " - f"example: 'list', 'pandas.DataFrame', 'numpy.ndarray', 'sklearn.linear_model.LinearRegression'. Type hint " - f"given: '{type_hint}'." - ) - - # Look for a builtin type (rest of the builtin types like `int`, `str`, `float` should be treated as results, hence - # not given as an input to an MLRun function, but as a parameter): - builtin_types = { - tuple.__name__: tuple, - list.__name__: list, - set.__name__: set, - dict.__name__: dict, - bytearray.__name__: bytearray, - } - if type_hint in builtin_types: - return builtin_types[type_hint] - - # If it's not a builtin, its should have a full module path: - if "." not in type_hint: - raise mlrun.errors.MLRunInvalidArgumentError( - f"MLRun tried to get the type hint '{type_hint}' but it can't as it is not a valid builtin Python type " - f"(one of {', '.join(list(builtin_types.keys()))}). Pay attention using only the type as string is not " - f"allowed as the handler's scope is different then MLRun's. To properly give a type hint, please specify " - f"the full module path. For example: do not use `DataFrame`, use `pandas.DataFrame`." - ) - - # Import the module to receive the hinted type: - try: - # Get the module path and the type class (If we'll wish to support inner classes, the `rsplit` won't work): - module_path, type_hint = type_hint.rsplit(".", 1) - # Replace alias if needed (alias assumed to be imported already, hence we look in globals): - # For example: - # If in handler scope there was `import A.B.C as abc` and user gave a type hint "abc.Something" then: - # `module_path[0]` will be equal to "abc". Then, because it is an alias, it will appear in the globals, so we'll - # replace the alias with the full module name in order to import the module. - module_path = module_path.split(".") - if module_path[0] in globals(): - module_path[0] = globals()[module_path[0]].__name__ - module_path = ".".join(module_path) - # Import the module: - module = importlib.import_module(module_path) - # Get the class type from the module: - type_hint = getattr(module, type_hint) - except ModuleNotFoundError as module_not_found_error: - # May be raised from `importlib.import_module` in case the module does not exist. - raise mlrun.errors.MLRunInvalidArgumentError( - f"MLRun tried to get the type hint '{type_hint}' but the module '{module_path}' cannot be imported. " - f"Keep in mind that using alias in the module path (meaning: import module as alias) is not allowed. " - f"If the module path is correct, please make sure the module package is installed in the python " - f"interpreter." - ) from module_not_found_error - except AttributeError as attribute_error: - # May be raised from `getattr(module, type_hint)` in case the class type cannot be imported directly from the - # imported module. - raise mlrun.errors.MLRunInvalidArgumentError( - f"MLRun tried to get the type hint '{type_hint}' from the module '{module.__name__}' but it seems it " - f"doesn't exist. Make sure the class can be imported from the module with the exact module path you " - f"passed. Notice inner classes (a class inside of a class) are not supported." - ) from attribute_error - - return type_hint - - -def _parse_log_hint( - log_hint: Union[Dict[str, str], str, None] -) -> Union[Dict[str, str], None]: - """ - Parse a given log hint from string to a logging configuration dictionary. The string will be read as the artifact - key ('key' in the dictionary) and if the string have a single colon, the following structure is assumed: - " : ". The artifact type must be on of the values of `ArtifactType`'s enum. - - If a logging configuration dictionary is received, it will be validated to have a key field and valid artifact type - value. - - None will be returned as None. - - :param log_hint: The log hint to parse. - - :return: The hinted logging configuration. - - :raise MLRunInvalidArgumentError: In case the log hint is not following the string structure, the artifact type is - not valid or the dictionary is missing the key field. - """ - # Check for None value: - if log_hint is None: - return None - - # If the log hint was provided as a string, construct a dictionary out of it: - if isinstance(log_hint, str): - # Check if only key is given: - if ":" not in log_hint: - log_hint = {"key": log_hint} - # Check for valid " : " pattern: - else: - if log_hint.count(":") > 1: - raise mlrun.errors.MLRunInvalidArgumentError( - f"Incorrect log hint pattern. Output keys can have only a single ':' in them to specify the " - f"desired artifact type the returned value will be logged as: ' : ', " - f"but given: {log_hint}" - ) - # Split into key and type: - key, artifact_type = log_hint.replace(" ", "").split(":") - log_hint = {"key": key, "artifact_type": artifact_type} - - # TODO: Replace with constants keys once mlrun.package is implemented. - # Validate the log hint dictionary has the mandatory key: - if "key" not in log_hint: - raise mlrun.errors.MLRunInvalidArgumentError( - f"An output log hint dictionary must include the 'key' - the artifact key (it's name). The following " - f"log hint is missing the key: {log_hint}." - ) - - # Validate the artifact type is valid: - if "artifact_type" in log_hint: - valid_artifact_types = [t.value for t in ArtifactType.__members__.values()] - if log_hint["artifact_type"] not in valid_artifact_types: - raise mlrun.errors.MLRunInvalidArgumentError( - f"The following artifact type '{log_hint['artifact_type']}' is not a valid `ArtifactType`. " - f"Please select one of the following: {','.join(valid_artifact_types)}" - ) - - return log_hint - - -def handler( - labels: Dict[str, str] = None, - outputs: List[Union[str, Dict[str, str]]] = None, - inputs: Union[bool, Dict[str, Union[str, Type]]] = True, -): - """ - MLRun's handler is a decorator to wrap a function and enable setting labels, automatic `mlrun.DataItem` parsing and - outputs logging. - - :param labels: Labels to add to the run. Expecting a dictionary with the labels names as keys. Default: None. - :param outputs: Logging configurations for the function's returned values. Expecting a list of tuples and None - values: - - * str - A string in the format of '{key}:{artifact_type}'. If a string was given without ':' it will - indicate the key and the artifact type will be according to the returned value type. The artifact - types can be one of: "dataset", "directory", "file", "object", "plot" and "result". - - * Dict[str, str] - A dictionary of logging configuration. the key 'key' is mandatory for the logged - artifact key. - - * None - Do not log the output. - - The list length must be equal to the total amount of returned values from the function. Default is - None - meaning no outputs will be logged. - - :param inputs: Parsing configurations for the arguments passed as inputs via the `run` method of an MLRun function. - Can be passed as a boolean value or a dictionary: - - * True - Parse all found inputs to the assigned type hint in the function's signature. If there is no - type hint assigned, the value will remain an `mlrun.DataItem`. - * False - Do not parse inputs, leaving the inputs as `mlrun.DataItem`. - * Dict[str, Union[Type, str]] - A dictionary with argument name as key and the expected type to parse - the `mlrun.DataItem` to. The expected type can be a string as well, idicating the full module path. - - **Notice**: Type hints from the `typing` module (e.g. `typing.Optional`, `typing.Union`, - `typing.List` etc.) are currently not supported but will be in the future. - - Default: True. - - Example:: - - import mlrun - - @mlrun.handler(outputs=["my_array", None, "my_multiplier"]) - def my_handler(array: np.ndarray, m: int): - array = array * m - m += 1 - return array, "I won't be logged", m - - >>> mlrun_function = mlrun.code_to_function("my_code.py", kind="job") - >>> run_object = mlrun_function.run( - ... handler="my_handler", - ... inputs={"array": "store://my_array_Artifact"}, - ... params={"m": 2} - ... ) - >>> run_object.outputs - {'my_multiplier': 3, 'my_array': 'store://...'} - """ - - def decorator(func: Callable): - def wrapper(*args: tuple, **kwargs: dict): - nonlocal labels - nonlocal outputs - nonlocal inputs - - # Set default `inputs` - inspect the full signature and add the user's input on top of it: - if inputs: - # Get the available parameters type hints from the function's signature: - func_signature = inspect.signature(func) - parameters = OrderedDict( - { - parameter.name: parameter.annotation - for parameter in func_signature.parameters.values() - } - ) - # If user input is given, add it on top of the collected defaults (from signature), strings type hints - # will be parsed to their actual types: - if isinstance(inputs, dict): - parameters.update( - { - parameter_name: _parse_type_hint(type_hint=type_hint) - for parameter_name, type_hint in inputs.items() - } - ) - inputs = parameters - - # Create a context handler and look for a context: - context_handler = ContextHandler() - context_handler.look_for_context(args=args, kwargs=kwargs) - - # If an MLRun context is found, parse arguments pre-run (kwargs are parsed inplace): - if context_handler.is_context_available() and inputs: - args = context_handler.parse_inputs( - args=args, kwargs=kwargs, type_hints=inputs - ) - - # Call the original function and get the returning values: - func_outputs = func(*args, **kwargs) - - # If an MLRun context is found, set the given labels and log the returning values to MLRun via the context: - if context_handler.is_context_available(): - if labels: - context_handler.set_labels(labels=labels) - if outputs: - context_handler.log_outputs( - outputs=func_outputs - if isinstance(func_outputs, tuple) - else [func_outputs], - log_hints=[ - _parse_log_hint(log_hint=log_hint) for log_hint in outputs - ], - ) - return # Do not return any values as the returning values were logged to MLRun. - return func_outputs - - # Make sure to pass the wrapped function's signature (argument list, type hints and doc strings) to the wrapper: - wrapper = functools.wraps(func)(wrapper) - - return wrapper - - return decorator diff --git a/mlrun/runtimes/__init__.py b/mlrun/runtimes/__init__.py index af678c958315..e0f017602685 100644 --- a/mlrun/runtimes/__init__.py +++ b/mlrun/runtimes/__init__.py @@ -25,8 +25,6 @@ "RemoteSparkRuntime", ] - -from mlrun.runtimes.package.context_handler import ArtifactType, ContextHandler from mlrun.runtimes.utils import ( resolve_mpijob_crd_version, resolve_spark_operator_version, diff --git a/mlrun/runtimes/base.py b/mlrun/runtimes/base.py index d0ac3777aea9..72adafe98ec3 100644 --- a/mlrun/runtimes/base.py +++ b/mlrun/runtimes/base.py @@ -14,72 +14,53 @@ import enum import getpass import http -import os.path -import shlex import traceback -import typing -import uuid +import warnings from abc import ABC, abstractmethod -from ast import literal_eval from base64 import b64encode -from copy import deepcopy from datetime import datetime, timedelta, timezone from os import environ -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union -import IPython import requests.exceptions +from deprecated import deprecated from kubernetes.client.rest import ApiException from nuclio.build import mlrun_footer from sqlalchemy.orm import Session import mlrun.api.db.sqldb.session import mlrun.api.utils.singletons.db +import mlrun.common.schemas import mlrun.errors +import mlrun.launcher.factory import mlrun.utils.helpers import mlrun.utils.notifications import mlrun.utils.regex -from mlrun.api import schemas from mlrun.api.constants import LogSources from mlrun.api.db.base import DBInterface from mlrun.utils.helpers import generate_object_uri, verify_field_regex -from ..config import config, is_running_as_api +from ..config import config from ..datastore import store_manager from ..db import RunDBError, get_or_set_dburl, get_run_db from ..errors import err_to_str -from ..execution import MLClientCtx -from ..k8s_utils import get_k8s_helper -from ..kfpops import mlrun_op, write_kfpmeta +from ..kfpops import mlrun_op from ..lists import RunList -from ..model import ( - BaseMetadata, - HyperParamOptions, - ImageBuilder, - ModelObj, - RunObject, - RunTemplate, -) -from ..secrets import SecretsStore +from ..model import BaseMetadata, HyperParamOptions, ImageBuilder, ModelObj, RunObject from ..utils import ( dict_to_json, dict_to_yaml, enrich_image_url, get_in, get_parsed_docker_registry, - get_ui_url, - is_ipython, logger, - normalize_name, now_date, update_in, ) from .constants import PodPhases, RunStates from .funcdoc import update_function_entry_points -from .generators import get_generator -from .utils import RunError, calc_hash, results_to_iter +from .utils import RunError, calc_hash, get_k8s -run_modes = ["pass"] spec_fields = [ "command", "args", @@ -93,6 +74,7 @@ "pythonpath", "disable_auto_mount", "allow_empty_resources", + "clone_target_dir", ] @@ -133,6 +115,7 @@ def __init__( default_handler=None, pythonpath=None, disable_auto_mount=False, + clone_target_dir=None, ): self.command = command or "" @@ -151,6 +134,9 @@ def __init__( self.entry_points = entry_points or {} self.disable_auto_mount = disable_auto_mount self.allow_empty_resources = None + # the build.source is cloned/extracted to the specified clone_target_dir + # if a relative path is specified, it will be enriched with a temp dir path + self.clone_target_dir = clone_target_dir or "" @property def build(self) -> ImageBuilder: @@ -186,14 +172,12 @@ def __init__(self, metadata=None, spec=None): self.is_child = False self._status = None self.status = None - self._is_api_server = False self.verbose = False self._enriched_image = False def set_db_connection(self, conn): if not self._db_conn: self._db_conn = conn - self._is_api_server = mlrun.config.is_running_as_api() @property def metadata(self) -> BaseMetadata: @@ -219,9 +203,6 @@ def status(self) -> FunctionStatus: def status(self, status): self._status = self._verify_dict(status, "status", FunctionStatus) - def _get_k8s(self): - return get_k8s_helper() - def set_label(self, key, value): self.metadata.labels[key] = str(value) return self @@ -239,39 +220,6 @@ def _is_remote_api(self): return True return False - def _use_remote_api(self): - if ( - self._is_remote - and not self._is_api_server - and self._get_db() - and self._get_db().kind == "http" - ): - return True - return False - - def _enrich_on_client_side(self): - self.try_auto_mount_based_on_config() - self._fill_credentials() - - def _enrich_on_server_side(self): - pass - - def _enrich_on_server_and_client_sides(self): - """ - enrich function also in client side and also on server side - """ - pass - - def _enrich_function(self): - """ - enriches the function based on the flow state we run in (sdk or server) - """ - if self._use_remote_api(): - self._enrich_on_client_side() - else: - self._enrich_on_server_side() - self._enrich_on_server_and_client_sides() - def _function_uri(self, tag=None, hash_key=None): return generate_object_uri( self.metadata.project, @@ -284,11 +232,11 @@ def _ensure_run_db(self): self.spec.rundb = self.spec.rundb or get_or_set_dburl() def _get_db(self): + # TODO: remove this function and use the launcher db instead self._ensure_run_db() if not self._db_conn: if self.spec.rundb: self._db_conn = get_run_db(self.spec.rundb, secrets=self._secrets) - self._is_api_server = mlrun.config.is_running_as_api() return self._db_conn # This function is different than the auto_mount function, as it mounts to runtimes based on the configuration. @@ -324,61 +272,62 @@ def _fill_credentials(self): def run( self, - runspec: RunObject = None, - handler=None, - name: str = "", - project: str = "", - params: dict = None, - inputs: Dict[str, str] = None, - out_path: str = "", - workdir: str = "", - artifact_path: str = "", - watch: bool = True, - schedule: Union[str, schemas.ScheduleCronTrigger] = None, - hyperparams: Dict[str, list] = None, - hyper_param_options: HyperParamOptions = None, - verbose=None, - scrape_metrics: bool = None, - local=False, - local_code_path=None, - auto_build=None, - param_file_secrets: Dict[str, str] = None, - notifications: List[mlrun.model.Notification] = None, + runspec: Optional[ + Union["mlrun.run.RunTemplate", "mlrun.run.RunObject", dict] + ] = None, + handler: Optional[Union[str, Callable]] = None, + name: Optional[str] = "", + project: Optional[str] = "", + params: Optional[dict] = None, + inputs: Optional[Dict[str, str]] = None, + out_path: Optional[str] = "", + workdir: Optional[str] = "", + artifact_path: Optional[str] = "", + watch: Optional[bool] = True, + schedule: Optional[Union[str, mlrun.common.schemas.ScheduleCronTrigger]] = None, + hyperparams: Optional[Dict[str, list]] = None, + hyper_param_options: Optional[HyperParamOptions] = None, + verbose: Optional[bool] = None, + scrape_metrics: Optional[bool] = None, + local: Optional[bool] = False, + local_code_path: Optional[str] = None, + auto_build: Optional[bool] = None, + param_file_secrets: Optional[Dict[str, str]] = None, + notifications: Optional[List[mlrun.model.Notification]] = None, returns: Optional[List[Union[str, Dict[str, str]]]] = None, ) -> RunObject: """ Run a local or remote task. - :param runspec: run template object or dict (see RunTemplate) - :param handler: pointer or name of a function handler - :param name: execution name - :param project: project name - :param params: input parameters (dict) + :param runspec: The run spec to generate the RunObject from. Can be RunTemplate | RunObject | dict. + :param handler: Pointer or name of a function handler. + :param name: Execution name. + :param project: Project name. + :param params: Input parameters (dict). :param inputs: Input objects to pass to the handler. Type hints can be given so the input will be parsed during runtime from `mlrun.DataItem` to the given type hint. The type hint can be given in the key field of the dictionary after a colon, e.g: " : ". - :param out_path: default artifact output path - :param artifact_path: default artifact output path (will replace out_path) - :param workdir: default input artifacts path - :param watch: watch/follow run log + :param out_path: Default artifact output path. + :param artifact_path: Default artifact output path (will replace out_path). + :param workdir: Default input artifacts path. + :param watch: Watch/follow run log. :param schedule: ScheduleCronTrigger class instance or a standard crontab expression string (which will be converted to the class using its `from_crontab` constructor), see this link for help: https://apscheduler.readthedocs.io/en/3.x/modules/triggers/cron.html#module-apscheduler.triggers.cron - :param hyperparams: dict of param name and list of values to be enumerated e.g. {"p1": [1,2,3]} + :param hyperparams: Dict of param name and list of values to be enumerated e.g. {"p1": [1,2,3]} the default strategy is grid search, can specify strategy (grid, list, random) - and other options in the hyper_param_options parameter - :param hyper_param_options: dict or :py:class:`~mlrun.model.HyperParamOptions` struct of - hyper parameter options - :param verbose: add verbose prints/logs - :param scrape_metrics: whether to add the `mlrun/scrape-metrics` label to this run's resources - :param local: run the function locally vs on the runtime/cluster - :param local_code_path: path of the code for local runs & debug - :param auto_build: when set to True and the function require build it will be built on the first - function run, use only if you dont plan on changing the build config between runs - :param param_file_secrets: dictionary of secrets to be used only for accessing the hyper-param parameter file. - These secrets are only used locally and will not be stored anywhere - :param notifications: list of notifications to push when the run is completed + and other options in the hyper_param_options parameter. + :param hyper_param_options: Dict or :py:class:`~mlrun.model.HyperParamOptions` struct of hyperparameter options. + :param verbose: Add verbose prints/logs. + :param scrape_metrics: Whether to add the `mlrun/scrape-metrics` label to this run's resources. + :param local: Run the function locally vs on the runtime/cluster. + :param local_code_path: Path of the code for local runs & debug. + :param auto_build: When set to True and the function require build it will be built on the first + function run, use only if you don't plan on changing the build config between runs. + :param param_file_secrets: Dictionary of secrets to be used only for accessing the hyper-param parameter file. + These secrets are only used locally and will not be stored anywhere + :param notifications: List of notifications to push when the run is completed :param returns: List of log hints - configurations for how to log the returning values from the handler's run (as artifacts or results). The list's length must be equal to the amount of returning objects. A log hint may be given as: @@ -390,212 +339,34 @@ def run( * A dictionary of configurations to use when logging. Further info per object type and artifact type can be given there. The artifact key must appear in the dictionary as "key": "the_key". - :return: run context object (RunObject) with run metadata, results and status + :return: Run context object (RunObject) with run metadata, results and status """ - mlrun.utils.helpers.verify_dict_items_type("Inputs", inputs, [str], [str]) - - if self.spec.mode and self.spec.mode not in run_modes: - raise ValueError(f'run mode can only be {",".join(run_modes)}') - - self._enrich_function() - - run = self._create_run_object(runspec) - - if local: - - # do not allow local function to be scheduled - if schedule is not None: - raise mlrun.errors.MLRunInvalidArgumentError( - "local and schedule cannot be used together" - ) - result = self._run_local( - run, - local_code_path, - project, - name, - workdir, - handler, - params, - inputs, - returns, - artifact_path, - notifications=notifications, - ) - self._save_or_push_notifications(result, local) - return result - - run = self._enrich_run( - run, - handler, - project, - name, - params, - inputs, - returns, - hyperparams, - hyper_param_options, - verbose, - scrape_metrics, - out_path, - artifact_path, - workdir, - notifications, - ) - self._validate_output_path(run) - db = self._get_db() - - if not self.is_deployed(): - if self.spec.build.auto_build or auto_build: - logger.info( - "Function is not deployed and auto_build flag is set, starting deploy..." - ) - self.deploy(skip_deployed=True, show_on_failure=True) - else: - raise RunError( - "function image is not built/ready, set auto_build=True or use .deploy() method first" - ) - - if self.verbose: - logger.info(f"runspec:\n{run.to_yaml()}") - - if "V3IO_USERNAME" in environ and "v3io_user" not in run.metadata.labels: - run.metadata.labels["v3io_user"] = environ.get("V3IO_USERNAME") - - if not self.is_child: - db_str = "self" if self._is_api_server else self.spec.rundb - logger.info( - "Storing function", - name=run.metadata.name, - uid=run.metadata.uid, - db=db_str, - ) - self._store_function(run, run.metadata, db) - - # execute the job remotely (to a k8s cluster via the API service) - if self._use_remote_api(): - return self._submit_job(run, schedule, db, watch) - - elif self._is_remote and not self._is_api_server and not self.kfp: - logger.warning( - "warning!, Api url not set, " "trying to exec remote runtime locally" - ) - - execution = MLClientCtx.from_dict( - run.to_dict(), - db, - autocommit=False, - is_api=self._is_api_server, - store_run=False, + launcher = mlrun.launcher.factory.LauncherFactory.create_launcher( + self._is_remote, local ) - - self._verify_run_params(run.spec.parameters) - - # create task generator (for child runs) from spec - task_generator = get_generator( - run.spec, execution, param_file_secrets=param_file_secrets + return launcher.launch( + runtime=self, + task=runspec, + handler=handler, + name=name, + project=project, + params=params, + inputs=inputs, + out_path=out_path, + workdir=workdir, + artifact_path=artifact_path, + watch=watch, + schedule=schedule, + hyperparams=hyperparams, + hyper_param_options=hyper_param_options, + verbose=verbose, + scrape_metrics=scrape_metrics, + local_code_path=local_code_path, + auto_build=auto_build, + param_file_secrets=param_file_secrets, + notifications=notifications, + returns=returns, ) - if task_generator: - # verify valid task parameters - tasks = task_generator.generate(run) - for task in tasks: - self._verify_run_params(task.spec.parameters) - - # post verifications, store execution in db and run pre run hooks - execution.store_run() - self._pre_run(run, execution) # hook for runtime specific prep - - last_err = None - # If the runtime is nested, it means the hyper-run will run within a single instance of the run. - # So while in the API, we consider the hyper-run as a single run, and then in the runtime itself when the - # runtime is now a local runtime and therefore `self._is_nested == False`, we run each task as a separate run by - # using the task generator - if task_generator and not self._is_nested: - # multiple runs (based on hyper params or params file) - runner = self._run_many - if hasattr(self, "_parallel_run_many") and task_generator.use_parallel(): - runner = self._parallel_run_many - results = runner(task_generator, execution, run) - results_to_iter(results, run, execution) - result = execution.to_dict() - result = self._update_run_state(result, task=run) - - else: - # single run - try: - resp = self._run(run, execution) - if ( - watch - and mlrun.runtimes.RuntimeKinds.is_watchable(self.kind) - # API shouldn't watch logs, its the client job to query the run logs - and not mlrun.config.is_running_as_api() - ): - state, _ = run.logs(True, self._get_db()) - if state not in ["succeeded", "completed"]: - logger.warning(f"run ended with state {state}") - result = self._update_run_state(resp, task=run) - except RunError as err: - last_err = err - result = self._update_run_state(task=run, err=err) - - self._save_or_push_notifications(run) - - self._post_run(result, execution) # hook for runtime specific cleanup - - return self._wrap_run_result(result, run, schedule=schedule, err=last_err) - - def _wrap_run_result( - self, result: dict, runspec: RunObject, schedule=None, err=None - ): - # if the purpose was to schedule (and not to run) nothing to wrap - if schedule: - return - - if result and self.kfp and err is None: - write_kfpmeta(result) - - # show ipython/jupyter result table widget - results_tbl = RunList() - if result: - results_tbl.append(result) - else: - logger.info("no returned result (job may still be in progress)") - results_tbl.append(runspec.to_dict()) - - uid = runspec.metadata.uid - project = runspec.metadata.project - if is_ipython and config.ipython_widget: - results_tbl.show() - print() - ui_url = get_ui_url(project, uid) - if ui_url: - ui_url = f' or click here to open in UI' - IPython.display.display( - IPython.display.HTML( - f" > to track results use the .show() or .logs() methods {ui_url}" - ) - ) - elif not (self.is_child and is_running_as_api()): - project_flag = f"-p {project}" if project else "" - info_cmd = f"mlrun get run {uid} {project_flag}" - logs_cmd = f"mlrun logs {uid} {project_flag}" - logger.info( - "To track results use the CLI", info_cmd=info_cmd, logs_cmd=logs_cmd - ) - ui_url = get_ui_url(project, uid) - if ui_url: - logger.info("Or click for UI", ui_url=ui_url) - if result: - run = RunObject.from_dict(result) - logger.info( - f"run executed, status={run.status.state}", name=run.metadata.name - ) - if run.status.state == "error": - if self._is_remote and not self.is_child: - logger.error(f"runtime error: {run.status.error}") - raise RunError(run.status.error) - return run - - return None def _get_db_run(self, task: RunObject = None): if self._get_db() and task: @@ -624,242 +395,6 @@ def _generate_runtime_env(self, runobj: RunObject): runtime_env["MLRUN_NAMESPACE"] = self.metadata.namespace or config.namespace return runtime_env - def _run_local( - self, - runspec, - local_code_path, - project, - name, - workdir, - handler, - params, - inputs, - returns, - artifact_path, - notifications: List[mlrun.model.Notification] = None, - ): - # allow local run simulation with a flip of a flag - command = self - if local_code_path: - project = project or self.metadata.project - name = name or self.metadata.name - command = local_code_path - return mlrun.run_local( - runspec, - command, - name, - self.spec.args, - workdir=workdir, - project=project, - handler=handler, - params=params, - inputs=inputs, - artifact_path=artifact_path, - mode=self.spec.mode, - allow_empty_resources=self.spec.allow_empty_resources, - notifications=notifications, - returns=returns, - ) - - def _create_run_object(self, runspec): - # TODO: Once implemented the `Runtime` handlers configurations (doc strings, params type hints and returning - # log hints, possible parameter values, etc), the configured type hints and log hints should be set into - # the `RunObject` from the `Runtime`. - if runspec: - runspec = deepcopy(runspec) - if isinstance(runspec, str): - runspec = literal_eval(runspec) - if not isinstance(runspec, (dict, RunTemplate, RunObject)): - raise ValueError( - "task/runspec is not a valid task object," f" type={type(runspec)}" - ) - - if isinstance(runspec, RunTemplate): - runspec = RunObject.from_template(runspec) - if isinstance(runspec, dict) or runspec is None: - runspec = RunObject.from_dict(runspec) - return runspec - - def _enrich_run( - self, - runspec, - handler, - project_name, - name, - params, - inputs, - returns, - hyperparams, - hyper_param_options, - verbose, - scrape_metrics, - out_path, - artifact_path, - workdir, - notifications: List[mlrun.model.Notification] = None, - ): - runspec.spec.handler = ( - handler or runspec.spec.handler or self.spec.default_handler or "" - ) - if runspec.spec.handler and self.kind not in ["handler", "dask"]: - runspec.spec.handler = runspec.spec.handler_name - - def_name = self.metadata.name - if runspec.spec.handler_name: - short_name = runspec.spec.handler_name - for separator in ["#", "::", "."]: - # drop paths, module or class name from short name - if separator in short_name: - short_name = short_name.split(separator)[-1] - def_name += "-" + short_name - - runspec.metadata.name = normalize_name( - name=name or runspec.metadata.name or def_name, - # if name or runspec.metadata.name are set then it means that is user defined name and we want to warn the - # user that the passed name needs to be set without underscore, if its not user defined but rather enriched - # from the handler(function) name then we replace the underscore without warning the user. - # most of the time handlers will have `_` in the handler name (python convention is to separate function - # words with `_`), therefore we don't want to be noisy when normalizing the run name - verbose=bool(name or runspec.metadata.name), - ) - verify_field_regex( - "run.metadata.name", runspec.metadata.name, mlrun.utils.regex.run_name - ) - runspec.metadata.project = ( - project_name - or runspec.metadata.project - or self.metadata.project - or config.default_project - ) - runspec.spec.parameters = params or runspec.spec.parameters - runspec.spec.inputs = inputs or runspec.spec.inputs - runspec.spec.returns = returns or runspec.spec.returns - runspec.spec.hyperparams = hyperparams or runspec.spec.hyperparams - runspec.spec.hyper_param_options = ( - hyper_param_options or runspec.spec.hyper_param_options - ) - runspec.spec.verbose = verbose or runspec.spec.verbose - if scrape_metrics is None: - if runspec.spec.scrape_metrics is None: - scrape_metrics = config.scrape_metrics - else: - scrape_metrics = runspec.spec.scrape_metrics - runspec.spec.scrape_metrics = scrape_metrics - runspec.spec.input_path = ( - workdir or runspec.spec.input_path or self.spec.workdir - ) - if self.spec.allow_empty_resources: - runspec.spec.allow_empty_resources = self.spec.allow_empty_resources - - spec = runspec.spec - if spec.secret_sources: - self._secrets = SecretsStore.from_list(spec.secret_sources) - - # update run metadata (uid, labels) and store in DB - meta = runspec.metadata - meta.uid = meta.uid or uuid.uuid4().hex - - runspec.spec.output_path = out_path or artifact_path or runspec.spec.output_path - - if not runspec.spec.output_path: - if runspec.metadata.project: - if ( - mlrun.pipeline_context.project - and runspec.metadata.project - == mlrun.pipeline_context.project.metadata.name - ): - runspec.spec.output_path = ( - mlrun.pipeline_context.project.spec.artifact_path - or mlrun.pipeline_context.workflow_artifact_path - ) - - if not runspec.spec.output_path and self._get_db(): - try: - # not passing or loading the DB before the enrichment on purpose, because we want to enrich the - # spec first as get_db() depends on it - project = self._get_db().get_project(runspec.metadata.project) - # this is mainly for tests, so we won't need to mock get_project for so many tests - # in normal use cases if no project is found we will get an error - if project: - runspec.spec.output_path = project.spec.artifact_path - except mlrun.errors.MLRunNotFoundError: - logger.warning( - f"project {project_name} is not saved in DB yet, " - f"enriching output path with default artifact path: {config.artifact_path}" - ) - - if not runspec.spec.output_path: - runspec.spec.output_path = config.artifact_path - - if runspec.spec.output_path: - runspec.spec.output_path = runspec.spec.output_path.replace( - "{{run.uid}}", meta.uid - ) - runspec.spec.output_path = mlrun.utils.helpers.fill_artifact_path_template( - runspec.spec.output_path, runspec.metadata.project - ) - - runspec.spec.notifications = notifications or runspec.spec.notifications or [] - return runspec - - def _submit_job(self, run: RunObject, schedule, db, watch): - if self._secrets: - run.spec.secret_sources = self._secrets.to_serial() - try: - resp = db.submit_job(run, schedule=schedule) - if schedule: - action = resp.pop("action", "created") - logger.info(f"task schedule {action}", **resp) - return - - except (requests.HTTPError, Exception) as err: - logger.error(f"got remote run err, {err_to_str(err)}") - - if isinstance(err, requests.HTTPError): - self._handle_submit_job_http_error(err) - - result = None - # if we got a schedule no reason to do post_run stuff (it purposed to update the run status with error, - # but there's no run in case of schedule) - if not schedule: - result = self._update_run_state(task=run, err=err_to_str(err)) - return self._wrap_run_result(result, run, schedule=schedule, err=err) - - if resp: - txt = get_in(resp, "status.status_text") - if txt: - logger.info(txt) - # watch is None only in scenario where we run from pipeline step, in this case we don't want to watch the run - # logs too frequently but rather just pull the state of the run from the DB and pull the logs every x seconds - # which ideally greater than the pull state interval, this reduces unnecessary load on the API server, as - # running a pipeline is mostly not an interactive process which means the logs pulling doesn't need to be pulled - # in real time - if ( - watch is None - and self.kfp - and config.httpdb.logs.pipelines.pull_state.mode == "enabled" - ): - state_interval = int( - config.httpdb.logs.pipelines.pull_state.pull_state_interval - ) - logs_interval = int( - config.httpdb.logs.pipelines.pull_state.pull_logs_interval - ) - - run.wait_for_completion( - show_logs=True, - sleep=state_interval, - logs_interval=logs_interval, - raise_on_failure=False, - ) - resp = self._get_db_run(run) - - elif watch or self.kfp: - run.logs(True, self._get_db()) - resp = self._get_db_run(run) - - return self._wrap_run_result(resp, run, schedule=schedule) - @staticmethod def _handle_submit_job_http_error(error: requests.HTTPError): # if we receive a 400 status code, this means the request was invalid and the run wasn't created in the DB. @@ -1060,47 +595,6 @@ def _update_run_state( return resp - def _save_or_push_notifications(self, runobj: RunObject, local: bool = False): - - if not runobj.spec.notifications: - logger.debug( - "No notifications to push for run", run_uid=runobj.metadata.uid - ) - return - - # TODO: add support for other notifications per run iteration - if runobj.metadata.iteration and runobj.metadata.iteration > 0: - logger.debug( - "Notifications per iteration are not supported, skipping", - run_uid=runobj.metadata.uid, - ) - return - - # If the run is remote, and we are in the SDK, we let the api deal with the notifications - # so there's nothing to do here. - # Otherwise, we continue on. - if is_running_as_api(): - - # import here to avoid circular imports and to avoid importing api requirements - from mlrun.api.crud import Notifications - - # If in the api server, we can assume that watch=False, so we save notification - # configs to the DB, for the run monitor to later pick up and push. - session = mlrun.api.db.sqldb.session.create_session() - Notifications().store_run_notifications( - session, - runobj.spec.notifications, - runobj.metadata.uid, - runobj.metadata.project, - ) - - elif local: - # If the run is local, we can assume that watch=True, therefore this code runs - # once the run is completed, and we can just push the notifications. - # TODO: add store_notifications API endpoint so we can store notifications pushed from the - # SDK for documentation purposes. - mlrun.utils.notifications.NotificationPusher([runobj]).push() - def _force_handler(self, handler): if not handler: raise RunError(f"handler must be provided for {self.kind} runtime") @@ -1263,104 +757,93 @@ def with_requirements( self, requirements: Union[str, List[str]], overwrite: bool = False, - verify_base_image: bool = True, + verify_base_image: bool = False, + prepare_image_for_deploy: bool = True, + requirements_file: str = "", ): """add package requirements from file or list to build spec. - :param requirements: python requirements file path or list of packages - :param overwrite: overwrite existing requirements - :param verify_base_image: verify that the base image is configured + :param requirements: a list of python packages + :param requirements_file: a local python requirements file path + :param overwrite: overwrite existing requirements + :param verify_base_image: verify that the base image is configured + (deprecated, use prepare_image_for_deploy) + :param prepare_image_for_deploy: prepare the image/base_image spec for deployment :return: function object """ - encoded_requirements = self._encode_requirements(requirements) - commands = self.spec.build.commands or [] if not overwrite else [] - new_command = f"python -m pip install {encoded_requirements}" - # make sure we dont append the same line twice - if new_command not in commands: - commands.append(new_command) - self.spec.build.commands = commands - if verify_base_image: - self.verify_base_image() + self.spec.build.with_requirements(requirements, requirements_file, overwrite) + + if verify_base_image or prepare_image_for_deploy: + # TODO: remove verify_base_image in 1.6.0 + if verify_base_image: + warnings.warn( + "verify_base_image is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use prepare_image_for_deploy", + category=FutureWarning, + ) + self.prepare_image_for_deploy() + return self def with_commands( self, commands: List[str], overwrite: bool = False, - verify_base_image: bool = True, + verify_base_image: bool = False, + prepare_image_for_deploy: bool = True, ): """add commands to build spec. - :param commands: list of commands to run during build + :param commands: list of commands to run during build + :param overwrite: overwrite existing commands + :param verify_base_image: verify that the base image is configured + (deprecated, use prepare_image_for_deploy) + :param prepare_image_for_deploy: prepare the image/base_image spec for deployment :return: function object """ - if not isinstance(commands, list): - raise ValueError("commands must be a string list") - if not self.spec.build.commands or overwrite: - self.spec.build.commands = commands - else: - # add commands to existing build commands - for command in commands: - if command not in self.spec.build.commands: - self.spec.build.commands.append(command) - # using list(set(x)) won't retain order, - # solution inspired from https://stackoverflow.com/a/17016257/8116661 - self.spec.build.commands = list(dict.fromkeys(self.spec.build.commands)) - if verify_base_image: - self.verify_base_image() + self.spec.build.with_commands(commands, overwrite) + + if verify_base_image or prepare_image_for_deploy: + # TODO: remove verify_base_image in 1.6.0 + if verify_base_image: + warnings.warn( + "verify_base_image is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use prepare_image_for_deploy", + category=FutureWarning, + ) + + self.prepare_image_for_deploy() return self def clean_build_params(self): - # when using `with_requirements` we also execute `verify_base_image` which adds the base image and cleans the - # spec.image, so we need to restore the image back + # when using `with_requirements` we also execute `prepare_image_for_deploy` which adds the base image + # and cleans the spec.image, so we need to restore the image back if self.spec.build.base_image and not self.spec.image: self.spec.image = self.spec.build.base_image self.spec.build = {} return self + # TODO: remove in 1.6.0 + @deprecated( + version="1.4.0", + reason="'verify_base_image' will be removed in 1.6.0, use 'prepare_image_for_deploy' instead", + category=FutureWarning, + ) def verify_base_image(self): - build = self.spec.build - require_build = build.commands or ( - build.source and not build.load_source_on_run + self.prepare_image_for_deploy() + + def prepare_image_for_deploy(self): + """ + if a function has a 'spec.image' it is considered to be deployed, + but because we allow the user to set 'spec.image' for usability purposes, + we need to check whether this is a built image or it requires to be built on top. + """ + launcher = mlrun.launcher.factory.LauncherFactory.create_launcher( + is_remote=self._is_remote ) - image = self.spec.image - # we allow users to not set an image, in that case we'll use the default - if ( - not image - and self.kind in mlrun.mlconf.function_defaults.image_by_kind.to_dict() - ): - image = mlrun.mlconf.function_defaults.image_by_kind.to_dict()[self.kind] - - if ( - self.kind not in mlrun.runtimes.RuntimeKinds.nuclio_runtimes() - # TODO: need a better way to decide whether a function requires a build - and require_build - and image - and not self.spec.build.base_image - # when submitting a run we are loading the function from the db, and using new_function for it, - # this results reaching here, but we are already after deploy of the image, meaning we don't need to prepare - # the base image for deployment - and self._is_remote_api() - ): - # when the function require build use the image as the base_image for the build - self.spec.build.base_image = image - self.spec.image = "" - - def _verify_run_params(self, parameters: typing.Dict[str, typing.Any]): - for param_name, param_value in parameters.items(): - - if isinstance(param_value, dict): - # if the parameter is a dict, we might have some nested parameters, - # in this case we need to verify them as well recursively - self._verify_run_params(param_value) - - # verify that integer parameters don't exceed a int64 - if isinstance(param_value, int) and abs(param_value) >= 2**63: - raise mlrun.errors.MLRunInvalidArgumentError( - f"parameter {param_name} value {param_value} exceeds int64" - ) + launcher.prepare_image_for_deploy(self) def export(self, target="", format=".yaml", secrets=None, strip=True): """save function spec to a local/remote path (default to./function.yaml) @@ -1391,35 +874,12 @@ def export(self, target="", format=".yaml", secrets=None, strip=True): return self def save(self, tag="", versioned=False, refresh=False) -> str: - db = self._get_db() - if not db: - logger.error("database connection is not configured") - return "" - - if refresh and self._is_remote_api(): - try: - meta = self.metadata - db_func = db.get_function(meta.name, meta.project, meta.tag) - if db_func and "status" in db_func: - self.status = db_func["status"] - if ( - self.status.state - and self.status.state == "ready" - and not hasattr(self.status, "nuclio_name") - ): - self.spec.image = get_in(db_func, "spec.image", self.spec.image) - except mlrun.errors.MLRunNotFoundError: - pass - - tag = tag or self.metadata.tag - - obj = self.to_dict() - logger.debug(f"saving function: {self.metadata.name}, tag: {tag}") - hash_key = db.store_function( - obj, self.metadata.name, self.metadata.project, tag, versioned + launcher = mlrun.launcher.factory.LauncherFactory.create_launcher( + is_remote=self._is_remote + ) + return launcher.save_function( + self, tag=tag, versioned=versioned, refresh=refresh ) - hash_key = hash_key if versioned else None - return "db://" + self._function_uri(hash_key=hash_key, tag=tag) def to_dict(self, fields=None, exclude=None, strip=False): struct = super().to_dict(fields, exclude=exclude) @@ -1448,76 +908,11 @@ def doc(self): line += f", default={p['default']}" print(" " + line) - def _encode_requirements(self, requirements_to_encode): - - # if a string, read the file then encode - if isinstance(requirements_to_encode, str): - with open(requirements_to_encode, "r") as fp: - requirements_to_encode = fp.read().splitlines() - - requirements = [] - for requirement in requirements_to_encode: - requirement = requirement.strip() - - # ignore empty lines - # ignore comments - if not requirement or requirement.startswith("#"): - continue - - # ignore inline comments as well - inline_comment = requirement.split(" #") - if len(inline_comment) > 1: - requirement = inline_comment[0].strip() - - # -r / --requirement are flags and should not be escaped - # we allow such flags (could be passed within the requirements.txt file) and do not - # try to open the file and include its content since it might be a remote file - # given on the base image. - for req_flag in ["-r", "--requirement"]: - if requirement.startswith(req_flag): - requirement = requirement[len(req_flag) :].strip() - requirements.append(req_flag) - break - - # wrap in single quote to ensure that the requirement is treated as a single string - # quote the requirement to avoid issues with special characters, double quotes, etc. - requirements.append(shlex.quote(requirement)) - - return " ".join(requirements) - - def _validate_output_path(self, run): - if is_local(run.spec.output_path): - message = "" - if not os.path.isabs(run.spec.output_path): - message = ( - "artifact/output path is not defined or is local and relative," - " artifacts will not be visible in the UI" - ) - if mlrun.runtimes.RuntimeKinds.requires_absolute_artifacts_path( - self.kind - ): - raise mlrun.errors.MLRunPreconditionFailedError( - "artifact path (`artifact_path`) must be absolute for remote tasks" - ) - elif hasattr(self.spec, "volume_mounts") and not self.spec.volume_mounts: - message = ( - "artifact output path is local while no volume mount is specified. " - "artifacts would not be visible via UI." - ) - if message: - logger.warning(message, output_path=run.spec.output_path) - - -def is_local(url): - if not url: - return True - return "://" not in url - class BaseRuntimeHandler(ABC): # setting here to allow tests to override kind = "base" - class_modes: typing.Dict[RuntimeClassMode, str] = {} + class_modes: Dict[RuntimeClassMode, str] = {} wait_for_deletion_interval = 10 @staticmethod @@ -1531,12 +926,12 @@ def _get_object_label_selector(object_id: str) -> str: def _should_collect_logs(self) -> bool: """ There are some runtimes which we don't collect logs for using the log collector - :return: whether should collect log for it + :return: whether it should collect log for it """ return True def _get_possible_mlrun_class_label_values( - self, class_mode: typing.Union[RuntimeClassMode, str] = None + self, class_mode: Union[RuntimeClassMode, str] = None ) -> List[str]: """ Should return the possible values of the mlrun/class label for runtime resources that are of this runtime @@ -1550,21 +945,20 @@ def _get_possible_mlrun_class_label_values( def list_resources( self, project: str, - object_id: typing.Optional[str] = None, + object_id: Optional[str] = None, label_selector: str = None, - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ) -> Union[ - mlrun.api.schemas.RuntimeResources, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResources, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ]: # We currently don't support removing runtime resources in non k8s env - if not mlrun.k8s_utils.get_k8s_helper( - silent=True - ).is_running_inside_kubernetes_cluster(): + if not get_k8s().is_running_inside_kubernetes_cluster(): return {} - k8s_helper = get_k8s_helper() - namespace = k8s_helper.resolve_namespace() + namespace = get_k8s().resolve_namespace() label_selector = self.resolve_label_selector(project, object_id, label_selector) pods = self._list_pods(namespace, label_selector) pod_resources = self._build_pod_resources(pods) @@ -1580,8 +974,10 @@ def list_resources( def build_output_from_runtime_resources( self, - runtime_resources_list: List[mlrun.api.schemas.RuntimeResources], - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + runtime_resources_list: List[mlrun.common.schemas.RuntimeResources], + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ): pod_resources = [] crd_resources = [] @@ -1607,12 +1003,9 @@ def delete_resources( if grace_period is None: grace_period = config.runtime_resources_deletion_grace_period # We currently don't support removing runtime resources in non k8s env - if not mlrun.k8s_utils.get_k8s_helper( - silent=True - ).is_running_inside_kubernetes_cluster(): + if not get_k8s().is_running_inside_kubernetes_cluster(): return - k8s_helper = get_k8s_helper() - namespace = k8s_helper.resolve_namespace() + namespace = get_k8s().resolve_namespace() label_selector = self.resolve_label_selector("*", label_selector=label_selector) crd_group, crd_version, crd_plural = self._get_crd_info() if crd_group and crd_version and crd_plural: @@ -1660,8 +1053,7 @@ def delete_runtime_object_resources( self.delete_resources(db, db_session, label_selector, force, grace_period) def monitor_runs(self, db: DBInterface, db_session: Session): - k8s_helper = get_k8s_helper() - namespace = k8s_helper.resolve_namespace() + namespace = get_k8s().resolve_namespace() label_selector = self._get_default_label_selector() crd_group, crd_version, crd_plural = self._get_crd_info() runtime_resource_is_crd = False @@ -1810,8 +1202,8 @@ def _ensure_run_not_stuck_on_non_terminal_state( def _add_object_label_selector_if_needed( self, - object_id: typing.Optional[str] = None, - label_selector: typing.Optional[str] = None, + object_id: Optional[str] = None, + label_selector: Optional[str] = None, ): if object_id: object_label_selector = self._get_object_label_selector(object_id) @@ -1833,17 +1225,19 @@ def _get_main_runtime_resource_label_selector() -> str: def _enrich_list_resources_response( self, response: Union[ - mlrun.api.schemas.RuntimeResources, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResources, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ], namespace: str, label_selector: str = None, - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ) -> Union[ - mlrun.api.schemas.RuntimeResources, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResources, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ]: """ Override this to list resources other then pods or CRDs (which are handled by the base class) @@ -1853,12 +1247,14 @@ def _enrich_list_resources_response( def _build_output_from_runtime_resources( self, response: Union[ - mlrun.api.schemas.RuntimeResources, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResources, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ], - runtime_resources_list: List[mlrun.api.schemas.RuntimeResources], - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + runtime_resources_list: List[mlrun.common.schemas.RuntimeResources], + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ): """ Override this to add runtime resources other than pods or CRDs (which are handled by the base class) to the @@ -1940,7 +1336,7 @@ def _resolve_pod_status_info( return in_terminal_state, last_container_completion_time, run_state def _get_default_label_selector( - self, class_mode: typing.Union[RuntimeClassMode, str] = None + self, class_mode: Union[RuntimeClassMode, str] = None ) -> str: """ Override this to add a default label selector @@ -1989,20 +1385,18 @@ def _expect_pods_without_uid() -> bool: return False def _list_pods(self, namespace: str, label_selector: str = None) -> List: - k8s_helper = get_k8s_helper() - pods = k8s_helper.list_pods(namespace, selector=label_selector) + pods = get_k8s().list_pods(namespace, selector=label_selector) # when we work with custom objects (list_namespaced_custom_object) it's always a dict, to be able to generalize # code working on runtime resource (either a custom object or a pod) we're transforming to dicts pods = [pod.to_dict() for pod in pods] return pods def _list_crd_objects(self, namespace: str, label_selector: str = None) -> List: - k8s_helper = get_k8s_helper() crd_group, crd_version, crd_plural = self._get_crd_info() crd_objects = [] if crd_group and crd_version and crd_plural: try: - crd_objects = k8s_helper.crdapi.list_namespaced_custom_object( + crd_objects = get_k8s().crdapi.list_namespaced_custom_object( crd_group, crd_version, namespace, @@ -2020,9 +1414,9 @@ def _list_crd_objects(self, namespace: str, label_selector: str = None) -> List: def resolve_label_selector( self, project: str, - object_id: typing.Optional[str] = None, - label_selector: typing.Optional[str] = None, - class_mode: typing.Union[RuntimeClassMode, str] = None, + object_id: Optional[str] = None, + label_selector: Optional[str] = None, + class_mode: Union[RuntimeClassMode, str] = None, with_main_runtime_resource_label_selector: bool = False, ) -> str: default_label_selector = self._get_default_label_selector(class_mode=class_mode) @@ -2053,7 +1447,7 @@ def resolve_label_selector( @staticmethod def resolve_object_id( run: dict, - ) -> typing.Optional[str]: + ) -> Optional[str]: """ Get the object id from the run object Override this if the object id is not the run uid @@ -2068,11 +1462,10 @@ def _wait_for_pods_deletion( deleted_pods: List[Dict], label_selector: str = None, ): - k8s_helper = get_k8s_helper() deleted_pod_names = [pod_dict["metadata"]["name"] for pod_dict in deleted_pods] def _verify_pods_removed(): - pods = k8s_helper.v1api.list_namespaced_pod( + pods = get_k8s().v1api.list_namespaced_pod( namespace, label_selector=label_selector ) existing_pod_names = [pod.metadata.name for pod in pods.items] @@ -2125,10 +1518,10 @@ def _verify_crds_underlying_pods_removed(): "name" ] still_in_deletion_crds_to_pod_names = {} - jobs_runtime_resources: mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput = self.list_resources( + jobs_runtime_resources: mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput = self.list_resources( "*", label_selector=label_selector, - group_by=mlrun.api.schemas.ListRuntimeResourcesGroupByField.job, + group_by=mlrun.common.schemas.ListRuntimeResourcesGroupByField.job, ) for project, project_jobs in jobs_runtime_resources.items(): if project not in project_uid_crd_map: @@ -2176,8 +1569,7 @@ def _delete_pod_resources( ) -> List[Dict]: if grace_period is None: grace_period = config.runtime_resources_deletion_grace_period - k8s_helper = get_k8s_helper() - pods = k8s_helper.v1api.list_namespaced_pod( + pods = get_k8s().v1api.list_namespaced_pod( namespace, label_selector=label_selector ) deleted_pods = [] @@ -2218,7 +1610,7 @@ def _delete_pod_resources( pod_name=pod.metadata.name, ) - get_k8s_helper().delete_pod(pod.metadata.name, namespace) + get_k8s().delete_pod(pod.metadata.name, namespace) deleted_pods.append(pod_dict) except Exception as exc: logger.warning( @@ -2239,11 +1631,10 @@ def _delete_crd_resources( ) -> List[Dict]: if grace_period is None: grace_period = config.runtime_resources_deletion_grace_period - k8s_helper = get_k8s_helper() crd_group, crd_version, crd_plural = self._get_crd_info() deleted_crds = [] try: - crd_objects = k8s_helper.crdapi.list_namespaced_custom_object( + crd_objects = get_k8s().crdapi.list_namespaced_custom_object( crd_group, crd_version, namespace, @@ -2295,7 +1686,7 @@ def _delete_crd_resources( crd_object_name=crd_object["metadata"]["name"], ) - get_k8s_helper().delete_crd( + get_k8s().delete_crd( crd_object["metadata"]["name"], crd_group, crd_version, @@ -2471,13 +1862,15 @@ def _monitor_runtime_resource( def _build_list_resources_response( self, - pod_resources: List[mlrun.api.schemas.RuntimeResource] = None, - crd_resources: List[mlrun.api.schemas.RuntimeResource] = None, - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + pod_resources: List[mlrun.common.schemas.RuntimeResource] = None, + crd_resources: List[mlrun.common.schemas.RuntimeResource] = None, + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ) -> Union[ - mlrun.api.schemas.RuntimeResources, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResources, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ]: if crd_resources is None: crd_resources = [] @@ -2485,15 +1878,18 @@ def _build_list_resources_response( pod_resources = [] if group_by is None: - return mlrun.api.schemas.RuntimeResources( + return mlrun.common.schemas.RuntimeResources( crd_resources=crd_resources, pod_resources=pod_resources ) else: - if group_by == mlrun.api.schemas.ListRuntimeResourcesGroupByField.job: + if group_by == mlrun.common.schemas.ListRuntimeResourcesGroupByField.job: return self._build_grouped_by_job_list_resources_response( pod_resources, crd_resources ) - elif group_by == mlrun.api.schemas.ListRuntimeResourcesGroupByField.project: + elif ( + group_by + == mlrun.common.schemas.ListRuntimeResourcesGroupByField.project + ): return self._build_grouped_by_project_list_resources_response( pod_resources, crd_resources ) @@ -2504,9 +1900,9 @@ def _build_list_resources_response( def _build_grouped_by_project_list_resources_response( self, - pod_resources: List[mlrun.api.schemas.RuntimeResource] = None, - crd_resources: List[mlrun.api.schemas.RuntimeResource] = None, - ) -> mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput: + pod_resources: List[mlrun.common.schemas.RuntimeResource] = None, + crd_resources: List[mlrun.common.schemas.RuntimeResource] = None, + ) -> mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput: resources = {} for pod_resource in pod_resources: self._add_resource_to_grouped_by_project_resources_response( @@ -2520,9 +1916,9 @@ def _build_grouped_by_project_list_resources_response( def _build_grouped_by_job_list_resources_response( self, - pod_resources: List[mlrun.api.schemas.RuntimeResource] = None, - crd_resources: List[mlrun.api.schemas.RuntimeResource] = None, - ) -> mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput: + pod_resources: List[mlrun.common.schemas.RuntimeResource] = None, + crd_resources: List[mlrun.common.schemas.RuntimeResource] = None, + ) -> mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput: resources = {} for pod_resource in pod_resources: self._add_resource_to_grouped_by_job_resources_response( @@ -2536,9 +1932,9 @@ def _build_grouped_by_job_list_resources_response( def _add_resource_to_grouped_by_project_resources_response( self, - resources: mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, + resources: mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, resource_field_name: str, - resource: mlrun.api.schemas.RuntimeResource, + resource: mlrun.common.schemas.RuntimeResource, ): if "mlrun/class" in resource.labels: project = resource.labels.get("mlrun/project", "") @@ -2550,9 +1946,9 @@ def _add_resource_to_grouped_by_project_resources_response( def _add_resource_to_grouped_by_job_resources_response( self, - resources: mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, + resources: mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, resource_field_name: str, - resource: mlrun.api.schemas.RuntimeResource, + resource: mlrun.common.schemas.RuntimeResource, ): if "mlrun/uid" in resource.labels: project = resource.labels.get("mlrun/project", config.default_project) @@ -2565,16 +1961,18 @@ def _add_resource_to_grouped_by_job_resources_response( def _add_resource_to_grouped_by_field_resources_response( first_field_value: str, second_field_value: str, - resources: mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, + resources: mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, resource_field_name: str, - resource: mlrun.api.schemas.RuntimeResource, + resource: mlrun.common.schemas.RuntimeResource, ): if first_field_value not in resources: resources[first_field_value] = {} if second_field_value not in resources[first_field_value]: resources[first_field_value][ second_field_value - ] = mlrun.api.schemas.RuntimeResources(pod_resources=[], crd_resources=[]) + ] = mlrun.common.schemas.RuntimeResources( + pod_resources=[], crd_resources=[] + ) if not getattr( resources[first_field_value][second_field_value], resource_field_name ): @@ -2708,11 +2106,11 @@ def _resolve_runtime_resource_run(runtime_resource: Dict) -> Tuple[str, str, str return project, uid, name @staticmethod - def _build_pod_resources(pods) -> List[mlrun.api.schemas.RuntimeResource]: + def _build_pod_resources(pods) -> List[mlrun.common.schemas.RuntimeResource]: pod_resources = [] for pod in pods: pod_resources.append( - mlrun.api.schemas.RuntimeResource( + mlrun.common.schemas.RuntimeResource( name=pod["metadata"]["name"], labels=pod["metadata"]["labels"], status=pod["status"], @@ -2721,11 +2119,13 @@ def _build_pod_resources(pods) -> List[mlrun.api.schemas.RuntimeResource]: return pod_resources @staticmethod - def _build_crd_resources(custom_objects) -> List[mlrun.api.schemas.RuntimeResource]: + def _build_crd_resources( + custom_objects, + ) -> List[mlrun.common.schemas.RuntimeResource]: crd_resources = [] for custom_object in custom_objects: crd_resources.append( - mlrun.api.schemas.RuntimeResource( + mlrun.common.schemas.RuntimeResource( name=custom_object["metadata"]["name"], labels=custom_object["metadata"]["labels"], status=custom_object.get("status", {}), diff --git a/mlrun/runtimes/daskjob.py b/mlrun/runtimes/daskjob.py index b04b32897cf8..58f0d3080b19 100644 --- a/mlrun/runtimes/daskjob.py +++ b/mlrun/runtimes/daskjob.py @@ -23,8 +23,9 @@ from kubernetes.client.rest import ApiException from sqlalchemy.orm import Session -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.errors +import mlrun.k8s_utils import mlrun.utils import mlrun.utils.regex from mlrun.api.db.base import DBInterface @@ -33,7 +34,6 @@ from ..config import config from ..execution import MLClientCtx -from ..k8s_utils import get_k8s_helper from ..model import RunObject from ..render import ipython_display from ..utils import logger, normalize_name, update_in @@ -41,7 +41,7 @@ from .kubejob import KubejobRuntime from .local import exec_from_params, load_module from .pod import KubeResourceSpec, kube_resource_spec_to_pod_spec -from .utils import RunError, get_func_selector, get_resource_labels, log_std +from .utils import RunError, get_func_selector, get_k8s, get_resource_labels, log_std def get_dask_resource(): @@ -106,6 +106,7 @@ def __init__( tolerations=None, preemption_mode=None, security_context=None, + clone_target_dir=None, ): super().__init__( @@ -135,6 +136,7 @@ def __init__( tolerations=tolerations, preemption_mode=preemption_mode, security_context=security_context, + clone_target_dir=clone_target_dir, ) self.args = args @@ -201,9 +203,7 @@ class DaskCluster(KubejobRuntime): def __init__(self, spec=None, metadata=None): super().__init__(spec, metadata) self._cluster = None - self.use_remote = not get_k8s_helper( - silent=True - ).is_running_inside_kubernetes_cluster() + self.use_remote = not mlrun.k8s_utils.is_running_inside_kubernetes_cluster() self.spec.build.base_image = self.spec.build.base_image or "daskdev/dask:latest" @property @@ -271,11 +271,11 @@ def _start(self, watch=True): ) if ( background_task.status.state - in mlrun.api.schemas.BackgroundTaskState.terminal_states() + in mlrun.common.schemas.BackgroundTaskState.terminal_states() ): if ( background_task.status.state - == mlrun.api.schemas.BackgroundTaskState.failed + == mlrun.common.schemas.BackgroundTaskState.failed ): raise mlrun.errors.MLRunRuntimeError( "Failed bringing up dask cluster" @@ -352,10 +352,6 @@ def client(self): f"remote scheduler at {addr} not ready, will try to restart {err_to_str(exc)}" ) - # todo: figure out if test is needed - # if self._is_remote_api(): - # raise Exception('no access to Kubernetes API') - status = self.get_status() if status != "running": self._start() @@ -671,7 +667,9 @@ def get_obj_status(selector=None, namespace=None): if selector is None: selector = [] - k8s = get_k8s_helper() + import mlrun.api.utils.singletons.k8s + + k8s = mlrun.api.utils.singletons.k8s.get_k8s_helper() namespace = namespace or config.namespace selector = ",".join(["dask.org/component=scheduler"] + selector) pods = k8s.list_pods(namespace, selector=selector) @@ -730,17 +728,19 @@ def resolve_object_id( def _enrich_list_resources_response( self, response: Union[ - mlrun.api.schemas.RuntimeResources, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResources, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ], namespace: str, label_selector: str = None, - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ) -> Union[ - mlrun.api.schemas.RuntimeResources, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResources, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ]: """ Handling listing service resources @@ -748,14 +748,13 @@ def _enrich_list_resources_response( enrich_needed = self._validate_if_enrich_is_needed_by_group_by(group_by) if not enrich_needed: return response - k8s_helper = get_k8s_helper() - services = k8s_helper.v1api.list_namespaced_service( + services = get_k8s().v1api.list_namespaced_service( namespace, label_selector=label_selector ) service_resources = [] for service in services.items: service_resources.append( - mlrun.api.schemas.RuntimeResource( + mlrun.common.schemas.RuntimeResource( name=service.metadata.name, labels=service.metadata.labels ) ) @@ -766,12 +765,14 @@ def _enrich_list_resources_response( def _build_output_from_runtime_resources( self, response: Union[ - mlrun.api.schemas.RuntimeResources, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResources, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ], - runtime_resources_list: List[mlrun.api.schemas.RuntimeResources], - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + runtime_resources_list: List[mlrun.common.schemas.RuntimeResources], + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ): enrich_needed = self._validate_if_enrich_is_needed_by_group_by(group_by) if not enrich_needed: @@ -786,13 +787,15 @@ def _build_output_from_runtime_resources( def _validate_if_enrich_is_needed_by_group_by( self, - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ) -> bool: # Dask runtime resources are per function (and not per job) therefore, when grouping by job we're simply # omitting the dask runtime resources - if group_by == mlrun.api.schemas.ListRuntimeResourcesGroupByField.job: + if group_by == mlrun.common.schemas.ListRuntimeResourcesGroupByField.job: return False - elif group_by == mlrun.api.schemas.ListRuntimeResourcesGroupByField.project: + elif group_by == mlrun.common.schemas.ListRuntimeResourcesGroupByField.project: return True elif group_by is not None: raise NotImplementedError( @@ -803,14 +806,16 @@ def _validate_if_enrich_is_needed_by_group_by( def _enrich_service_resources_in_response( self, response: Union[ - mlrun.api.schemas.RuntimeResources, - mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, - mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + mlrun.common.schemas.RuntimeResources, + mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, + mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ], - service_resources: List[mlrun.api.schemas.RuntimeResource], - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + service_resources: List[mlrun.common.schemas.RuntimeResource], + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ): - if group_by == mlrun.api.schemas.ListRuntimeResourcesGroupByField.project: + if group_by == mlrun.common.schemas.ListRuntimeResourcesGroupByField.project: for service_resource in service_resources: self._add_resource_to_grouped_by_project_resources_response( response, "service_resources", service_resource @@ -845,14 +850,13 @@ def _delete_extra_resources( if dask_component == "scheduler" and cluster_name: service_names.append(cluster_name) - k8s_helper = get_k8s_helper() - services = k8s_helper.v1api.list_namespaced_service( + services = get_k8s().v1api.list_namespaced_service( namespace, label_selector=label_selector ) for service in services.items: try: if force or service.metadata.name in service_names: - k8s_helper.v1api.delete_namespaced_service( + get_k8s().v1api.delete_namespaced_service( service.metadata.name, namespace ) logger.info(f"Deleted service: {service.metadata.name}") diff --git a/mlrun/runtimes/funcdoc.py b/mlrun/runtimes/funcdoc.py index e53213a70059..a98602bcbecb 100644 --- a/mlrun/runtimes/funcdoc.py +++ b/mlrun/runtimes/funcdoc.py @@ -15,6 +15,7 @@ import ast import inspect import re +import sys from mlrun.model import FunctionEntrypoint @@ -49,13 +50,23 @@ def param_dict(name="", type="", doc="", default=""): } -def func_dict(name, doc, params, returns, lineno): +def func_dict( + name, + doc, + params, + returns, + lineno, + has_varargs: bool = False, + has_kwargs: bool = False, +): return { "name": name, "doc": doc, "params": params, "return": returns, "lineno": lineno, + "has_varargs": has_varargs, + "has_kwargs": has_kwargs, } @@ -165,6 +176,9 @@ def ast_func_info(func: ast.FunctionDef): doc = ast.get_docstring(func) or "" rtype = getattr(func.returns, "id", "") params = [ast_param_dict(p) for p in func.args.args] + # adds info about *args and **kwargs to the function doc + has_varargs = func.args.vararg is not None + has_kwargs = func.args.kwarg is not None defaults = func.args.defaults if defaults: for param, default in zip(params[-len(defaults) :], defaults): @@ -176,6 +190,8 @@ def ast_func_info(func: ast.FunctionDef): params=params, returns=param_dict(type=rtype), lineno=func.lineno, + has_varargs=has_varargs, + has_kwargs=has_kwargs, ) if not doc.strip(): @@ -195,16 +211,33 @@ def ast_param_dict(param: ast.arg) -> dict: def ann_type(ann): if hasattr(ann, "slice"): - name = ann.value.id + if isinstance(ann.value, ast.Attribute): + # value is an attribute, e.g. b of a.b - get the full path + name = get_attr_path(ann.value) + else: + name = ann.value.id inner = ", ".join(ann_type(e) for e in iter_elems(ann.slice)) return f"{name}[{inner}]" if isinstance(ann, ast.Attribute): + if isinstance(ann.value, ast.Attribute): + # value is an attribute, e.g. b of a.b - get the full path + return get_attr_path(ann) + return ann.attr return getattr(ann, "id", "") +def get_attr_path(ann: ast.Attribute): + if isinstance(ann.value, ast.Attribute): + # value is an attribute, e.g. b of a.b - get the full path + return f"{get_attr_path(ann.value)}.{ann.attr}" + + # value can be a subscript or name - get its annotation type and append the attribute + return f"{ann_type(ann.value)}.{ann.attr}" + + def iter_elems(ann): """ Gets the elements of an ast.Subscript.slice, e.g. Union[int, str] -> [int, str] @@ -219,10 +252,13 @@ def iter_elems(ann): return [ann.value] # From python 3.9, slice is an expr and we should evaluate it recursively. Left this for backward compatibility. - elif hasattr(ann.slice, "elts"): - return ann.slice.elts - elif hasattr(ann.slice, "value"): - return [ann.slice.value] + # TODO: Remove this in 1.5.0 when we drop support for python 3.7 + if sys.version_info < (3, 9): + if hasattr(ann.slice, "elts"): + return ann.slice.elts + elif hasattr(ann.slice, "value"): + return [ann.slice.value] + return [ann] diff --git a/mlrun/runtimes/function.py b/mlrun/runtimes/function.py index 03fbaff5a480..40f7a2e6b990 100644 --- a/mlrun/runtimes/function.py +++ b/mlrun/runtimes/function.py @@ -16,10 +16,8 @@ import json import typing import warnings -from base64 import b64encode from datetime import datetime from time import sleep -from urllib.parse import urlparse import nuclio import nuclio.utils @@ -31,15 +29,13 @@ from nuclio.triggers import V3IOStreamTrigger import mlrun.errors +import mlrun.k8s_utils import mlrun.utils -from mlrun.datastore import parse_s3_bucket_and_key +from mlrun.common.schemas import AuthInfo from mlrun.db import RunDBError -from ..api.schemas import AuthInfo from ..config import config as mlconf -from ..config import is_running_as_api from ..errors import err_to_str -from ..k8s_utils import get_k8s_helper from ..kfpops import deploy_op from ..lists import RunList from ..model import RunObject @@ -50,9 +46,8 @@ split_path, v3io_cred, ) -from ..utils import as_number, enrich_image_url, get_in, logger, update_in +from ..utils import get_in, logger, update_in from .base import FunctionStatus, RunError -from .constants import NuclioIngressAddTemplatedIngressModes from .pod import KubeResource, KubeResourceSpec from .utils import get_item_name, log_std @@ -183,6 +178,7 @@ def __init__( security_context=None, service_type=None, add_templated_ingress_host_mode=None, + clone_target_dir=None, ): super().__init__( @@ -212,6 +208,7 @@ def __init__( tolerations=tolerations, preemption_mode=preemption_mode, security_context=security_context, + clone_target_dir=clone_target_dir, ) self.base_spec = base_spec or {} @@ -566,66 +563,39 @@ def deploy( if tag: self.metadata.tag = tag - save_record = False - if not dashboard: - # Attempt auto-mounting, before sending to remote build - self.try_auto_mount_based_on_config() - self._fill_credentials() - db = self._get_db() - logger.info("Starting remote function deploy") - data = db.remote_builder(self, False, builder_env=builder_env) - self.status = data["data"].get("status") - self._update_credentials_from_remote_build(data["data"]) - - # when a function is deployed, we wait for it to be ready by default - # this also means that the function object will be updated with the function status - self._wait_for_function_deployment(db, verbose=verbose) - - # NOTE: on older mlrun versions & nuclio versions, function are exposed via NodePort - # now, functions can be not exposed (using service type ClusterIP) and hence - # for BC we first try to populate the external invocation url, and then - # if not exists, take the internal invocation url - if self.status.external_invocation_urls: - self.spec.command = f"http://{self.status.external_invocation_urls[0]}" - save_record = True - elif self.status.internal_invocation_urls: - self.spec.command = f"http://{self.status.internal_invocation_urls[0]}" - save_record = True - elif self.status.address: - self.spec.command = f"http://{self.status.address}" - save_record = True - - else: - + if dashboard: warnings.warn( - "'dashboard' is deprecated in 1.3.0, and will be removed in 1.5.0, " - "Keep 'dashboard' value empty to allow auto-detection by MLRun API.", - # TODO: Remove in 1.5.0 - FutureWarning, + "'dashboard' parameter is no longer supported on client side, " + "it is being configured through the MLRun API.", ) - self.save(versioned=False) - self._ensure_run_db() - internal_invocation_urls, external_invocation_urls = deploy_nuclio_function( - self, - dashboard=dashboard, - watch=True, - auth_info=auth_info, - ) - self.status.internal_invocation_urls = internal_invocation_urls - self.status.external_invocation_urls = external_invocation_urls - - # save the (first) function external invocation url - # this is made for backwards compatability because the user, at this point, may - # work remotely and need the external invocation url on the spec.command - # TODO: when using `ClusterIP`, this block might not fulfilled - # as long as function doesnt have ingresses - if self.status.external_invocation_urls: - address = self.status.external_invocation_urls[0] - self.spec.command = f"http://{address}" - self.status.state = "ready" - self.status.address = address - save_record = True + save_record = False + # Attempt auto-mounting, before sending to remote build + self.try_auto_mount_based_on_config() + self._fill_credentials() + db = self._get_db() + logger.info("Starting remote function deploy") + data = db.remote_builder(self, False, builder_env=builder_env) + self.status = data["data"].get("status") + self._update_credentials_from_remote_build(data["data"]) + + # when a function is deployed, we wait for it to be ready by default + # this also means that the function object will be updated with the function status + self._wait_for_function_deployment(db, verbose=verbose) + + # NOTE: on older mlrun versions & nuclio versions, function are exposed via NodePort + # now, functions can be not exposed (using service type ClusterIP) and hence + # for BC we first try to populate the external invocation url, and then + # if not exists, take the internal invocation url + if self.status.external_invocation_urls: + self.spec.command = f"http://{self.status.external_invocation_urls[0]}" + save_record = True + elif self.status.internal_invocation_urls: + self.spec.command = f"http://{self.status.internal_invocation_urls[0]}" + save_record = True + elif self.status.address: + self.spec.command = f"http://{self.status.address}" + save_record = True logger.info( "successfully deployed function", @@ -691,7 +661,7 @@ def with_preemption_mode(self, mode): The default preemption mode is configurable in mlrun.mlconf.function_defaults.preemption_mode, by default it's set to **prevent** - :param mode: allow | constrain | prevent | none defined in :py:class:`~mlrun.api.schemas.PreemptionModes` + :param mode: allow | constrain | prevent | none defined in :py:class:`~mlrun.common.schemas.PreemptionModes` """ super().with_preemption_mode(mode=mode) @@ -1080,9 +1050,7 @@ def _resolve_invocation_url(self, path, force_external_address): if ( not force_external_address and self.status.internal_invocation_urls - and get_k8s_helper( - silent=True, log=False - ).is_running_inside_kubernetes_cluster() + and mlrun.k8s_utils.is_running_inside_kubernetes_cluster() ): return f"http://{self.status.internal_invocation_urls[0]}{path}" @@ -1201,430 +1169,6 @@ def get_fullname(name, project, tag): return name -def deploy_nuclio_function( - function: RemoteRuntime, - dashboard="", - watch=False, - auth_info: AuthInfo = None, - client_version: str = None, - builder_env: dict = None, - client_python_version: str = None, -): - """Deploys a nuclio function. - - :param function: nuclio function object - :param dashboard: DEPRECATED. Keep empty to allow auto-detection by MLRun API. - :param watch: wait for function to be ready - :param auth_info: service AuthInfo - :param client_version: mlrun client version - :param builder_env: mlrun builder environment (for config/credentials) - :param client_python_version: mlrun client python version - """ - dashboard = dashboard or mlconf.nuclio_dashboard_url - function_name, project_name, function_config = compile_function_config( - function, - client_version=client_version, - client_python_version=client_python_version, - builder_env=builder_env or {}, - auth_info=auth_info, - ) - - # if mode allows it, enrich function http trigger with an ingress - enrich_function_with_ingress( - function_config, - function.spec.add_templated_ingress_host_mode - or mlconf.httpdb.nuclio.add_templated_ingress_host_mode, - function.spec.service_type or mlconf.httpdb.nuclio.default_service_type, - ) - - try: - return nuclio.deploy.deploy_config( - function_config, - dashboard_url=dashboard, - name=function_name, - project=project_name, - tag=function.metadata.tag, - verbose=function.verbose, - create_new=True, - watch=watch, - return_address_mode=nuclio.deploy.ReturnAddressModes.all, - auth_info=auth_info.to_nuclio_auth_info() if auth_info else None, - ) - except nuclio.utils.DeployError as exc: - if exc.err: - err_message = ( - f"Failed to deploy nuclio function {project_name}/{function_name}" - ) - - try: - - # the error might not be jsonable, so we'll try to parse it - # and extract the error message - json_err = exc.err.response.json() - if "error" in json_err: - err_message += f" {json_err['error']}" - if "errorStackTrace" in json_err: - logger.warning( - "Failed to deploy nuclio function", - nuclio_stacktrace=json_err["errorStackTrace"], - ) - except Exception as parse_exc: - logger.warning( - "Failed to parse nuclio deploy error", - parse_exc=err_to_str(parse_exc), - ) - - mlrun.errors.raise_for_status( - exc.err.response, - err_message, - ) - raise - - -def resolve_function_ingresses(function_spec): - http_trigger = resolve_function_http_trigger(function_spec) - if not http_trigger: - return [] - - ingresses = [] - for _, ingress_config in ( - http_trigger.get("attributes", {}).get("ingresses", {}).items() - ): - ingresses.append(ingress_config) - return ingresses - - -def resolve_function_http_trigger(function_spec): - for trigger_name, trigger_config in function_spec.get("triggers", {}).items(): - if trigger_config.get("kind") != "http": - continue - return trigger_config - - -def _resolve_function_image_pull_secret(function): - """ - the corresponding attribute for 'build.secret' in nuclio is imagePullSecrets, attached link for reference - https://github.com/nuclio/nuclio/blob/e4af2a000dc52ee17337e75181ecb2652b9bf4e5/pkg/processor/build/builder.go#L1073 - if only one of the secrets is set, use it. - if both are set, use the non default one and give precedence to image_pull_secret - """ - # enrich only on server side - if not is_running_as_api(): - return function.spec.image_pull_secret or function.spec.build.secret - - if function.spec.image_pull_secret is None: - function.spec.image_pull_secret = ( - mlrun.mlconf.function.spec.image_pull_secret.default - ) - elif ( - function.spec.image_pull_secret - != mlrun.mlconf.function.spec.image_pull_secret.default - ): - return function.spec.image_pull_secret - - if function.spec.build.secret is None: - function.spec.build.secret = mlrun.mlconf.httpdb.builder.docker_registry_secret - elif ( - function.spec.build.secret != mlrun.mlconf.httpdb.builder.docker_registry_secret - ): - return function.spec.build.secret - - return function.spec.image_pull_secret or function.spec.build.secret - - -def compile_function_config( - function: RemoteRuntime, - client_version: str = None, - client_python_version: str = None, - builder_env=None, - auth_info=None, -): - labels = function.metadata.labels or {} - labels.update({"mlrun/class": function.kind}) - for key, value in labels.items(): - # Adding escaping to the key to prevent it from being split by dots if it contains any - function.set_config(f"metadata.labels.\\{key}\\", value) - - # Add secret configurations to function's pod spec, if secret sources were added. - # Needs to be here, since it adds env params, which are handled in the next lines. - # This only needs to run if we're running within k8s context. If running in Docker, for example, skip. - if get_k8s_helper(silent=True).is_running_inside_kubernetes_cluster(): - function.add_secrets_config_to_spec() - - env_dict, external_source_env_dict = function._get_nuclio_config_spec_env() - - nuclio_runtime = ( - function.spec.nuclio_runtime - or _resolve_nuclio_runtime_python_image( - mlrun_client_version=client_version, python_version=client_python_version - ) - ) - - if is_nuclio_version_in_range("0.0.0", "1.6.0") and nuclio_runtime in [ - "python:3.7", - "python:3.8", - ]: - nuclio_runtime_set_from_spec = nuclio_runtime == function.spec.nuclio_runtime - if nuclio_runtime_set_from_spec: - raise mlrun.errors.MLRunInvalidArgumentError( - f"Nuclio version does not support the configured runtime: {nuclio_runtime}" - ) - else: - # our default is python:3.9, simply set it to python:3.6 to keep supporting envs with old Nuclio - nuclio_runtime = "python:3.6" - - # In nuclio 1.6.0<=v<1.8.0, python runtimes default behavior was to not decode event strings - # Our code is counting on the strings to be decoded, so add the needed env var for those versions - if ( - is_nuclio_version_in_range("1.6.0", "1.8.0") - and "NUCLIO_PYTHON_DECODE_EVENT_STRINGS" not in env_dict - ): - env_dict["NUCLIO_PYTHON_DECODE_EVENT_STRINGS"] = "true" - - nuclio_spec = nuclio.ConfigSpec( - env=env_dict, - external_source_env=external_source_env_dict, - config=function.spec.config, - ) - nuclio_spec.cmd = function.spec.build.commands or [] - project = function.metadata.project or "default" - tag = function.metadata.tag - handler = function.spec.function_handler - - if function.spec.build.source: - _compile_nuclio_archive_config( - nuclio_spec, function, builder_env, project, auth_info=auth_info - ) - - nuclio_spec.set_config("spec.runtime", nuclio_runtime) - - # In Nuclio >= 1.6.x default serviceType has changed to "ClusterIP". - nuclio_spec.set_config( - "spec.serviceType", - function.spec.service_type or mlconf.httpdb.nuclio.default_service_type, - ) - if function.spec.readiness_timeout: - nuclio_spec.set_config( - "spec.readinessTimeoutSeconds", function.spec.readiness_timeout - ) - if function.spec.resources: - nuclio_spec.set_config("spec.resources", function.spec.resources) - if function.spec.no_cache: - nuclio_spec.set_config("spec.build.noCache", True) - if function.spec.build.functionSourceCode: - nuclio_spec.set_config( - "spec.build.functionSourceCode", function.spec.build.functionSourceCode - ) - - image_pull_secret = _resolve_function_image_pull_secret(function) - if image_pull_secret: - nuclio_spec.set_config("spec.imagePullSecrets", image_pull_secret) - - if function.spec.base_image_pull: - nuclio_spec.set_config("spec.build.noBaseImagesPull", False) - # don't send node selections if nuclio is not compatible - if validate_nuclio_version_compatibility("1.5.20", "1.6.10"): - if function.spec.node_selector: - nuclio_spec.set_config("spec.nodeSelector", function.spec.node_selector) - if function.spec.node_name: - nuclio_spec.set_config("spec.nodeName", function.spec.node_name) - if function.spec.affinity: - nuclio_spec.set_config( - "spec.affinity", - mlrun.runtimes.pod.get_sanitized_attribute(function.spec, "affinity"), - ) - - # don't send tolerations if nuclio is not compatible - if validate_nuclio_version_compatibility("1.7.5"): - if function.spec.tolerations: - nuclio_spec.set_config( - "spec.tolerations", - mlrun.runtimes.pod.get_sanitized_attribute( - function.spec, "tolerations" - ), - ) - # don't send preemption_mode if nuclio is not compatible - if validate_nuclio_version_compatibility("1.8.6"): - if function.spec.preemption_mode: - nuclio_spec.set_config( - "spec.PreemptionMode", - function.spec.preemption_mode, - ) - - # don't send default or any priority class name if nuclio is not compatible - if ( - function.spec.priority_class_name - and validate_nuclio_version_compatibility("1.6.18") - and len(mlconf.get_valid_function_priority_class_names()) - ): - nuclio_spec.set_config( - "spec.priorityClassName", function.spec.priority_class_name - ) - - if function.spec.replicas: - - nuclio_spec.set_config( - "spec.minReplicas", as_number("spec.Replicas", function.spec.replicas) - ) - nuclio_spec.set_config( - "spec.maxReplicas", as_number("spec.Replicas", function.spec.replicas) - ) - - else: - nuclio_spec.set_config( - "spec.minReplicas", - as_number("spec.minReplicas", function.spec.min_replicas), - ) - nuclio_spec.set_config( - "spec.maxReplicas", - as_number("spec.maxReplicas", function.spec.max_replicas), - ) - - if function.spec.service_account: - nuclio_spec.set_config("spec.serviceAccount", function.spec.service_account) - - if function.spec.security_context: - nuclio_spec.set_config( - "spec.securityContext", - mlrun.runtimes.pod.get_sanitized_attribute( - function.spec, "security_context" - ), - ) - - if ( - function.spec.base_spec - or function.spec.build.functionSourceCode - or function.spec.build.source - or function.kind == mlrun.runtimes.RuntimeKinds.serving # serving can be empty - ): - config = function.spec.base_spec - if not config: - # if base_spec was not set (when not using code_to_function) and we have base64 code - # we create the base spec with essential attributes - config = nuclio.config.new_config() - update_in(config, "spec.handler", handler or "main:handler") - - config = nuclio.config.extend_config( - config, nuclio_spec, tag, function.spec.build.code_origin - ) - - update_in(config, "metadata.name", function.metadata.name) - update_in(config, "spec.volumes", function.spec.generate_nuclio_volumes()) - base_image = ( - get_in(config, "spec.build.baseImage") - or function.spec.image - or function.spec.build.base_image - ) - if base_image: - update_in( - config, - "spec.build.baseImage", - enrich_image_url(base_image, client_version, client_python_version), - ) - - logger.info("deploy started") - name = get_fullname(function.metadata.name, project, tag) - function.status.nuclio_name = name - update_in(config, "metadata.name", name) - - if function.kind == mlrun.runtimes.RuntimeKinds.serving and not get_in( - config, "spec.build.functionSourceCode" - ): - if not function.spec.build.source: - # set the source to the mlrun serving wrapper - body = nuclio.build.mlrun_footer.format( - mlrun.runtimes.serving.serving_subkind - ) - update_in( - config, - "spec.build.functionSourceCode", - b64encode(body.encode("utf-8")).decode("utf-8"), - ) - elif not function.spec.function_handler: - # point the nuclio function handler to mlrun serving wrapper handlers - update_in( - config, - "spec.handler", - "mlrun.serving.serving_wrapper:handler", - ) - else: - # todo: should be deprecated (only work via MLRun service) - # this may also be called in case of using single file code_to_function(embed_code=False) - # this option need to be removed or be limited to using remote files (this code runs in server) - name, config, code = nuclio.build_file( - function.spec.source, - name=function.metadata.name, - project=project, - handler=handler, - tag=tag, - spec=nuclio_spec, - kind=function.spec.function_kind, - verbose=function.verbose, - ) - - update_in(config, "spec.volumes", function.spec.generate_nuclio_volumes()) - base_image = function.spec.image or function.spec.build.base_image - if base_image: - update_in( - config, - "spec.build.baseImage", - enrich_image_url(base_image, client_version, client_python_version), - ) - - name = get_fullname(name, project, tag) - function.status.nuclio_name = name - - update_in(config, "metadata.name", name) - - return name, project, config - - -def enrich_function_with_ingress(config, mode, service_type): - # do not enrich with an ingress - if mode == NuclioIngressAddTemplatedIngressModes.never: - return - - ingresses = resolve_function_ingresses(config["spec"]) - - # function has ingresses already, nothing to add / enrich - if ingresses: - return - - # if exists, get the http trigger the function has - # we would enrich it with an ingress - http_trigger = resolve_function_http_trigger(config["spec"]) - if not http_trigger: - # function has an HTTP trigger without an ingress - # TODO: read from nuclio-api frontend-spec - http_trigger = { - "kind": "http", - "name": "http", - "maxWorkers": 1, - "workerAvailabilityTimeoutMilliseconds": 10000, # 10 seconds - "attributes": {}, - } - - def enrich(): - http_trigger.setdefault("attributes", {}).setdefault("ingresses", {})["0"] = { - "paths": ["/"], - # this would tell Nuclio to use its default ingress host template - # and would auto assign a host for the ingress - "hostTemplate": "@nuclio.fromDefault", - } - http_trigger["attributes"]["serviceType"] = service_type - config["spec"].setdefault("triggers", {})[http_trigger["name"]] = http_trigger - - if mode == NuclioIngressAddTemplatedIngressModes.always: - enrich() - elif mode == NuclioIngressAddTemplatedIngressModes.on_cluster_ip: - - # service type is not cluster ip, bail out - if service_type and service_type.lower() != "clusterip": - return - - enrich() - - def get_nuclio_deploy_status( name, project, @@ -1683,163 +1227,3 @@ def get_nuclio_deploy_status( else: text = "\n".join(outputs) if outputs else "" return state, address, name, last_log_timestamp, text, function_status - - -def _compile_nuclio_archive_config( - nuclio_spec, - function: RemoteRuntime, - builder_env, - project=None, - auth_info=None, -): - secrets = {} - if project and get_k8s_helper(silent=True).is_running_inside_kubernetes_cluster(): - secrets = get_k8s_helper().get_project_secret_data(project) - - def get_secret(key): - return builder_env.get(key) or secrets.get(key, "") - - source = function.spec.build.source - parsed_url = urlparse(source) - code_entry_type = "" - if source.startswith("s3://"): - code_entry_type = "s3" - if source.startswith("git://"): - code_entry_type = "git" - for archive_prefix in ["http://", "https://", "v3io://", "v3ios://"]: - if source.startswith(archive_prefix): - code_entry_type = "archive" - - if code_entry_type == "": - raise mlrun.errors.MLRunInvalidArgumentError( - "Couldn't resolve code entry type from source" - ) - - code_entry_attributes = {} - - # resolve work_dir and handler - work_dir, handler = _resolve_work_dir_and_handler(function.spec.function_handler) - work_dir = function.spec.workdir or work_dir - if work_dir != "": - code_entry_attributes["workDir"] = work_dir - - # archive - if code_entry_type == "archive": - v3io_access_key = builder_env.get("V3IO_ACCESS_KEY", "") - if source.startswith("v3io"): - if not parsed_url.netloc: - source = mlrun.mlconf.v3io_api + parsed_url.path - else: - source = f"http{source[len('v3io'):]}" - if auth_info and not v3io_access_key: - v3io_access_key = auth_info.data_session or auth_info.access_key - - if v3io_access_key: - code_entry_attributes["headers"] = {"X-V3io-Session-Key": v3io_access_key} - - # s3 - if code_entry_type == "s3": - bucket, item_key = parse_s3_bucket_and_key(source) - - code_entry_attributes["s3Bucket"] = bucket - code_entry_attributes["s3ItemKey"] = item_key - - code_entry_attributes["s3AccessKeyId"] = get_secret("AWS_ACCESS_KEY_ID") - code_entry_attributes["s3SecretAccessKey"] = get_secret("AWS_SECRET_ACCESS_KEY") - code_entry_attributes["s3SessionToken"] = get_secret("AWS_SESSION_TOKEN") - - # git - if code_entry_type == "git": - - # change git:// to https:// as nuclio expects it to be - if source.startswith("git://"): - source = source.replace("git://", "https://") - - source, reference, branch = mlrun.utils.resolve_git_reference_from_source( - source - ) - if not branch and not reference: - raise mlrun.errors.MLRunInvalidArgumentError( - "git branch or refs must be specified in the source e.g.: " - "'git:///org/repo.git#'" - ) - if reference: - code_entry_attributes["reference"] = reference - if branch: - code_entry_attributes["branch"] = branch - - password = get_secret("GIT_PASSWORD") - username = get_secret("GIT_USERNAME") - - token = get_secret("GIT_TOKEN") - if token: - username, password = mlrun.utils.get_git_username_password_from_token(token) - - code_entry_attributes["username"] = username - code_entry_attributes["password"] = password - - # populate spec with relevant fields - nuclio_spec.set_config("spec.handler", handler) - nuclio_spec.set_config("spec.build.path", source) - nuclio_spec.set_config("spec.build.codeEntryType", code_entry_type) - nuclio_spec.set_config("spec.build.codeEntryAttributes", code_entry_attributes) - - -def _resolve_work_dir_and_handler(handler): - """ - Resolves a nuclio function working dir and handler inside an archive/git repo - :param handler: a path describing working dir and handler of a nuclio function - :return: (working_dir, handler) tuple, as nuclio expects to get it - - Example: ("a/b/c#main:Handler") -> ("a/b/c", "main:Handler") - """ - - def extend_handler(base_handler): - # return default handler and module if not specified - if not base_handler: - return "main:handler" - if ":" not in base_handler: - base_handler = f"{base_handler}:handler" - return base_handler - - if not handler: - return "", "main:handler" - - split_handler = handler.split("#") - if len(split_handler) == 1: - return "", extend_handler(handler) - - return split_handler[0], extend_handler(split_handler[1]) - - -def _resolve_nuclio_runtime_python_image( - mlrun_client_version: str = None, python_version: str = None -): - # if no python version or mlrun version is passed it means we use mlrun client older than 1.3.0 therefore need - # to use the previoud default runtime which is python 3.7 - if not python_version or not mlrun_client_version: - return "python:3.7" - - # If the mlrun version is 0.0.0-, it is a dev version, - # so we can't check if it is higher than 1.3.0, but if the python version was passed, - # it means it is 1.3.0-rc or higher, so use the image according to the python version - if mlrun_client_version.startswith("0.0.0-") or "unstable" in mlrun_client_version: - if python_version.startswith("3.7"): - return "python:3.7" - - return mlrun.mlconf.default_nuclio_runtime - - # if mlrun version is older than 1.3.0 we need to use the previous default runtime which is python 3.7 - if semver.VersionInfo.parse(mlrun_client_version) < semver.VersionInfo.parse( - "1.3.0-X" - ): - return "python:3.7" - - # if mlrun version is 1.3.0 or newer and python version is 3.7 we need to use python 3.7 image - if semver.VersionInfo.parse(mlrun_client_version) >= semver.VersionInfo.parse( - "1.3.0-X" - ) and python_version.startswith("3.7"): - return "python:3.7" - - # if none of the above conditions are met we use the default runtime which is python 3.9 - return mlrun.mlconf.default_nuclio_runtime diff --git a/mlrun/runtimes/kubejob.py b/mlrun/runtimes/kubejob.py index 745004ab4791..532b3be9c59e 100644 --- a/mlrun/runtimes/kubejob.py +++ b/mlrun/runtimes/kubejob.py @@ -14,15 +14,15 @@ import os import time +import warnings from kubernetes import client from kubernetes.client.rest import ApiException -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.errors from mlrun.runtimes.base import BaseRuntimeHandler -from ..builder import build_runtime from ..db import RunDBError from ..errors import err_to_str from ..kfpops import build_op @@ -30,7 +30,7 @@ from ..utils import get_in, logger from .base import RunError, RuntimeClassMode from .pod import KubeResource, kube_resource_spec_to_pod_spec -from .utils import AsyncLogWriter +from .utils import get_k8s class KubejobRuntime(KubeResource): @@ -44,12 +44,13 @@ def is_deployed(self): if self.spec.image: return True - if self._is_remote_api(): - db = self._get_db() - try: - db.get_builder_status(self, logs=False) - except Exception: - pass + db = self._get_db() + try: + # getting builder status enriches the runtime when it needs to be fetched from the API, + # otherwise it's a no-op + db.get_builder_status(self, logs=False) + except Exception: + pass if self.spec.image: return True @@ -58,27 +59,30 @@ def is_deployed(self): return False def with_source_archive( - self, source, workdir=None, handler=None, pull_at_runtime=True + self, source, workdir=None, handler=None, pull_at_runtime=True, target_dir=None ): """load the code from git/tar/zip archive at runtime or build - :param source: valid path to git, zip, or tar file, e.g. - git://github.com/mlrun/something.git - http://some/url/file.zip - :param handler: default function handler - :param workdir: working dir relative to the archive root or absolute (e.g. './subdir') + :param source: valid absolute path or URL to git, zip, or tar file, e.g. + git://github.com/mlrun/something.git + http://some/url/file.zip + note path source must exist on the image or exist locally when run is local + (it is recommended to use 'workdir' when source is a filepath instead) + :param handler: default function handler + :param workdir: working dir relative to the archive root (e.g. './subdir') or absolute to the image root :param pull_at_runtime: load the archive into the container at job runtime vs on build/deploy + :param target_dir: target dir on runtime pod or repo clone / archive extraction """ - if source.endswith(".zip") and not pull_at_runtime: - logger.warn( - "zip files are not natively extracted by docker, use tar.gz for faster loading during build" - ) + mlrun.utils.helpers.validate_builder_source(source, pull_at_runtime, workdir) self.spec.build.source = source if handler: self.spec.default_handler = handler if workdir: self.spec.workdir = workdir + if target_dir: + self.spec.clone_target_dir = target_dir + self.spec.build.load_source_on_run = pull_at_runtime if ( self.spec.build.base_image @@ -86,7 +90,7 @@ def with_source_archive( and pull_at_runtime and not self.spec.image ): - # if we load source from repo and dont need a full build use the base_image as the image + # if we load source from repo and don't need a full build use the base_image as the image self.spec.image = self.spec.build.base_image elif not pull_at_runtime: # clear the image so build will not be skipped @@ -106,7 +110,9 @@ def build_config( auto_build=None, requirements=None, overwrite=False, - verify_base_image=True, + verify_base_image=False, + prepare_image_for_deploy=True, + requirements_file=None, ): """specify builder configuration for the deploy operation @@ -121,43 +127,43 @@ def build_config( :param with_mlrun: add the current mlrun package to the container build :param auto_build: when set to True and the function require build it will be built on the first function run, use only if you dont plan on changing the build config between runs - :param requirements: requirements.txt file to install or list of packages to install + :param requirements: a list of packages to install + :param requirements_file: requirements file to install :param overwrite: overwrite existing build configuration * False: the new params are merged with the existing (currently merge is applied to requirements and commands) * True: the existing params are replaced by the new ones - :param verify_base_image: verify the base image is set + :param verify_base_image: verify that the base image is configured + (deprecated, use prepare_image_for_deploy) + :param prepare_image_for_deploy: prepare the image/base_image spec for deployment """ - if image: - self.spec.build.image = image - if base_image: - self.spec.build.base_image = base_image - # if overwrite and requirements or commands passed, clear the existing commands - # (requirements are added to the commands parameter) - if (requirements or commands) and overwrite: - self.spec.build.commands = None - if requirements: - self.with_requirements( - requirements, overwrite=False, verify_base_image=False - ) - if commands: - self.with_commands(commands, overwrite=False, verify_base_image=False) - if extra: - self.spec.build.extra = extra - if secret is not None: - self.spec.build.secret = secret - if source: - self.spec.build.source = source - if load_source_on_run: - self.spec.build.load_source_on_run = load_source_on_run - if with_mlrun is not None: - self.spec.build.with_mlrun = with_mlrun - if auto_build: - self.spec.build.auto_build = auto_build - - if verify_base_image: - self.verify_base_image() + + image = mlrun.utils.helpers.remove_image_protocol_prefix(image) + self.spec.build.build_config( + image, + base_image, + commands, + secret, + source, + extra, + load_source_on_run, + with_mlrun, + auto_build, + requirements, + requirements_file, + overwrite, + ) + + if verify_base_image or prepare_image_for_deploy: + if verify_base_image: + # TODO: remove verify_base_image in 1.6.0 + warnings.warn( + "verify_base_image is deprecated in 1.4.0 and will be removed in 1.6.0, " + "use prepare_image_for_deploy", + category=FutureWarning, + ) + self.prepare_image_for_deploy() def deploy( self, @@ -194,7 +200,13 @@ def deploy( or "/mlrun/" in build.base_image ) - if not build.source and not build.commands and not build.extra and with_mlrun: + if ( + not build.source + and not build.commands + and not build.requirements + and not build.extra + and with_mlrun + ): logger.info( "running build to add mlrun package, set " "with_mlrun=False to skip if its already in the image" @@ -209,6 +221,7 @@ def deploy( if is_kfp: watch = True + ready = False if self._is_remote_api(): db = self._get_db() data = db.remote_builder( @@ -223,7 +236,8 @@ def deploy( self.spec.build.base_image = self.spec.build.base_image or get_in( data, "data.spec.build.base_image" ) - self.spec.workdir = get_in(data, "data.spec.workdir") + # get the clone target dir in case it was enriched due to loading source + self.spec.clone_target_dir = get_in(data, "data.spec.clone_target_dir") ready = data.get("ready", False) if not ready: logger.info( @@ -233,17 +247,6 @@ def deploy( state = self._build_watch(watch, show_on_failure=show_on_failure) ready = state == "ready" self.status.state = state - else: - self.save(versioned=False) - ready = build_runtime( - mlrun.api.schemas.AuthInfo(), - self, - with_mlrun, - mlrun_version_specifier, - skip_deployed, - watch, - ) - self.save(versioned=False) if watch and not ready: raise mlrun.errors.MLRunRuntimeError("Deploy failed") @@ -280,36 +283,6 @@ def print_log(text): print() return self.status.state - def builder_status(self, watch=True, logs=True): - if self._is_remote_api(): - return self._build_watch(watch, logs) - - else: - pod = self.status.build_pod - if not self.status.state == "ready" and pod: - k8s = self._get_k8s() - status = k8s.get_pod_status(pod) - if logs: - if watch: - status = k8s.watch(pod) - else: - resp = k8s.logs(pod) - if resp: - print(resp.encode()) - - if status == "succeeded": - self.status.build_pod = None - self.status.state = "ready" - logger.info("build completed successfully") - return "ready" - if status in ["failed", "error"]: - self.status.state = status - logger.error(f" build {status}, watch the build pod logs: {pod}") - return status - - logger.info(f"builder status is: {status}, wait for it to complete") - return None - def deploy_step( self, image=None, @@ -336,23 +309,14 @@ def deploy_step( ) def _run(self, runobj: RunObject, execution): - command, args, extra_env = self._get_cmd_args(runobj) if runobj.metadata.iteration: self.store_run(runobj) - k8s = self._get_k8s() new_meta = self._get_meta(runobj) self._add_secrets_to_spec_before_running(runobj) - workdir = self.spec.workdir - if workdir: - if self.spec.build.source and self.spec.build.load_source_on_run: - # workdir will be set AFTER the clone - workdir = None - elif not workdir.startswith("/"): - # relative path mapped to real path in the job pod - workdir = os.path.join("/mlrun", workdir) + workdir = self._resolve_workdir() pod_spec = func_to_pod( self.full_image_path( @@ -369,23 +333,41 @@ def _run(self, runobj: RunObject, execution): ) pod = client.V1Pod(metadata=new_meta, spec=pod_spec) try: - pod_name, namespace = k8s.create_pod(pod) + pod_name, namespace = get_k8s().create_pod(pod) except ApiException as exc: raise RunError(err_to_str(exc)) - if pod_name and self.kfp: - writer = AsyncLogWriter(self._db_conn, runobj) - status = k8s.watch(pod_name, namespace, writer=writer) - - if status in ["failed", "error"]: - raise RunError(f"pod exited with {status}, check logs") - else: - txt = f"Job is running in the background, pod: {pod_name}" - logger.info(txt) - runobj.status.status_text = txt + txt = f"Job is running in the background, pod: {pod_name}" + logger.info(txt) + runobj.status.status_text = txt return None + def _resolve_workdir(self): + """ + The workdir is relative to the source root, if the source is not loaded on run then the workdir + is relative to the clone target dir (where the source was copied to). + Otherwise, if the source is loaded on run, the workdir is resolved on the run as well. + If the workdir is absolute, keep it as is. + """ + workdir = self.spec.workdir + if self.spec.build.source and self.spec.build.load_source_on_run: + # workdir will be set AFTER the clone which is done in the pre-run of local runtime + return None + + if workdir and os.path.isabs(workdir): + return workdir + + if self.spec.clone_target_dir: + workdir = workdir or "" + if workdir.startswith("./"): + # TODO: use 'removeprefix' when we drop python 3.7 support + # workdir.removeprefix("./") + workdir = workdir[2:] + return os.path.join(self.spec.clone_target_dir, workdir) + + return workdir + def func_to_pod(image, runtime, extra_env, command, args, workdir): container = client.V1Container( diff --git a/mlrun/runtimes/local.py b/mlrun/runtimes/local.py index 2ea733f039d1..2caec47ad1f7 100644 --- a/mlrun/runtimes/local.py +++ b/mlrun/runtimes/local.py @@ -14,11 +14,13 @@ import importlib.util as imputil import inspect +import io import json import os import socket import sys import tempfile +import threading import traceback from contextlib import redirect_stdout from copy import copy @@ -39,7 +41,7 @@ from ..model import RunObject from ..utils import get_handler_extended, get_in, logger, set_paths from ..utils.clones import extract_source -from .base import BaseRuntime, FunctionSpec, spec_fields +from .base import BaseRuntime from .kubejob import KubejobRuntime from .remotesparkjob import RemoteSparkRuntime from .utils import RunError, global_context, log_std @@ -170,48 +172,10 @@ def _run(self, runobj: RunObject, execution: MLClientCtx): return context.to_dict() -class LocalFunctionSpec(FunctionSpec): - _dict_fields = spec_fields + ["clone_target_dir"] - - def __init__( - self, - command=None, - args=None, - mode=None, - default_handler=None, - pythonpath=None, - entry_points=None, - description=None, - workdir=None, - build=None, - clone_target_dir=None, - ): - super().__init__( - command=command, - args=args, - mode=mode, - build=build, - entry_points=entry_points, - description=description, - workdir=workdir, - default_handler=default_handler, - pythonpath=pythonpath, - ) - self.clone_target_dir = clone_target_dir - - class LocalRuntime(BaseRuntime, ParallelRunner): kind = "local" _is_remote = False - @property - def spec(self) -> LocalFunctionSpec: - return self._spec - - @spec.setter - def spec(self, spec): - self._spec = self._verify_dict(spec, "spec", LocalFunctionSpec) - def to_job(self, image=""): struct = self.to_dict() obj = KubejobRuntime.from_dict(struct) @@ -222,12 +186,12 @@ def to_job(self, image=""): def with_source_archive(self, source, workdir=None, handler=None, target_dir=None): """load the code from git/tar/zip archive at runtime or build - :param source: valid path to git, zip, or tar file, e.g. - git://github.com/mlrun/something.git - http://some/url/file.zip - :param handler: default function handler - :param workdir: working dir relative to the archive root or absolute (e.g. './subdir') - :param target_dir: local target dir for repo clone (by default its /code) + :param source: valid path to git, zip, or tar file, e.g. + git://github.com/mlrun/something.git + http://some/url/file.zip + :param handler: default function handler + :param workdir: working dir relative to the archive root (e.g. './subdir') or absolute + :param target_dir: local target dir for repo clone (by default its /code) """ self.spec.build.source = source self.spec.build.load_source_on_run = True @@ -253,6 +217,8 @@ def _pre_run(self, runobj: RunObject, execution: MLClientCtx): execution._current_workdir = workdir execution._old_workdir = None + # _is_run_local is set when the user specifies local=True in run() + # in this case we don't want to extract the source code and contaminate the user's local dir if self.spec.build.source and not hasattr(self, "_is_run_local"): target_dir = extract_source( self.spec.build.source, @@ -396,21 +362,43 @@ def load_module(file_name, handler, context): def run_exec(cmd, args, env=None, cwd=None): if args: cmd += args - out = "" if env and "SYSTEMROOT" in os.environ: env["SYSTEMROOT"] = os.environ["SYSTEMROOT"] print("running:", cmd) - process = Popen(cmd, stdout=PIPE, stderr=PIPE, env=os.environ, cwd=cwd) - while True: - nextline = process.stdout.readline() - if not nextline and process.poll() is not None: - break - print(nextline.decode("utf-8"), end="") - sys.stdout.flush() - out += nextline.decode("utf-8") + process = Popen( + cmd, stdout=PIPE, stderr=PIPE, env=os.environ, cwd=cwd, universal_newlines=True + ) + + def read_stderr(stderr): + while True: + nextline = process.stderr.readline() + if not nextline: + break + stderr.write(nextline) + + # ML-3710. We must read stderr in a separate thread to drain the stderr pipe so that the spawned process won't + # hang if it tries to write more to stderr than the buffer size (default of approx 8kb). + with io.StringIO() as stderr: + stderr_consumer_thread = threading.Thread(target=read_stderr, args=[stderr]) + stderr_consumer_thread.start() + + with io.StringIO() as stdout: + while True: + nextline = process.stdout.readline() + if not nextline: + break + print(nextline, end="") + sys.stdout.flush() + stdout.write(nextline) + out = stdout.getvalue() + + stderr_consumer_thread.join() + err = stderr.getvalue() + + # if we return anything for err, the caller will assume that the process failed code = process.poll() + err = "" if code == 0 else err - err = process.stderr.read().decode("utf-8") if code != 0 else "" return out, err @@ -447,20 +435,23 @@ def exec_from_params(handler, runobj: RunObject, context: MLClientCtx, cwd=None) if cwd: os.chdir(cwd) # Apply the MLRun handler decorator for parsing inputs using type hints and logging outputs using log hints - # (Expected behavior: inputs are being parsed when they have type hints in code or given by user. - # outputs are logged only if log hints are provided by the user): - val = mlrun.handler( - inputs=( - runobj.spec.inputs_type_hints - if runobj.spec.inputs_type_hints - else True # True will use type hints if provided in user's code. - ), - outputs=( - runobj.spec.returns - if runobj.spec.returns - else None # None will turn off outputs logging. - ), - )(handler)(**kwargs) + # (Expected behavior: inputs are being parsed when they have type hints in code or given by user. Outputs + # are logged only if log hints are provided by the user): + if mlrun.mlconf.packagers.enabled: + val = mlrun.handler( + inputs=( + runobj.spec.inputs_type_hints + if runobj.spec.inputs_type_hints + else True # True will use type hints if provided in user's code. + ), + outputs=( + runobj.spec.returns + if runobj.spec.returns + else None # None will turn off outputs logging. + ), + )(handler)(**kwargs) + else: + val = handler(**kwargs) context.set_state("completed", commit=False) except Exception as exc: err = err_to_str(exc) @@ -487,6 +478,18 @@ def get_func_arg(handler, runobj: RunObject, context: MLClientCtx, is_nuclio=Fal kwargs = {} args = inspect.signature(handler).parameters + def _get_input_value(input_key: str): + input_obj = context.get_input(input_key, inputs[input_key]) + # If there is no type hint annotation but there is a default value and its type is string, point the data + # item to local downloaded file path (`local()` returns the downloaded temp path string): + if ( + args[input_key].annotation is inspect.Parameter.empty + and type(args[input_key].default) is str + ): + return input_obj.local() + else: + return input_obj + for key in args.keys(): if key == "context": kwargs[key] = context @@ -495,9 +498,23 @@ def get_func_arg(handler, runobj: RunObject, context: MLClientCtx, is_nuclio=Fal elif key in params: kwargs[key] = copy(params[key]) elif key in inputs: - obj = context.get_input(key, inputs[key]) - if type(args[key].default) is str or args[key].annotation == str: - kwargs[key] = obj.local() - else: - kwargs[key] = context.get_input(key, inputs[key]) + kwargs[key] = _get_input_value(key) + + list_of_params = list(args.values()) + if len(list_of_params) == 0: + return kwargs + + # get the last parameter, as **kwargs can only be last in the function's parameters list + last_param = list_of_params[-1] + # VAR_KEYWORD meaning : A dict of keyword arguments that aren’t bound to any other parameter. + # This corresponds to a **kwargs parameter in a Python function definition. + if last_param.kind == last_param.VAR_KEYWORD: + # if handler has **kwargs, pass all parameters provided by the user to the handler which were not already set + # as part of the previous loop which handled all parameters which were explicitly defined in the handler + for key in params: + if key not in kwargs: + kwargs[key] = copy(params[key]) + for key in inputs: + if key not in kwargs: + kwargs[key] = _get_input_value(key) return kwargs diff --git a/mlrun/runtimes/mpijob/abstract.py b/mlrun/runtimes/mpijob/abstract.py index ec54df540e41..cde25bb948fd 100644 --- a/mlrun/runtimes/mpijob/abstract.py +++ b/mlrun/runtimes/mpijob/abstract.py @@ -24,7 +24,7 @@ from mlrun.model import RunObject from mlrun.runtimes.kubejob import KubejobRuntime from mlrun.runtimes.pod import KubeResourceSpec -from mlrun.runtimes.utils import RunError +from mlrun.runtimes.utils import RunError, get_k8s from mlrun.utils import get_in, logger @@ -60,6 +60,7 @@ def __init__( tolerations=None, preemption_mode=None, security_context=None, + clone_target_dir=None, ): super().__init__( command=command, @@ -88,6 +89,7 @@ def __init__( tolerations=tolerations, preemption_mode=preemption_mode, security_context=security_context, + clone_target_dir=clone_target_dir, ) self.mpi_args = mpi_args or [ "-x", @@ -189,10 +191,9 @@ def _run(self, runobj: RunObject, execution: MLClientCtx): def _submit_mpijob(self, job, namespace=None): mpi_group, mpi_version, mpi_plural = self._get_crd_info() - k8s = self._get_k8s() - namespace = k8s.resolve_namespace(namespace) + namespace = get_k8s().resolve_namespace(namespace) try: - resp = k8s.crdapi.create_namespaced_custom_object( + resp = get_k8s().crdapi.create_namespaced_custom_object( mpi_group, mpi_version, namespace=namespace, @@ -208,7 +209,7 @@ def _submit_mpijob(self, job, namespace=None): def delete_job(self, name, namespace=None): mpi_group, mpi_version, mpi_plural = self._get_crd_info() - k8s = self._get_k8s() + k8s = get_k8s() namespace = k8s.resolve_namespace(namespace) try: # delete the mpi job @@ -223,11 +224,10 @@ def delete_job(self, name, namespace=None): def list_jobs(self, namespace=None, selector="", show=True): mpi_group, mpi_version, mpi_plural = self._get_crd_info() - k8s = self._get_k8s() - namespace = k8s.resolve_namespace(namespace) + namespace = get_k8s().resolve_namespace(namespace) items = [] try: - resp = k8s.crdapi.list_namespaced_custom_object( + resp = get_k8s().crdapi.list_namespaced_custom_object( mpi_group, mpi_version, namespace, @@ -247,10 +247,9 @@ def list_jobs(self, namespace=None, selector="", show=True): def get_job(self, name, namespace=None): mpi_group, mpi_version, mpi_plural = self._get_crd_info() - k8s = self._get_k8s() - namespace = k8s.resolve_namespace(namespace) + namespace = get_k8s().resolve_namespace(namespace) try: - resp = k8s.crdapi.get_namespaced_custom_object( + resp = get_k8s().crdapi.get_namespaced_custom_object( mpi_group, mpi_version, namespace, mpi_plural, name ) except client.exceptions.ApiException as exc: @@ -259,12 +258,11 @@ def get_job(self, name, namespace=None): return resp def get_pods(self, name=None, namespace=None, launcher=False): - k8s = self._get_k8s() - namespace = k8s.resolve_namespace(namespace) + namespace = get_k8s().resolve_namespace(namespace) selector = self._generate_pods_selector(name, launcher) - pods = k8s.list_pods(selector=selector, namespace=namespace) + pods = get_k8s().list_pods(selector=selector, namespace=namespace) if pods: return {p.metadata.name: p.status.phase for p in pods} diff --git a/mlrun/runtimes/mpijob/v1.py b/mlrun/runtimes/mpijob/v1.py index 52115e05c378..e867b4a8859c 100644 --- a/mlrun/runtimes/mpijob/v1.py +++ b/mlrun/runtimes/mpijob/v1.py @@ -62,6 +62,7 @@ def __init__( tolerations=None, preemption_mode=None, security_context=None, + clone_target_dir=None, ): super().__init__( command=command, @@ -91,6 +92,7 @@ def __init__( tolerations=tolerations, preemption_mode=preemption_mode, security_context=security_context, + clone_target_dir=clone_target_dir, ) self.clean_pod_policy = clean_pod_policy or MPIJobV1CleanPodPolicies.default() diff --git a/mlrun/runtimes/package/context_handler.py b/mlrun/runtimes/package/context_handler.py deleted file mode 100644 index fe6a60277005..000000000000 --- a/mlrun/runtimes/package/context_handler.py +++ /dev/null @@ -1,711 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import inspect -import os -import shutil -from collections import OrderedDict -from enum import Enum -from pathlib import Path -from typing import Any, Callable, Dict, List, Type, Union - -import cloudpickle -import numpy as np -import pandas as pd - -from mlrun.datastore import DataItem -from mlrun.errors import MLRunInvalidArgumentError, MLRunRuntimeError -from mlrun.execution import MLClientCtx -from mlrun.utils import logger - - -# TODO: Move the `ArtifactType` to constants.py -class ArtifactType(Enum): - """ - Possible artifact types to log using the MLRun `context` decorator. - """ - - # Types: - DATASET = "dataset" - DIRECTORY = "directory" - FILE = "file" - OBJECT = "object" - PLOT = "plot" - RESULT = "result" - - # Constants: - DEFAULT = RESULT - - -class InputsParser: - """ - A static class to hold all the common parsing functions - functions for parsing MLRun DataItem to the user desired - type. - """ - - @staticmethod - def parse_pandas_dataframe(data_item: DataItem) -> pd.DataFrame: - """ - Parse an MLRun `DataItem` to a `pandas.DataFrame`. - - :param data_item: The `DataItem` to parse. - - :returns: The `DataItem` as a `pandas.DataFrame`. - """ - return data_item.as_df() - - @staticmethod - def parse_numpy_array(data_item: DataItem) -> np.ndarray: - """ - Parse an MLRun `DataItem` to a `numpy.ndarray`. - - :param data_item: The `DataItem` to parse. - - :returns: The `DataItem` as a `numpy.ndarray`. - """ - return data_item.as_df().to_numpy() - - @staticmethod - def parse_dict(data_item: DataItem) -> dict: - """ - Parse an MLRun `DataItem` to a `dict`. - - :param data_item: The `DataItem` to parse. - - :returns: The `DataItem` as a `dict`. - """ - return data_item.as_df().to_dict() - - @staticmethod - def parse_list(data_item: DataItem) -> list: - """ - Parse an MLRun `DataItem` to a `list`. - - :param data_item: The `DataItem` to parse. - - :returns: The `DataItem` as a `list`. - """ - return data_item.as_df().to_numpy().tolist() - - @staticmethod - def parse_object(data_item: DataItem) -> object: - """ - Parse an MLRun `DataItem` to its unpickled object. The pickle file will be downloaded to a local temp - directory and then loaded. - - :param data_item: The `DataItem` to parse. - - :returns: The `DataItem` as the original object that was pickled once it was logged. - """ - object_file = data_item.local() - with open(object_file, "rb") as pickle_file: - obj = cloudpickle.load(pickle_file) - return obj - - -class OutputsLogger: - """ - A static class to hold all the common logging functions - functions for logging different objects by artifact type - to MLRun. - """ - - @staticmethod - def log_dataset( - ctx: MLClientCtx, - obj: Union[pd.DataFrame, np.ndarray, pd.Series, dict, list], - key: str, - logging_kwargs: dict, - ): - """ - Log an object as a dataset. The dataset wil lbe cast to a `pandas.DataFrame`. Supporting casting from - `pandas.Series`, `numpy.ndarray`, `dict` and `list`. - - :param ctx: The MLRun context to log with. - :param obj: The data to log. - :param key: The key of the artifact. - :param logging_kwargs: Additional keyword arguments to pass to the `context.log_dataset` - - :raise MLRunInvalidArgumentError: If the type is not supported for being cast to `pandas.DataFrame`. - """ - # Check for the object type: - if not isinstance(obj, pd.DataFrame): - if isinstance(obj, (np.ndarray, pd.Series, dict, list)): - obj = pd.DataFrame(obj) - else: - raise MLRunInvalidArgumentError( - f"The value requested to be logged as a dataset artifact is of type '{type(obj)}' and it " - f"cannot be logged as a dataset. Please parse it in your code into one `numpy.ndarray`, " - f"`pandas.DataFrame`, `pandas.Series`, `dict`, `list` before returning it so we can log it." - ) - - # Log the DataFrame object as a dataset: - ctx.log_dataset(**logging_kwargs, key=key, df=obj) - - @staticmethod - def log_directory( - ctx: MLClientCtx, - obj: Union[str, Path], - key: str, - logging_kwargs: dict, - ): - """ - Log a directory as a zip file. The zip file will be created at the current working directory. Once logged, - it will be deleted. - - :param ctx: The MLRun context to log with. - :param obj: The directory to zip path. - :param key: The key of the artifact. - :param logging_kwargs: Additional keyword arguments to pass to the `context.log_artifact` method. - - :raises MLRunInvalidArgumentError: In case the given path is not of a directory or do not exist. - """ - # In case it is a `pathlib` path, parse to str: - obj = str(obj) - - # Verify the path is of an existing directory: - if not os.path.isdir(obj): - raise MLRunInvalidArgumentError( - f"The given path is not a directory: '{obj}'" - ) - if not os.path.exists(obj): - raise MLRunInvalidArgumentError( - f"The given directory path do not exist: '{obj}'" - ) - - # Zip the directory: - directory_zip_path = shutil.make_archive( - base_name=key, - format="zip", - root_dir=os.path.abspath(obj), - ) - - # Log the zip file: - ctx.log_artifact(**logging_kwargs, item=key, local_path=directory_zip_path) - - # Delete the zip file: - os.remove(directory_zip_path) - - @staticmethod - def log_file( - ctx: MLClientCtx, - obj: Union[str, Path], - key: str, - logging_kwargs: dict, - ): - """ - Log a file to MLRun. - - :param ctx: The MLRun context to log with. - :param obj: The path of the file to log. - :param key: The key of the artifact. - :param logging_kwargs: Additional keyword arguments to pass to the `context.log_artifact` method. - - :raises MLRunInvalidArgumentError: In case the given path is not of a file or do not exist. - """ - # In case it is a `pathlib` path, parse to str: - obj = str(obj) - - # Verify the path is of an existing directory: - if not os.path.isfile(obj): - raise MLRunInvalidArgumentError(f"The given path is not a file: '{obj}'") - if not os.path.exists(obj): - raise MLRunInvalidArgumentError( - f"The given directory path do not exist: '{obj}'" - ) - - # Log the zip file: - ctx.log_artifact(**logging_kwargs, item=key, local_path=os.path.abspath(obj)) - - @staticmethod - def log_object(ctx: MLClientCtx, obj, key: str, logging_kwargs: dict): - """ - Log an object as a pickle. - - :param ctx: The MLRun context to log with. - :param obj: The object to log. - :param key: The key of the artifact. - :param logging_kwargs: Additional keyword arguments to pass to the `context.log_artifact` method. - """ - ctx.log_artifact( - **logging_kwargs, - item=key, - body=obj if isinstance(obj, (bytes, bytearray)) else cloudpickle.dumps(obj), - format="pkl", - ) - - @staticmethod - def log_plot(ctx: MLClientCtx, obj, key: str, logging_kwargs: dict): - """ - Log an object as a plot. Currently, supporting plots produced by one the following modules: `matplotlib`, - `seaborn`, `plotly` and `bokeh`. - - :param ctx: The MLRun context to log with. - :param obj: The plot to log. - :param key: The key of the artifact. - :param logging_kwargs: Additional keyword arguments to pass to the `context.log_artifact`. - - :raise MLRunInvalidArgumentError: If the object type is not supported (meaning the plot was not produced by - one of the supported modules). - """ - # Create the plot artifact according to the module produced the object: - artifact = None - - # `matplotlib` and `seaborn`: - try: - import matplotlib.pyplot as plt - - from mlrun.artifacts import PlotArtifact - - # Get the figure: - figure = None - if isinstance(obj, plt.Figure): - figure = obj - elif isinstance(obj, plt.Axes): - if hasattr(obj, "get_figure"): - figure = obj.get_figure() - elif hasattr(obj, "figure"): - figure = obj.figure - elif hasattr(obj, "fig"): - figure = obj.fig - - # Create the artifact: - if figure is not None: - artifact = PlotArtifact(key=key, body=figure) - except ModuleNotFoundError: - pass - - # `plotly`: - if artifact is None: - try: - import plotly - - from mlrun.artifacts import PlotlyArtifact - - if isinstance(obj, plotly.graph_objs.Figure): - artifact = PlotlyArtifact(key=key, figure=obj) - except ModuleNotFoundError: - pass - - # `bokeh`: - if artifact is None: - try: - import bokeh.plotting as bokeh_plt - - from mlrun.artifacts import BokehArtifact - - if isinstance(obj, bokeh_plt.Figure): - artifact = BokehArtifact(key=key, figure=obj) - except ModuleNotFoundError: - pass - except ImportError: - logger.warn( - "Bokeh installation is ignored. If needed, " - "make sure you have the required version with `pip install mlrun[bokeh]`" - ) - - # Log the artifact: - if artifact is None: - raise MLRunInvalidArgumentError( - f"The given plot is of type `{type(obj)}`. We currently support logging plots produced by one of " - f"the following modules: `matplotlib`, `seaborn`, `plotly` and `bokeh`. You may try to save the " - f"plot to file and log it as a file instead." - ) - ctx.log_artifact(**logging_kwargs, item=artifact) - - @staticmethod - def log_result( - ctx: MLClientCtx, - obj: Union[int, float, str, list, tuple, dict, np.ndarray], - key: str, - logging_kwargs: dict, - ): - """ - Log an object as a result. The objects value will be cast to a serializable version of itself. Supporting: - int, float, str, list, tuple, dict, numpy.ndarray - - :param ctx: The MLRun context to log with. - :param obj: The value to log. - :param key: The key of the artifact. - :param logging_kwargs: Additional keyword arguments to pass to the `context.log_result` method. - """ - ctx.log_result(**logging_kwargs, key=key, value=obj) - - -class ContextHandler: - """ - Private class for handling an MLRun context of a function that is wrapped in MLRun's `handler` decorator. - - The context handler have 3 duties: - 1. Check if the user used MLRun to run the wrapped function and if so, get the MLRun context. - 2. Parse the user's inputs (MLRun `DataItem`) to the function. - 3. Log the function's outputs to MLRun. - - The context handler use dictionaries to map objects to their logging / parsing function. The maps can be edited - using the relevant `update_X` class method. If needed to add additional artifacts types, the `ArtifactType` class - can be inherited and replaced as well using the `update_artifact_type_class` class method. - """ - - # The artifact type enum class to use: - _ARTIFACT_TYPE_CLASS = ArtifactType - # The map to use to get default artifact types of objects: - _DEFAULT_OBJECTS_ARTIFACT_TYPES_MAP = None - # The map to use for logging an object by its type: - _OUTPUTS_LOGGING_MAP = None - # The map to use for parsing an object by its type: - _INPUTS_PARSING_MAP = None - - @classmethod - def update_artifact_type_class(cls, artifact_type_class: Type[ArtifactType]): - """ - Update the artifact type enum class that the handler will use to specify new artifact types to log and parse. - - :param artifact_type_class: An enum inheriting from the `ArtifactType` enum. - """ - cls._ARTIFACT_TYPE_CLASS = artifact_type_class - - @classmethod - def update_default_objects_artifact_types_map( - cls, updates: Dict[type, ArtifactType] - ): - """ - Enrich the default objects artifact types map with new objects types to support. - - :param updates: New objects types to artifact types to support. - """ - if cls._DEFAULT_OBJECTS_ARTIFACT_TYPES_MAP is None: - cls._init_default_objects_artifact_types_map() - cls._DEFAULT_OBJECTS_ARTIFACT_TYPES_MAP.update(updates) - - @classmethod - def update_outputs_logging_map( - cls, - updates: Dict[ArtifactType, Callable[[MLClientCtx, Any, str, dict], None]], - ): - """ - Enrich the outputs logging map with new artifact types to support. The outputs logging map is a dictionary of - artifact type enum as key, and a function that will handle the given output. The function must accept 4 keyword - arguments - - * ctx: `mlrun.MLClientCtx` - The MLRun context to log with. - * obj: `Any` - The value / object to log. - * key: `str` - The key of the artifact. - * logging_kwargs: `dict` - Keyword arguments the user can pass in the instructions tuple. - - :param updates: New artifact types to support - a dictionary of artifact type enum as key, and a function that - will handle the given output to update the current map. - """ - if cls._OUTPUTS_LOGGING_MAP is None: - cls._init_outputs_logging_map() - cls._OUTPUTS_LOGGING_MAP.update(updates) - - @classmethod - def update_inputs_parsing_map(cls, updates: Dict[type, Callable[[DataItem], Any]]): - """ - Enrich the inputs parsing map with new objects to support. The inputs parsing map is a dictionary of object - types as key, and a function that will handle the given input. The function must accept 1 keyword argument - (data_item: `mlrun.DataItem`) and return the relevant parsed object. - - :param updates: New object types to support - a dictionary of artifact type enum as key, and a function that - will handle the given input to update the current map. - """ - if cls._INPUTS_PARSING_MAP is None: - cls._init_inputs_parsing_map() - cls._INPUTS_PARSING_MAP.update(updates) - - def __init__(self): - """ - Initialize a context handler. - """ - # Initialize the maps: - if self._DEFAULT_OBJECTS_ARTIFACT_TYPES_MAP is None: - self._init_default_objects_artifact_types_map() - if self._OUTPUTS_LOGGING_MAP is None: - self._init_outputs_logging_map() - if self._INPUTS_PARSING_MAP is None: - self._init_inputs_parsing_map() - - # Set up a variable to hold the context: - self._context: MLClientCtx = None - - def look_for_context(self, args: tuple, kwargs: dict): - """ - Look for an MLRun context (`mlrun.MLClientCtx`). The handler will look for a context in the given order: - 1. The given arguments. - 2. The given keyword arguments. - 3. If an MLRun RunTime was used the context will be located via the `mlrun.get_or_create_ctx` method. - - :param args: The arguments tuple passed to the function. - :param kwargs: The keyword arguments dictionary passed to the function. - """ - # Search in the given arguments: - for argument in args: - if isinstance(argument, MLClientCtx): - self._context = argument - return - - # Search in the given keyword arguments: - for argument_name, argument_value in kwargs.items(): - if isinstance(argument_value, MLClientCtx): - self._context = argument_value - return - - # Search if the function was triggered from an MLRun RunTime object by looking at the call stack: - # Index 0: the current frame. - # Index 1: the decorator's frame. - # Index 2-...: If it is from mlrun.runtimes we can be sure it ran via MLRun, otherwise not. - for callstack_frame in inspect.getouterframes(inspect.currentframe()): - if os.path.join("mlrun", "runtimes", "local") in callstack_frame.filename: - import mlrun - - self._context = mlrun.get_or_create_ctx("context") - break - - def is_context_available(self) -> bool: - """ - Check if a context was found by the method `look_for_context`. - - :returns: True if a context was found and False otherwise. - """ - return self._context is not None - - def parse_inputs( - self, - args: tuple, - kwargs: dict, - type_hints: OrderedDict, - ) -> tuple: - """ - Parse the given arguments and keyword arguments data items to the expected types. - - :param args: The arguments tuple passed to the function. - :param kwargs: The keyword arguments dictionary passed to the function. - :param type_hints: An ordered dictionary of the expected types of arguments. - - :returns: The parsed args (kwargs are parsed inplace). - """ - # Parse the arguments: - parsed_args = [] - type_hints_keys = list(type_hints.keys()) - for i, argument in enumerate(args): - if ( - isinstance(argument, DataItem) - and type_hints[type_hints_keys[i]] != inspect._empty - ): - parsed_args.append( - self._parse_input( - data_item=argument, - type_hint=type_hints[type_hints_keys[i]], - ) - ) - continue - parsed_args.append(argument) - parsed_args = tuple(parsed_args) # `args` is expected to be a tuple. - - # Parse the keyword arguments: - for key in kwargs.keys(): - if isinstance(kwargs[key], DataItem) and type_hints[key] not in [ - inspect._empty, - DataItem, - ]: - kwargs[key] = self._parse_input( - data_item=kwargs[key], type_hint=type_hints[key] - ) - - return parsed_args - - def log_outputs( - self, - outputs: list, - log_hints: List[Union[Dict[str, str], None]], - ): - """ - Log the given outputs as artifacts with the stored context. - - :param outputs: List of outputs to log. - :param log_hints: List of logging configurations to use. - """ - for obj, log_hint in zip(outputs, log_hints): - # Check if needed to log (not None): - if log_hint is None: - continue - # Parse the instructions: - artifact_type = self._DEFAULT_OBJECTS_ARTIFACT_TYPES_MAP.get( - type(obj), self._ARTIFACT_TYPE_CLASS.DEFAULT - ).value - key = log_hint.pop("key") - artifact_type = log_hint.pop("artifact_type", artifact_type) - # Check if the object to log is None (None values are only logged if the artifact type is Result): - if obj is None and artifact_type != ArtifactType.RESULT.value: - continue - # Log: - self._log_output( - obj=obj, - artifact_type=artifact_type, - key=key, - logging_kwargs=log_hint, - ) - - def set_labels(self, labels: Dict[str, str]): - """ - Set the given labels with the stored context. - - :param labels: The labels to set. - """ - for key, value in labels.items(): - self._context.set_label(key=key, value=value) - - @classmethod - def _init_default_objects_artifact_types_map(cls): - """ - Initialize the default objects artifact types map with the basic classes supported by MLRun. In addition, it - will try to support further common packages that are not required in MLRun. - """ - # Initialize the map with the default classes: - cls._DEFAULT_OBJECTS_ARTIFACT_TYPES_MAP = { - pd.DataFrame: ArtifactType.DATASET, - pd.Series: ArtifactType.DATASET, - np.ndarray: ArtifactType.DATASET, - dict: ArtifactType.RESULT, - list: ArtifactType.RESULT, - tuple: ArtifactType.RESULT, - str: ArtifactType.RESULT, - int: ArtifactType.RESULT, - float: ArtifactType.RESULT, - bytes: ArtifactType.OBJECT, - bytearray: ArtifactType.OBJECT, - } - - # Try to enrich it with further classes according ot the user's environment: - try: - import matplotlib.pyplot as plt - - cls._DEFAULT_OBJECTS_ARTIFACT_TYPES_MAP[plt.Figure] = ArtifactType.PLOT - cls._DEFAULT_OBJECTS_ARTIFACT_TYPES_MAP[plt.Axes] = ArtifactType.PLOT - except ModuleNotFoundError: - pass - try: - import plotly - - cls._DEFAULT_OBJECTS_ARTIFACT_TYPES_MAP[ - plotly.graph_objs.Figure - ] = ArtifactType.PLOT - except ModuleNotFoundError: - pass - try: - import bokeh.plotting as bokeh_plt - - cls._DEFAULT_OBJECTS_ARTIFACT_TYPES_MAP[ - bokeh_plt.Figure - ] = ArtifactType.PLOT - except ModuleNotFoundError: - pass - except ImportError: - logger.warn( - "Bokeh installation is ignored. If needed, " - "make sure you have the required version with `pip install mlrun[bokeh]`" - ) - - @classmethod - def _init_outputs_logging_map(cls): - """ - Initialize the outputs logging map for the basic artifact types supported by MLRun. - """ - cls._OUTPUTS_LOGGING_MAP = { - ArtifactType.DATASET: OutputsLogger.log_dataset, - ArtifactType.DIRECTORY: OutputsLogger.log_directory, - ArtifactType.FILE: OutputsLogger.log_file, - ArtifactType.OBJECT: OutputsLogger.log_object, - ArtifactType.PLOT: OutputsLogger.log_plot, - ArtifactType.RESULT: OutputsLogger.log_result, - } - - @classmethod - def _init_inputs_parsing_map(cls): - """ - Initialize the inputs parsing map with the basic classes supported by MLRun. - """ - cls._INPUTS_PARSING_MAP = { - pd.DataFrame: InputsParser.parse_pandas_dataframe, - np.ndarray: InputsParser.parse_numpy_array, - dict: InputsParser.parse_dict, - list: InputsParser.parse_list, - object: InputsParser.parse_object, - } - - def _parse_input(self, data_item: DataItem, type_hint: type) -> Any: - """ - Parse the given data frame to the expected type. By default, it will be parsed to an object (will be treated as - a pickle). - - :param data_item: The data item to parse. - :param type_hint: The expected type to parse to. - - :returns: The parsed data item. - - :raises MLRunRuntimeError: If an error was raised during the parsing function. - """ - if str(type_hint).startswith("typing."): - return data_item - try: - return self._INPUTS_PARSING_MAP.get( - type_hint, self._INPUTS_PARSING_MAP[object] - )(data_item=data_item) - except Exception as exception: - raise MLRunRuntimeError( - f"MLRun tried to parse a `DataItem` of type '{type_hint}' but failed. Be sure the item was " - f"logged correctly - as the type you are trying to parse it back to. In general, python objects should " - f"be logged as pickles." - ) from exception - - def _log_output( - self, - obj, - artifact_type: Union[ArtifactType, str], - key: str, - logging_kwargs: Dict[str, Any], - ): - """ - Log the given object to MLRun as the given artifact type with the provided key. The key can be part of a - logging keyword arguments to pass to the relevant context logging function. - - :param obj: The object to log. - :param artifact_type: The artifact type to log the object as. - :param key: The key (name) of the artifact or a logging kwargs to use when logging the artifact. - - :raises MLRunInvalidArgumentError: If a key was provided in the logging kwargs. - :raises MLRunRuntimeError: If an error was raised during the logging function. - """ - # Get the artifact type (will also verify the artifact type is valid): - artifact_type = self._ARTIFACT_TYPE_CLASS(artifact_type) - - # Check if 'key' or 'item' were given the logging kwargs: - if "key" in logging_kwargs or "item" in logging_kwargs: - raise MLRunInvalidArgumentError( - "When passing logging keyword arguments, both 'key' and 'item' (according to the context method) " - "cannot be added to the dictionary as the key is given on its own." - ) - - # Use the logging map to log the object: - try: - self._OUTPUTS_LOGGING_MAP[artifact_type]( - ctx=self._context, - obj=obj, - key=key, - logging_kwargs=logging_kwargs, - ) - except Exception as exception: - raise MLRunRuntimeError( - f"MLRun tried to log '{key}' as '{artifact_type.value}' but failed. If you didn't provide the artifact " - f"type and the default one does not fit, try to select the correct type from the enum `ArtifactType`." - ) from exception diff --git a/mlrun/runtimes/pod.py b/mlrun/runtimes/pod.py index 97cde8af0f10..271d130298d1 100644 --- a/mlrun/runtimes/pod.py +++ b/mlrun/runtimes/pod.py @@ -24,12 +24,12 @@ import mlrun.errors import mlrun.utils.regex - -from ..api.schemas import ( +from mlrun.common.schemas import ( NodeSelectorOperator, PreemptionModes, SecurityContextEnrichmentModes, ) + from ..config import config as mlconf from ..k8s_utils import ( generate_preemptible_node_selector_requirements, @@ -44,6 +44,7 @@ apply_kfp, get_gpu_from_resource_requirement, get_item_name, + get_k8s, get_resource_labels, set_named_item, verify_limits, @@ -135,6 +136,7 @@ def __init__( tolerations=None, preemption_mode=None, security_context=None, + clone_target_dir=None, ): super().__init__( command=command, @@ -148,6 +150,7 @@ def __init__( default_handler=default_handler, pythonpath=pythonpath, disable_auto_mount=disable_auto_mount, + clone_target_dir=clone_target_dir, ) self._volumes = {} self._volume_mounts = {} @@ -989,12 +992,14 @@ def set_envs(self, env_vars: dict = None, file_path: str = None): "must specify env_vars OR file_path" ) if file_path: - env_vars = dotenv.dotenv_values(file_path) - if None in env_vars.values(): - raise mlrun.errors.MLRunInvalidArgumentError( - "env file lines must be in the form key=value" - ) - + if os.path.isfile(file_path): + env_vars = dotenv.dotenv_values(file_path) + if None in env_vars.values(): + raise mlrun.errors.MLRunInvalidArgumentError( + "env file lines must be in the form key=value" + ) + else: + raise mlrun.errors.MLRunNotFoundError(f"{file_path} does not exist") for name, value in env_vars.items(): self.set_env(name, value) return self @@ -1113,7 +1118,7 @@ def with_preemption_mode(self, mode: typing.Union[PreemptionModes, str]): The default preemption mode is configurable in mlrun.mlconf.function_defaults.preemption_mode, by default it's set to **prevent** - :param mode: allow | constrain | prevent | none defined in :py:class:`~mlrun.api.schemas.PreemptionModes` + :param mode: allow | constrain | prevent | none defined in :py:class:`~mlrun.common.schemas.PreemptionModes` """ preemptible_mode = PreemptionModes(mode) self.spec.preemption_mode = preemptible_mode.value @@ -1122,7 +1127,7 @@ def with_security_context(self, security_context: k8s_client.V1SecurityContext): """ Set security context for the pod. For Iguazio we handle security context internally - - see mlrun.api.schemas.function.SecurityContextEnrichmentModes + see mlrun.common.schemas.function.SecurityContextEnrichmentModes Example: @@ -1155,7 +1160,7 @@ def get_default_priority_class_name(self): return mlconf.default_function_priority_class_name def _get_meta(self, runobj, unique=False): - namespace = self._get_k8s().resolve_namespace() + namespace = get_k8s().resolve_namespace() labels = get_resource_labels(self, runobj, runobj.spec.scrape_metrics) new_meta = k8s_client.V1ObjectMeta( @@ -1223,7 +1228,7 @@ def _add_k8s_secrets_to_spec( mlconf.secret_stores.kubernetes.global_function_env_secret_name ) if mlrun.config.is_running_as_api() and global_secret_name: - global_secrets = self._get_k8s().get_secret_data(global_secret_name) + global_secrets = get_k8s().get_secret_data(global_secret_name) for key, value in global_secrets.items(): env_var_name = ( SecretsStore.k8s_env_variable_name_for_secret(key) @@ -1245,10 +1250,10 @@ def _add_k8s_secrets_to_spec( logger.warning("No project provided. Cannot add k8s secrets") return - secret_name = self._get_k8s().get_project_secret_name(project_name) + secret_name = get_k8s().get_project_secret_name(project_name) # Not utilizing the same functionality from the Secrets crud object because this code also runs client-side # in the nuclio remote-dashboard flow, which causes dependency problems. - existing_secret_keys = self._get_k8s().get_project_secret_keys( + existing_secret_keys = get_k8s().get_project_secret_keys( project_name, filter_internal=True ) @@ -1282,7 +1287,7 @@ def _add_vault_params_to_spec(self, runobj=None, project=None): ) ) - project_vault_secret_name = self._get_k8s().get_project_vault_secret_name( + project_vault_secret_name = get_k8s().get_project_vault_secret_name( project_name, service_account_name ) if project_vault_secret_name is None: @@ -1453,7 +1458,7 @@ def get_sanitized_attribute(spec, attribute_name: str): if isinstance(attribute, dict): if attribute_config["not_sanitized_class"] != dict: raise mlrun.errors.MLRunInvalidArgumentTypeError( - f"expected to to be of type {attribute_config.get('not_sanitized_class')} but got dict" + f"expected to be of type {attribute_config.get('not_sanitized_class')} but got dict" ) if _resolve_if_type_sanitized(attribute_name, attribute): return attribute @@ -1463,7 +1468,7 @@ def get_sanitized_attribute(spec, attribute_name: str): ): if attribute_config["not_sanitized_class"] != list: raise mlrun.errors.MLRunInvalidArgumentTypeError( - f"expected to to be of type {attribute_config.get('not_sanitized_class')} but got list" + f"expected to be of type {attribute_config.get('not_sanitized_class')} but got list" ) if _resolve_if_type_sanitized(attribute_name, attribute[0]): return attribute diff --git a/mlrun/runtimes/remotesparkjob.py b/mlrun/runtimes/remotesparkjob.py index 46731f0c47d3..c262b6059f72 100644 --- a/mlrun/runtimes/remotesparkjob.py +++ b/mlrun/runtimes/remotesparkjob.py @@ -58,6 +58,7 @@ def __init__( tolerations=None, preemption_mode=None, security_context=None, + clone_target_dir=None, ): super().__init__( command=command, @@ -86,6 +87,7 @@ def __init__( tolerations=tolerations, preemption_mode=preemption_mode, security_context=security_context, + clone_target_dir=clone_target_dir, ) self.provider = provider diff --git a/mlrun/runtimes/serving.py b/mlrun/runtimes/serving.py index 15b4e09ed724..ddc76331fabe 100644 --- a/mlrun/runtimes/serving.py +++ b/mlrun/runtimes/serving.py @@ -21,7 +21,7 @@ from nuclio import KafkaTrigger import mlrun -import mlrun.api.schemas +import mlrun.common.schemas from ..datastore import parse_kafka_url from ..model import ObjectList @@ -143,6 +143,7 @@ def __init__( security_context=None, service_type=None, add_templated_ingress_host_mode=None, + clone_target_dir=None, ): super().__init__( @@ -182,6 +183,7 @@ def __init__( security_context=security_context, service_type=service_type, add_templated_ingress_host_mode=add_templated_ingress_host_mode, + clone_target_dir=clone_target_dir, ) self.models = models or {} @@ -317,7 +319,7 @@ def set_tracking( example:: # initialize a new serving function - serving_fn = mlrun.import_function("hub://v2_model_server", new_name="serving") + serving_fn = mlrun.import_function("hub://v2-model-server", new_name="serving") # apply model monitoring and set monitoring batch job to run every 3 hours tracking_policy = {'default_batch_intervals':"0 */3 * * *"} serving_fn.set_tracking(tracking_policy=tracking_policy) @@ -583,7 +585,7 @@ def deploy( project="", tag="", verbose=False, - auth_info: mlrun.api.schemas.AuthInfo = None, + auth_info: mlrun.common.schemas.AuthInfo = None, builder_env: dict = None, ): """deploy model serving function to a local/remote cluster diff --git a/mlrun/runtimes/sparkjob/abstract.py b/mlrun/runtimes/sparkjob/abstract.py index 4aeca02bcf13..269915b7251c 100644 --- a/mlrun/runtimes/sparkjob/abstract.py +++ b/mlrun/runtimes/sparkjob/abstract.py @@ -31,7 +31,6 @@ from mlrun.runtimes.constants import RunStates, SparkApplicationStates from ...execution import MLClientCtx -from ...k8s_utils import get_k8s_helper from ...model import RunObject from ...platforms.iguazio import mount_v3io, mount_v3iod from ...utils import ( @@ -45,7 +44,7 @@ from ..base import RunError, RuntimeClassMode from ..kubejob import KubejobRuntime from ..pod import KubeResourceSpec -from ..utils import get_item_name +from ..utils import get_item_name, get_k8s _service_account = "sparkapp" _sparkjob_template = { @@ -143,6 +142,7 @@ def __init__( tolerations=None, preemption_mode=None, security_context=None, + clone_target_dir=None, ): super().__init__( @@ -172,6 +172,7 @@ def __init__( tolerations=tolerations, preemption_mode=preemption_mode, security_context=security_context, + clone_target_dir=clone_target_dir, ) self._driver_resources = self.enrich_resources_with_default_pod_resources( @@ -568,8 +569,10 @@ def _run(self, runobj: RunObject, execution: MLClientCtx): if self.spec.command: if "://" not in self.spec.command: + workdir = self._resolve_workdir() self.spec.command = "local://" + os.path.join( - self.spec.workdir or "", self.spec.command + workdir or "", + self.spec.command, ) update_in(job, "spec.mainApplicationFile", self.spec.command) @@ -588,7 +591,7 @@ def _submit_spark_job( code=None, ): namespace = meta.namespace - k8s = self._get_k8s() + k8s = get_k8s() namespace = k8s.resolve_namespace(namespace) if code: k8s_config_map = client.V1ConfigMap() @@ -632,7 +635,7 @@ def _submit_spark_job( raise RunError("Exception when creating SparkJob") from exc def get_job(self, name, namespace=None): - k8s = self._get_k8s() + k8s = get_k8s() namespace = k8s.resolve_namespace(namespace) try: resp = k8s.crdapi.get_namespaced_custom_object( @@ -801,43 +804,26 @@ def with_restart_policy( ) def with_source_archive( - self, source, workdir=None, handler=None, pull_at_runtime=True + self, source, workdir=None, handler=None, pull_at_runtime=True, target_dir=None ): """load the code from git/tar/zip archive at runtime or build - :param source: valid path to git, zip, or tar file, e.g. - git://github.com/mlrun/something.git - http://some/url/file.zip - :param handler: default function handler - :param workdir: working dir relative to the archive root or absolute (e.g. './subdir') + :param source: valid path to git, zip, or tar file, e.g. + git://github.com/mlrun/something.git + http://some/url/file.zip + :param handler: default function handler + :param workdir: working dir relative to the archive root (e.g. './subdir') or absolute to the image root :param pull_at_runtime: not supported for spark runtime, must be False + :param target_dir: target dir on runtime pod for repo clone / archive extraction """ if pull_at_runtime: raise mlrun.errors.MLRunInvalidArgumentError( "pull_at_runtime is not supported for spark runtime, use pull_at_runtime=False" ) - super().with_source_archive(source, workdir, handler, pull_at_runtime) - - def get_pods(self, name=None, namespace=None, driver=False): - k8s = self._get_k8s() - namespace = k8s.resolve_namespace(namespace) - selector = "mlrun/class=spark" - if name: - selector += f",sparkoperator.k8s.io/app-name={name}" - if driver: - selector += ",spark-role=driver" - pods = k8s.list_pods(selector=selector, namespace=namespace) - if pods: - return {p.metadata.name: p.status.phase for p in pods} - - def _get_driver(self, name, namespace=None): - pods = self.get_pods(name, namespace, driver=True) - if not pods: - logger.error("no pod matches that job name") - return - _ = self._get_k8s() - return list(pods.items())[0] + super().with_source_archive( + source, workdir, handler, pull_at_runtime, target_dir + ) def is_deployed(self): if ( @@ -961,15 +947,14 @@ def _delete_extra_resources( uid = crd_dict["metadata"].get("labels", {}).get("mlrun/uid", None) uids.append(uid) - k8s_helper = get_k8s_helper() - config_maps = k8s_helper.v1api.list_namespaced_config_map( + config_maps = get_k8s().v1api.list_namespaced_config_map( namespace, label_selector=label_selector ) for config_map in config_maps.items: try: uid = config_map.metadata.labels.get("mlrun/uid", None) if force or uid in uids: - k8s_helper.v1api.delete_namespaced_config_map( + get_k8s().v1api.delete_namespaced_config_map( config_map.metadata.name, namespace ) logger.info(f"Deleted config map: {config_map.metadata.name}") diff --git a/mlrun/runtimes/sparkjob/spark3job.py b/mlrun/runtimes/sparkjob/spark3job.py index 0f9e9c5a6588..0500937e1244 100644 --- a/mlrun/runtimes/sparkjob/spark3job.py +++ b/mlrun/runtimes/sparkjob/spark3job.py @@ -16,7 +16,7 @@ import kubernetes.client -import mlrun.api.schemas.function +import mlrun.common.schemas.function import mlrun.errors import mlrun.runtimes.pod @@ -100,6 +100,7 @@ def __init__( driver_cores=None, executor_cores=None, security_context=None, + clone_target_dir=None, ): super().__init__( @@ -129,6 +130,7 @@ def __init__( tolerations=tolerations, preemption_mode=preemption_mode, security_context=security_context, + clone_target_dir=clone_target_dir, ) self.driver_resources = driver_resources or {} @@ -518,7 +520,7 @@ def with_executor_node_selection( self.spec.executor_tolerations = tolerations def with_preemption_mode( - self, mode: typing.Union[mlrun.api.schemas.function.PreemptionModes, str] + self, mode: typing.Union[mlrun.common.schemas.function.PreemptionModes, str] ): """ Use with_driver_preemption_mode / with_executor_preemption_mode to setup preemption_mode for spark operator @@ -529,7 +531,7 @@ def with_preemption_mode( ) def with_driver_preemption_mode( - self, mode: typing.Union[mlrun.api.schemas.function.PreemptionModes, str] + self, mode: typing.Union[mlrun.common.schemas.function.PreemptionModes, str] ): """ Preemption mode controls whether the spark driver can be scheduled on preemptible nodes. @@ -545,13 +547,13 @@ def with_driver_preemption_mode( The default preemption mode is configurable in mlrun.mlconf.function_defaults.preemption_mode, by default it's set to **prevent** - :param mode: allow | constrain | prevent | none defined in :py:class:`~mlrun.api.schemas.PreemptionModes` + :param mode: allow | constrain | prevent | none defined in :py:class:`~mlrun.common.schemas.PreemptionModes` """ - preemption_mode = mlrun.api.schemas.function.PreemptionModes(mode) + preemption_mode = mlrun.common.schemas.function.PreemptionModes(mode) self.spec.driver_preemption_mode = preemption_mode.value def with_executor_preemption_mode( - self, mode: typing.Union[mlrun.api.schemas.function.PreemptionModes, str] + self, mode: typing.Union[mlrun.common.schemas.function.PreemptionModes, str] ): """ Preemption mode controls whether the spark executor can be scheduled on preemptible nodes. @@ -567,9 +569,9 @@ def with_executor_preemption_mode( The default preemption mode is configurable in mlrun.mlconf.function_defaults.preemption_mode, by default it's set to **prevent** - :param mode: allow | constrain | prevent | none defined in :py:class:`~mlrun.api.schemas.PreemptionModes` + :param mode: allow | constrain | prevent | none defined in :py:class:`~mlrun.common.schemas.PreemptionModes` """ - preemption_mode = mlrun.api.schemas.function.PreemptionModes(mode) + preemption_mode = mlrun.common.schemas.function.PreemptionModes(mode) self.spec.executor_preemption_mode = preemption_mode.value def with_security_context( diff --git a/mlrun/runtimes/utils.py b/mlrun/runtimes/utils.py index 2a02125b1372..372d7a365c8e 100644 --- a/mlrun/runtimes/utils.py +++ b/mlrun/runtimes/utils.py @@ -24,17 +24,17 @@ from kubernetes import client import mlrun -import mlrun.builder +import mlrun.api.utils.builder +import mlrun.common.constants import mlrun.utils.regex from mlrun.api.utils.clients import nuclio -from mlrun.db import get_run_db from mlrun.errors import err_to_str from mlrun.frameworks.parallel_coordinates import gen_pcp_plot -from mlrun.k8s_utils import get_k8s_helper from mlrun.runtimes.constants import MPIJobCRDVersions from ..artifacts import TableArtifact -from ..config import config +from ..config import config, is_running_as_api +from ..k8s_utils import is_running_inside_kubernetes_cluster from ..utils import get_in, helpers, logger, verify_field_regex from .generators import selector @@ -69,19 +69,21 @@ def set(self, context): # if not specified, try resolving it according to the mpi-operator, otherwise set to default # since this is a heavy operation (sending requests to k8s/API), and it's unlikely that the crd version # will change in any context - cache it -def resolve_mpijob_crd_version(api_context=False): +def resolve_mpijob_crd_version(): global cached_mpijob_crd_version if not cached_mpijob_crd_version: # config override everything + # on client side, expecting it to get enriched from the API through the client-spec mpijob_crd_version = config.mpijob_crd_version if not mpijob_crd_version: - in_k8s_cluster = get_k8s_helper( - silent=True - ).is_running_inside_kubernetes_cluster() - if in_k8s_cluster: - k8s_helper = get_k8s_helper() + in_k8s_cluster = is_running_inside_kubernetes_cluster() + + if in_k8s_cluster and is_running_as_api(): + import mlrun.api.utils.singletons.k8s + + k8s_helper = mlrun.api.utils.singletons.k8s.get_k8s_helper() namespace = k8s_helper.resolve_namespace() # try resolving according to mpi-operator that's running @@ -93,13 +95,8 @@ def resolve_mpijob_crd_version(api_context=False): mpijob_crd_version = mpi_operator_pod.metadata.labels.get( "crd-version" ) - elif not in_k8s_cluster and not api_context: - # connect will populate the config from the server config - # TODO: something nicer - get_run_db() - mpijob_crd_version = config.mpijob_crd_version - # If resolution failed simply use default + # backoff to use default if wasn't resolved in API if not mpijob_crd_version: mpijob_crd_version = MPIJobCRDVersions.default() @@ -182,22 +179,6 @@ def log_std(db, runobj, out, err="", skip=False, show=True, silent=False): raise RunError(err) -class AsyncLogWriter: - def __init__(self, db, runobj): - self.db = db - self.uid = runobj.metadata.uid - self.project = runobj.metadata.project or "" - self.iter = runobj.metadata.iteration - - def write(self, data): - if self.db: - self.db.store_log(self.uid, self.project, data, append=True) - - def flush(self): - # todo: verify writes are large enough, if not cache and use flush - pass - - def add_code_metadata(path=""): if path: if "://" in path: @@ -232,6 +213,19 @@ def add_code_metadata(path=""): return None +def get_k8s(): + """ + Get the k8s helper object + :return: k8s helper object or None if not running as API + """ + if is_running_as_api(): + import mlrun.api.utils.singletons.k8s + + return mlrun.api.utils.singletons.k8s.get_k8s_helper() + + return None + + def set_if_none(struct, key, value): if not struct.get(key): struct[key] = value @@ -348,7 +342,11 @@ def generate_function_image_name(project: str, name: str, tag: str) -> str: _, repository = helpers.get_parsed_docker_registry() repository = helpers.get_docker_repository_or_default(repository) return fill_function_image_name_template( - mlrun.builder.IMAGE_NAME_ENRICH_REGISTRY_PREFIX, repository, project, name, tag + mlrun.common.constants.IMAGE_NAME_ENRICH_REGISTRY_PREFIX, + repository, + project, + name, + tag, ) @@ -373,7 +371,7 @@ def resolve_function_target_image_registries_to_enforce_prefix(): registry, repository = helpers.get_parsed_docker_registry() repository = helpers.get_docker_repository_or_default(repository) return [ - f"{mlrun.builder.IMAGE_NAME_ENRICH_REGISTRY_PREFIX}{repository}/", + f"{mlrun.common.constants.IMAGE_NAME_ENRICH_REGISTRY_PREFIX}{repository}/", f"{registry}/{repository}/", ] diff --git a/mlrun/secrets.py b/mlrun/secrets.py index 89237010ad41..eff375ff0650 100644 --- a/mlrun/secrets.py +++ b/mlrun/secrets.py @@ -16,7 +16,7 @@ from os import environ, getenv from typing import Callable, Dict, Optional, Union -from .utils import AzureVaultStore, VaultStore, list2dict +from .utils import AzureVaultStore, list2dict class SecretsStore: @@ -26,7 +26,6 @@ def __init__(self): # for example from Vault, and when adding their source they will be retrieved from the external source. self._hidden_sources = [] self._hidden_secrets = {} - self.vault = VaultStore() @classmethod def from_list(cls, src_list: list): @@ -60,21 +59,20 @@ def add_source(self, kind, source="", prefix=""): for key in source.split(","): k = key.strip() self._secrets[prefix + k] = environ.get(k) - - elif kind == "vault": - if isinstance(source, str): - source = literal_eval(source) - if not isinstance(source, dict): - raise ValueError("vault secrets must be of type dict") - - for key, value in self.vault.get_secrets( - source["secrets"], - user=source.get("user"), - project=source.get("project"), - ).items(): - self._hidden_secrets[prefix + key] = value - self._hidden_sources.append({"kind": kind, "source": source}) - + # TODO: Vault: uncomment when vault returns to be relevant + # elif kind == "vault": + # if isinstance(source, str): + # source = literal_eval(source) + # if not isinstance(source, dict): + # raise ValueError("vault secrets must be of type dict") + # + # for key, value in self.vault.get_secrets( + # source["secrets"], + # user=source.get("user"), + # project=source.get("project"), + # ).items(): + # self._hidden_secrets[prefix + key] = value + # self._hidden_sources.append({"kind": kind, "source": source}) elif kind == "azure_vault": if isinstance(source, str): source = literal_eval(source) diff --git a/mlrun/serving/__init__.py b/mlrun/serving/__init__.py index 5a9f14f22f7d..24f02af98f21 100644 --- a/mlrun/serving/__init__.py +++ b/mlrun/serving/__init__.py @@ -21,10 +21,11 @@ "TaskStep", "RouterStep", "QueueStep", + "ErrorStep", ] from .routers import ModelRouter, VotingEnsemble # noqa from .server import GraphContext, GraphServer, create_graph_server # noqa -from .states import QueueStep, RouterStep, TaskStep # noqa +from .states import ErrorStep, QueueStep, RouterStep, TaskStep # noqa from .v1_serving import MLModelServer, new_v1_model_server # noqa from .v2_serving import V2ModelServer # noqa diff --git a/mlrun/serving/routers.py b/mlrun/serving/routers.py index 1685387b0062..b870e4bda203 100644 --- a/mlrun/serving/routers.py +++ b/mlrun/serving/routers.py @@ -24,18 +24,12 @@ import numpy as np import mlrun +import mlrun.common.model_monitoring +import mlrun.common.schemas import mlrun.utils.model_monitoring from mlrun.utils import logger, now_date, parse_versioned_object_uri -from ..api.schemas import ( - ModelEndpoint, - ModelEndpointMetadata, - ModelEndpointSpec, - ModelEndpointStatus, - ModelMonitoringMode, -) from ..config import config -from ..utils.model_monitoring import EndpointType from .server import GraphServer from .utils import RouterToDict, _extract_input_data, _update_result_body from .v2_serving import _ModelLogPusher @@ -402,12 +396,14 @@ def _init_pool( step._parent = None if step._object: step._object.context = None + if hasattr(step._object, "_kwargs"): + step._object._kwargs["graph_step"] = None routes[key] = step executor_class = concurrent.futures.ProcessPoolExecutor self._pool = executor_class( max_workers=len(self.routes), initializer=ParallelRun.init_pool, - initargs=(server, routes, id(self)), + initargs=(server, routes), ) elif self.executor_type == ParallelRunnerModes.thread: executor_class = concurrent.futures.ThreadPoolExecutor @@ -422,7 +418,7 @@ def _shutdown_pool(self): if self._pool is not None: if self.executor_type == ParallelRunnerModes.process: global local_routes - local_routes.pop(id(self)) + del local_routes self._pool.shutdown() self._pool = None @@ -446,7 +442,7 @@ def _parallel_run(self, event: dict): for route in self.routes.keys(): if self.executor_type == ParallelRunnerModes.process: future = executor.submit( - ParallelRun._wrap_step, route, id(self), copy.copy(event) + ParallelRun._wrap_step, route, copy.copy(event) ) elif self.executor_type == ParallelRunnerModes.thread: step = self.routes[route] @@ -470,25 +466,22 @@ def _parallel_run(self, event: dict): return results @staticmethod - def init_pool(server_spec, routes, object_id): + def init_pool(server_spec, routes): server = mlrun.serving.GraphServer.from_dict(server_spec) server.init_states(None, None) global local_routes - if object_id in local_routes: - return for route in routes.values(): route.context = server.context if route._object: route._object.context = server.context - local_routes[object_id] = routes + local_routes = routes @staticmethod - def _wrap_step(route, object_id, event): + def _wrap_step(route, event): global local_routes - routes = local_routes.get(object_id, None).copy() - if routes is None: + if local_routes is None: return None, None - return route, routes[route].run(event) + return route, local_routes[route].run(event) @staticmethod def _wrap_method(route, handler, event): @@ -1043,7 +1036,7 @@ def _init_endpoint_record( versioned_model_name = f"{voting_ensemble.name}:latest" # Generating model endpoint ID based on function uri and model version - endpoint_uid = mlrun.utils.model_monitoring.create_model_endpoint_id( + endpoint_uid = mlrun.common.model_monitoring.create_model_endpoint_uid( function_uri=graph_server.function_uri, versioned_model=versioned_model_name ).uid @@ -1061,33 +1054,35 @@ def _init_endpoint_record( if hasattr(c, "endpoint_uid"): children_uids.append(c.endpoint_uid) - model_endpoint = ModelEndpoint( - metadata=ModelEndpointMetadata(project=project, uid=endpoint_uid), - spec=ModelEndpointSpec( - function_uri=graph_server.function_uri, - model=versioned_model_name, - model_class=voting_ensemble.__class__.__name__, - stream_path=config.model_endpoint_monitoring.store_prefixes.default.format( - project=project, kind="stream" - ), - active=True, - monitoring_mode=ModelMonitoringMode.enabled - if voting_ensemble.context.server.track_models - else ModelMonitoringMode.disabled, - ), - status=ModelEndpointStatus( - children=list(voting_ensemble.routes.keys()), - endpoint_type=EndpointType.ROUTER, - children_uids=children_uids, + model_endpoint = mlrun.common.schemas.ModelEndpoint( + metadata=mlrun.common.schemas.ModelEndpointMetadata( + project=project, uid=endpoint_uid + ), + spec=mlrun.common.schemas.ModelEndpointSpec( + function_uri=graph_server.function_uri, + model=versioned_model_name, + model_class=voting_ensemble.__class__.__name__, + stream_path=config.model_endpoint_monitoring.store_prefixes.default.format( + project=project, kind="stream" ), - ) + active=True, + monitoring_mode=mlrun.common.model_monitoring.ModelMonitoringMode.enabled + if voting_ensemble.context.server.track_models + else mlrun.common.model_monitoring.ModelMonitoringMode.disabled, + ), + status=mlrun.common.schemas.ModelEndpointStatus( + children=list(voting_ensemble.routes.keys()), + endpoint_type=mlrun.common.model_monitoring.EndpointType.ROUTER, + children_uids=children_uids, + ), + ) db = mlrun.get_run_db() db.create_model_endpoint( project=project, endpoint_id=model_endpoint.metadata.uid, - model_endpoint=model_endpoint, + model_endpoint=model_endpoint.dict(), ) # Update model endpoint children type @@ -1095,7 +1090,9 @@ def _init_endpoint_record( current_endpoint = db.get_model_endpoint( project=project, endpoint_id=model_endpoint ) - current_endpoint.status.endpoint_type = EndpointType.LEAF_EP + current_endpoint.status.endpoint_type = ( + mlrun.common.model_monitoring.EndpointType.LEAF_EP + ) db.create_model_endpoint( project=project, endpoint_id=model_endpoint, diff --git a/mlrun/serving/server.py b/mlrun/serving/server.py index 1bd6e048881b..0f76335cd741 100644 --- a/mlrun/serving/server.py +++ b/mlrun/serving/server.py @@ -18,12 +18,13 @@ import json import os import socket -import sys import traceback import uuid from typing import Optional, Union import mlrun +import mlrun.utils.model_monitoring +from mlrun.common.model_monitoring import FileTargetKind from mlrun.config import config from mlrun.errors import err_to_str from mlrun.secrets import SecretsStore @@ -32,38 +33,52 @@ from ..datastore.store_resources import ResourceCache from ..errors import MLRunInvalidArgumentError from ..model import ModelObj -from ..utils import create_logger, get_caller_globals, parse_versioned_object_uri +from ..utils import get_caller_globals, parse_versioned_object_uri from .states import RootFlowStep, RouterStep, get_function, graph_root_setter -from .utils import event_id_key, event_path_key +from .utils import ( + event_id_key, + event_path_key, + legacy_event_id_key, + legacy_event_path_key, +) class _StreamContext: - def __init__(self, enabled, parameters, function_uri): + """Handles the stream context for the events stream process. Includes the configuration for the output stream + that will be used for pushing the events from the nuclio model serving function""" + + def __init__(self, enabled: bool, parameters: dict, function_uri: str): + + """ + Initialize _StreamContext object. + :param enabled: A boolean indication for applying the stream context + :param parameters: Dictionary of optional parameters, such as `log_stream` and `stream_args`. Note that these + parameters might be relevant to the output source such as `kafka_bootstrap_servers` if + the output source is from type Kafka. + :param function_uri: Full value of the function uri, usually it's / + """ + self.enabled = False self.hostname = socket.gethostname() self.function_uri = function_uri self.output_stream = None self.stream_uri = None + log_stream = parameters.get(FileTargetKind.LOG_STREAM, "") - log_stream = parameters.get("log_stream", "") - stream_uri = config.model_endpoint_monitoring.store_prefixes.default - - if ((enabled and stream_uri) or log_stream) and function_uri: + if (enabled or log_stream) and function_uri: self.enabled = True - project, _, _, _ = parse_versioned_object_uri( function_uri, config.default_project ) - stream_uri = stream_uri.format(project=project, kind="stream") + stream_uri = mlrun.utils.model_monitoring.get_stream_path(project=project) if log_stream: + # Update the stream path to the log stream value stream_uri = log_stream.format(project=project) stream_args = parameters.get("stream_args", {}) - self.stream_uri = stream_uri - self.output_stream = get_stream_pusher(stream_uri, **stream_args) @@ -241,10 +256,18 @@ def run(self, event, context=None, get_body=False, extra_args=None): context = context or server_context event.content_type = event.content_type or self.default_content_type or "" if event.headers: - if event_id_key in event.headers: - event.id = event.headers.get(event_id_key) - if event_path_key in event.headers: - event.path = event.headers.get(event_path_key) + # TODO: remove old event id and path keys in 1.6.0 + if event_id_key in event.headers or legacy_event_id_key in event.headers: + event.id = event.headers.get(event_id_key) or event.headers.get( + legacy_event_id_key + ) + if ( + event_path_key in event.headers + or legacy_event_path_key in event.headers + ): + event.path = event.headers.get(event_path_key) or event.headers.get( + legacy_event_path_key + ) if isinstance(event.body, (str, bytes)) and ( not event.content_type or event.content_type in ["json", "application/json"] @@ -445,7 +468,7 @@ def __init__(self, level="info", logger=None, server=None, nuclio_context=None): self.Response = nuclio_context.Response self.worker_id = nuclio_context.worker_id elif not logger: - self.logger = create_logger(level, "human", "flow", sys.stdout) + self.logger = mlrun.utils.helpers.logger self._server = server self.current_function = None diff --git a/mlrun/serving/states.py b/mlrun/serving/states.py index 2cef1a6f00b8..46fd01a84816 100644 --- a/mlrun/serving/states.py +++ b/mlrun/serving/states.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ["TaskStep", "RouterStep", "RootFlowStep"] +__all__ = ["TaskStep", "RouterStep", "RootFlowStep", "ErrorStep"] import os import pathlib @@ -49,6 +49,7 @@ class StepKinds: queue = "queue" choice = "choice" root = "root" + error_step = "error_step" _task_step_fields = [ @@ -134,11 +135,82 @@ def after_step(self, *after, append=True): self.after.append(name) return self - def error_handler(self, step_name: str = None): - """set error handler step (on failure/raise of this step)""" - if not step_name: - raise MLRunInvalidArgumentError("Must specify step_name") - self.on_error = step_name + def error_handler( + self, + name: str = None, + class_name=None, + handler=None, + before=None, + function=None, + full_event: bool = None, + input_path: str = None, + result_path: str = None, + **class_args, + ): + """set error handler on a step or the entire graph (to be executed on failure/raise) + + When setting the error_handler on the graph object, the graph completes after the error handler execution. + + example: + in the below example, an 'error_catcher' step is set as the error_handler of the 'raise' step: + in case of error/raise in 'raise' step, the handle_error will be run. after that, + the 'echo' step will be run. + graph = function.set_topology('flow', engine='async') + graph.to(name='raise', handler='raising_step')\ + .error_handler(name='error_catcher', handler='handle_error', full_event=True, before='echo') + graph.add_step(name="echo", handler='echo', after="raise").respond() + + :param name: unique name (and path) for the error handler step, default is class name + :param class_name: class name or step object to build the step from + the error handler step is derived from task step (ie no router/queue functionally) + :param handler: class/function handler to invoke on run/event + :param before: string or list of next step(s) names that will run after this step. + the `before` param must not specify upstream steps as it will cause a loop. + if `before` is not specified, the graph will complete after the error handler execution. + :param function: function this step should run in + :param full_event: this step accepts the full event (not just the body) + :param input_path: selects the key/path in the event to use as input to the step + this requires that the event body will behave like a dict, for example: + event: {"data": {"a": 5, "b": 7}}, input_path="data.b" means the step will + receive 7 as input + :param result_path: selects the key/path in the event to write the results to + this requires that the event body will behave like a dict, for example: + event: {"x": 5} , result_path="y" means the output of the step will be written + to event["y"] resulting in {"x": 5, "y": } + :param class_args: class init arguments + + """ + if not (class_name or handler): + raise MLRunInvalidArgumentError("class_name or handler must be provided") + if isinstance(self, RootFlowStep) and before: + raise MLRunInvalidArgumentError( + "`before` arg can't be specified for graph error handler" + ) + + name = get_name(name, class_name) + step = ErrorStep( + class_name, + class_args, + handler, + name=name, + function=function, + full_event=full_event, + input_path=input_path, + result_path=result_path, + ) + self.on_error = name + before = [before] if isinstance(before, str) else before + step.before = before or [] + step.base_step = self.name + if hasattr(self, "_parent") and self._parent: + # when self is a step + step = self._parent._steps.update(name, step) + step.set_parent(self._parent) + else: + # when self is the graph + step = self._steps.update(name, step) + step.set_parent(self) + return self def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): @@ -186,10 +258,11 @@ def _log_error(self, event, err, **kwargs): def _call_error_handler(self, event, err, **kwargs): """call the error handler if exist""" - if self._on_error_handler: - event.error = err_to_str(err) - event.origin_state = self.fullname - return self._on_error_handler(event) + if not event.error: + event.error = {} + event.error[self.name] = err_to_str(err) + event.origin_state = self.fullname + return self._on_error_handler(event) def path_to_step(self, path: str): """return step object from step relative/fullname""" @@ -327,6 +400,7 @@ def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwar args = signature(self._handler).parameters if args and "context" in list(args.keys()): self._inject_context = True + self._set_error_handler() return self._class_object, self.class_name = self.get_step_class_object( @@ -464,14 +538,23 @@ def run(self, event, *args, **kwargs): ) event.body = _update_result_body(self.result_path, event.body, result) except Exception as exc: - self._log_error(event, exc) - handled = self._call_error_handler(event, exc) - if not handled: + if self._on_error_handler: + self._log_error(event, exc) + result = self._call_error_handler(event, exc) + event.body = _update_result_body(self.result_path, event.body, result) + else: raise exc - event.terminated = True return event +class ErrorStep(TaskStep): + """error execution step, runs a class or handler""" + + kind = "error_step" + _dict_fields = _task_step_fields + ["before", "base_step"] + _default_class = "" + + class RouterStep(TaskStep): """router step, implement routing logic for running child routes""" @@ -824,6 +907,7 @@ def __iter__(self): def init_object(self, context, namespace, mode="sync", reset=False, **extra_kwargs): """initialize graph objects and classes""" self.context = context + self._insert_all_error_handlers() self.check_and_process_graph() for step in self._steps.values(): @@ -866,7 +950,11 @@ def has_loop(step, previous): responders = [] for step in self._steps.values(): - if hasattr(step, "responder") and step.responder: + if ( + hasattr(step, "responder") + and step.responder + and step.kind != "error_step" + ): responders.append(step.name) if step.on_error and step.on_error in start_steps: start_steps.remove(step.on_error) @@ -979,6 +1067,10 @@ def process_step(state, step, root): # never set a step as its own error handler if step != error_step: step.async_object.set_recovery_step(error_step.async_object) + for next_step in error_step.next or []: + next_state = self[next_step] + if next_state.async_object and error_step.async_object: + error_step.async_object.to(next_state.async_object) self._controller = source.run() @@ -1059,15 +1151,22 @@ def run(self, event, *args, **kwargs): try: event = next_obj.run(event, *args, **kwargs) except Exception as exc: - self._log_error(event, exc, failed_step=next_obj.name) - handled = self._call_error_handler(event, exc) - if not handled: + if self._on_error_handler: + self._log_error(event, exc, failed_step=next_obj.name) + event.body = self._call_error_handler(event, exc) + event.terminated = True + return event + else: raise exc - event.terminated = True - return event if hasattr(event, "terminated") and event.terminated: return event + if ( + hasattr(event, "error") + and isinstance(event.error, dict) + and next_obj.name in event.error + ): + next_obj = self._steps[next_obj.on_error] next = next_obj.next if next and len(next) > 1: raise GraphError( @@ -1103,6 +1202,33 @@ def plot(self, filename=None, format=None, source=None, targets=None, **kw): **kw, ) + def _insert_all_error_handlers(self): + """ + insert all error steps to the graph + run after deployment + """ + for name, step in self._steps.items(): + if step.kind == "error_step": + self._insert_error_step(name, step) + + def _insert_error_step(self, name, step): + """ + insert error step to the graph + run after deployment + """ + if not step.before and not any( + [step.name in other_step.after for other_step in self._steps.values()] + ): + step.responder = True + return + + for step_name in step.before: + if step_name not in self._steps.keys(): + raise MLRunInvalidArgumentError( + f"cant set before, there is no step named {step_name}" + ) + self[step_name].after_step(name) + class RootFlowStep(FlowStep): """root flow step""" @@ -1116,6 +1242,7 @@ class RootFlowStep(FlowStep): "router": RouterStep, "flow": FlowStep, "queue": QueueStep, + "error_step": ErrorStep, } @@ -1155,15 +1282,8 @@ def _add_graphviz_flow( _add_graphviz_router(sg, child) else: graph.node(child.fullname, label=child.name, shape=child.get_shape()) - after = child.after or [] - for item in after: - previous_object = step[item] - kw = ( - {"ltail": "cluster_" + previous_object.fullname} - if previous_object.kind == StepKinds.router - else {} - ) - graph.edge(previous_object.fullname, child.fullname, **kw) + _add_edges(child.after or [], step, graph, child) + _add_edges(getattr(child, "before", []), step, graph, child, after=False) if child.on_error: graph.edge(child.fullname, child.on_error, style="dashed") @@ -1183,6 +1303,18 @@ def _add_graphviz_flow( graph.edge(last_step, target.fullname) +def _add_edges(items, step, graph, child, after=True): + for item in items: + next_or_prev_object = step[item] + kw = {} + if next_or_prev_object.kind == StepKinds.router: + kw["ltail"] = f"cluster_{next_or_prev_object.fullname}" + if after: + graph.edge(next_or_prev_object.fullname, child.fullname, **kw) + else: + graph.edge(child.fullname, next_or_prev_object.fullname, **kw) + + def _generate_graphviz( step, renderer, @@ -1355,7 +1487,7 @@ def _init_async_objects(context, steps): endpoint, stream_path = parse_path(step.path) stream_path = stream_path.strip("/") step._async_object = storey.StreamTarget( - storey.V3ioDriver(endpoint), + storey.V3ioDriver(endpoint or config.v3io_api), stream_path, context=context, **options, diff --git a/mlrun/serving/utils.py b/mlrun/serving/utils.py index 6f4917f79fca..ee8b9034f91f 100644 --- a/mlrun/serving/utils.py +++ b/mlrun/serving/utils.py @@ -16,8 +16,14 @@ from mlrun.utils import get_in, update_in -event_id_key = "MLRUN_EVENT_ID" -event_path_key = "MLRUN_EVENT_PATH" +# headers keys with underscore are getting ignored by werkzeug https://github.com/pallets/werkzeug/pull/2622 +# to avoid conflicts with WGSI which converts all header keys to uppercase with underscores. +# more info https://github.com/benoitc/gunicorn/issues/2799, this comment can be removed once old keys are removed +event_id_key = "MLRUN-EVENT-ID" +event_path_key = "MLRUN-EVENT-PATH" +# TODO: remove these keys in 1.6.0 +legacy_event_id_key = "MLRUN_EVENT_ID" +legacy_event_path_key = "MLRUN_EVENT_PATH" def _extract_input_data(input_path, body): diff --git a/mlrun/serving/v2_serving.py b/mlrun/serving/v2_serving.py index 14e79a336ef0..468a521afdc8 100644 --- a/mlrun/serving/v2_serving.py +++ b/mlrun/serving/v2_serving.py @@ -11,23 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import threading import time import traceback from typing import Dict, Union import mlrun -from mlrun.api.schemas import ( - ModelEndpoint, - ModelEndpointMetadata, - ModelEndpointSpec, - ModelEndpointStatus, - ModelMonitoringMode, -) +import mlrun.common.model_monitoring +import mlrun.common.schemas from mlrun.artifacts import ModelArtifact # noqa: F401 from mlrun.config import config from mlrun.utils import logger, now_date, parse_versioned_object_uri -from mlrun.utils.model_monitoring import EndpointType from .server import GraphServer from .utils import StepToDict, _extract_input_data, _update_result_body @@ -265,11 +260,20 @@ def do_event(self, event, *args, **kwargs): # get model health operation setattr(event, "terminated", True) if self.ready: - event.body = self.context.Response() + # Generate a response, confirming that the model is ready + event.body = self.context.Response( + status_code=200, + body=bytes( + f"Model {self.name} is ready (event_id = {event_id})", + encoding="utf-8", + ), + ) + else: event.body = self.context.Response( status_code=408, body=b"model not ready" ) + return event elif op == "" and event.method == "GET": @@ -487,7 +491,7 @@ def _init_endpoint_record( versioned_model_name = f"{model.name}:latest" # Generating model endpoint ID based on function uri and model version - uid = mlrun.utils.model_monitoring.create_model_endpoint_id( + uid = mlrun.common.model_monitoring.create_model_endpoint_uid( function_uri=graph_server.function_uri, versioned_model=versioned_model_name ).uid @@ -499,11 +503,11 @@ def _init_endpoint_record( logger.info("Creating a new model endpoint record", endpoint_id=uid) try: - model_endpoint = ModelEndpoint( - metadata=ModelEndpointMetadata( + model_endpoint = mlrun.common.schemas.ModelEndpoint( + metadata=mlrun.common.schemas.ModelEndpointMetadata( project=project, labels=model.labels, uid=uid ), - spec=ModelEndpointSpec( + spec=mlrun.common.schemas.ModelEndpointSpec( function_uri=graph_server.function_uri, model=versioned_model_name, model_class=model.__class__.__name__, @@ -512,18 +516,21 @@ def _init_endpoint_record( project=project, kind="stream" ), active=True, - monitoring_mode=ModelMonitoringMode.enabled + monitoring_mode=mlrun.common.model_monitoring.ModelMonitoringMode.enabled if model.context.server.track_models - else ModelMonitoringMode.disabled, + else mlrun.common.model_monitoring.ModelMonitoringMode.disabled, + ), + status=mlrun.common.schemas.ModelEndpointStatus( + endpoint_type=mlrun.common.model_monitoring.EndpointType.NODE_EP ), - status=ModelEndpointStatus(endpoint_type=EndpointType.NODE_EP), ) db = mlrun.get_run_db() + db.create_model_endpoint( project=project, - endpoint_id=model_endpoint.metadata.uid, - model_endpoint=model_endpoint, + endpoint_id=uid, + model_endpoint=model_endpoint.dict(), ) except Exception as e: diff --git a/mlrun/utils/__init__.py b/mlrun/utils/__init__.py index ff2582c48454..854be7e0e41e 100644 --- a/mlrun/utils/__init__.py +++ b/mlrun/utils/__init__.py @@ -18,4 +18,3 @@ from .helpers import * # noqa from .http import * # noqa from .logger import * # noqa -from .vault import * # noqa diff --git a/mlrun/utils/condition_evaluator.py b/mlrun/utils/condition_evaluator.py new file mode 100644 index 000000000000..2dbf620f8f0e --- /dev/null +++ b/mlrun/utils/condition_evaluator.py @@ -0,0 +1,65 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import typing + +from mlrun.utils import logger + + +def evaluate_condition_in_separate_process( + condition: str, context: typing.Dict[str, typing.Any], timeout: int = 5 +): + + if not condition: + return True + + receiver, sender = multiprocessing.Pipe() + p = multiprocessing.Process( + target=_evaluate_condition_wrapper, + args=(sender, condition, context), + ) + p.start() + if receiver.poll(timeout): + result = receiver.recv() + p.join() + return result + else: + p.kill() + logger.warning( + f"Condition evaluation timed out after {timeout} seconds. Ignoring condition", + condition=condition, + ) + return True + + +def _evaluate_condition_wrapper( + connection, condition: str, context: typing.Dict[str, typing.Any] +): + connection.send(_evaluate_condition(condition, context)) + return connection.close() + + +def _evaluate_condition(condition: str, context: typing.Dict[str, typing.Any]): + + import jinja2.sandbox + + jinja_env = jinja2.sandbox.SandboxedEnvironment() + template = jinja_env.from_string(condition) + result = template.render(**context) + if result.lower() in ["0", "no", "n", "f", "false", "off"]: + return False + + # if the condition is not a boolean, we ignore the condition + return True diff --git a/mlrun/utils/db.py b/mlrun/utils/db.py new file mode 100644 index 000000000000..e66940e99825 --- /dev/null +++ b/mlrun/utils/db.py @@ -0,0 +1,52 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pickle +from datetime import datetime + +from sqlalchemy.orm import class_mapper + + +class BaseModel: + def to_dict(self, exclude=None): + """ + NOTE - this function (currently) does not handle serializing relationships + """ + exclude = exclude or [] + mapper = class_mapper(self.__class__) + columns = [column.key for column in mapper.columns if column.key not in exclude] + get_key_value = ( + lambda c: (c, getattr(self, c).isoformat()) + if isinstance(getattr(self, c), datetime) + else (c, getattr(self, c)) + ) + return dict(map(get_key_value, columns)) + + +class HasStruct(BaseModel): + @property + def struct(self): + return pickle.loads(self.body) + + @struct.setter + def struct(self, value): + self.body = pickle.dumps(value) + + def to_dict(self, exclude=None): + """ + NOTE - this function (currently) does not handle serializing relationships + """ + exclude = exclude or [] + exclude.append("body") + return super().to_dict(exclude) diff --git a/mlrun/utils/helpers.py b/mlrun/utils/helpers.py index 7d8dfe40196c..037b302067b1 100644 --- a/mlrun/utils/helpers.py +++ b/mlrun/utils/helpers.py @@ -16,6 +16,7 @@ import hashlib import inspect import json +import os import re import sys import time @@ -96,8 +97,11 @@ def get_artifact_target(item: dict, project=None): tree = item["metadata"].get("tree") kind = item.get("kind") - if kind in ["dataset", "model"] and db_key: - return f"{DB_SCHEMA}://{StorePrefix.Artifact}/{project_str}/{db_key}:{tree}" + if kind in ["dataset", "model", "artifact"] and db_key: + target = f"{DB_SCHEMA}://{StorePrefix.Artifact}/{project_str}/{db_key}" + if tree: + target = f"{target}:{tree}" + return target return ( item.get("target_path") @@ -163,6 +167,35 @@ def verify_field_regex( return True +def validate_builder_source( + source: str, pull_at_runtime: bool = False, workdir: str = None +): + if pull_at_runtime or not source: + return + + if "://" not in source: + if not path.isabs(source): + raise mlrun.errors.MLRunInvalidArgumentError( + f"Source '{source}' must be a valid URL or absolute path when 'pull_at_runtime' is False " + "set 'source' to a remote URL to clone/copy the source to the base image, " + "or set 'pull_at_runtime' to True to pull the source at runtime." + ) + + else: + logger.warn( + "Loading local source at build time requires the source to be on the base image, " + "in which case it is recommended to use 'workdir' instead", + source=source, + workdir=workdir, + ) + + if source.endswith(".zip"): + logger.warn( + "zip files are not natively extracted by docker, use tar.gz for faster loading during build", + source=source, + ) + + def validate_tag_name( tag_name: str, field_name: str, raise_on_failure: bool = True ) -> bool: @@ -208,6 +241,23 @@ def is_yaml_path(url): return url.endswith(".yaml") or url.endswith(".yml") +def remove_image_protocol_prefix(image: str) -> str: + if not image: + return image + + prefixes = ["https://", "https://"] + if any(prefix in image for prefix in prefixes): + image = image.removeprefix("https://").removeprefix("http://") + logger.warning( + "The image has an unexpected protocol prefix ('http://' or 'https://'). " + "If you wish to use the default configured registry, no protocol prefix is required " + "(note that you can also use '.' instead of the full URL where is a placeholder). " + "Removing protocol prefix from image.", + image=image, + ) + return image + + # Verifying that a field input is of the expected type. If not the method raises a detailed MLRunInvalidArgumentError def verify_field_of_type(field_name: str, field_value, expected_type: type): if not isinstance(field_value, expected_type): @@ -954,7 +1004,7 @@ def retry_until_successful( f" last_exception: {last_exception}," f" function_name: {_function.__name__}," f" timeout: {timeout}" - ) + ) from last_exception def get_ui_url(project, uid=None): @@ -1000,7 +1050,7 @@ def create_class(pkg_class: str): return class_ -def create_function(pkg_func: list): +def create_function(pkg_func: str): """Create a function from a package.module.function string :param pkg_func: full function location, @@ -1014,9 +1064,16 @@ def create_function(pkg_func: list): return function_ -def get_caller_globals(level=2): +def get_caller_globals(): + """Returns a dictionary containing the first non-mlrun caller function's namespace.""" try: - return inspect.stack()[level][0].f_globals + stack = inspect.stack() + # If an API function called this function directly, the first non-mlrun caller will be 2 levels up the stack. + # Otherwise, we keep going up the stack until we find it. + for level in range(2, len(stack)): + namespace = stack[level][0].f_globals + if not namespace["__name__"].startswith("mlrun."): + return namespace except Exception: return None @@ -1295,3 +1352,31 @@ def ensure_git_branch(url: str, repo: git.Repo) -> str: if not branch and not reference: url = f"{url}#refs/heads/{repo.active_branch}" return url + + +def is_file_path(filepath): + root, ext = os.path.splitext(filepath) + return os.path.isfile(filepath) and ext + + +class DeprecationHelper(object): + """A helper class to deprecate old schemas""" + + def __init__(self, new_target, version="1.4.0"): + self._new_target = new_target + self._version = version + + def _warn(self): + warnings.warn( + f"mlrun.api.schemas.{self._new_target.__name__} is deprecated in version {self._version}, " + f"Please use mlrun.common.schemas.{self._new_target.__name__} instead.", + FutureWarning, + ) + + def __call__(self, *args, **kwargs): + self._warn() + return self._new_target(*args, **kwargs) + + def __getattr__(self, attr): + self._warn() + return getattr(self._new_target, attr) diff --git a/mlrun/utils/http.py b/mlrun/utils/http.py index a40e764941c5..9022b7044745 100644 --- a/mlrun/utils/http.py +++ b/mlrun/utils/http.py @@ -79,6 +79,7 @@ def __init__( self.retry_backoff_factor = retry_backoff_factor self.retry_on_exception = retry_on_exception self.verbose = verbose + self._logger = logger.get_child("http-client") if retry_on_status: http_adapter = requests.adapters.HTTPAdapter( @@ -142,6 +143,13 @@ def request(self, method, url, **kwargs): ) raise exc + self._logger.warning( + "Error during request handling, retrying", + exc=str(exc), + retry_count=retry_count, + url=url, + method=method, + ) if self.verbose: self._log_exception( "debug", @@ -159,11 +167,11 @@ def _get_retry_methods(retry_on_post=False): # setting to False in order to retry on all methods, otherwise every method except POST. False if retry_on_post - else urllib3.util.retry.Retry.DEFAULT_METHOD_WHITELIST + else urllib3.util.retry.Retry.DEFAULT_ALLOWED_METHODS ) def _log_exception(self, level, exc, message, retry_count): - getattr(logger, level)( + getattr(self._logger, level)( message, exception_type=type(exc), exception_message=err_to_str(exc), diff --git a/mlrun/utils/logger.py b/mlrun/utils/logger.py index 266968da8ca9..87a76f81fef5 100644 --- a/mlrun/utils/logger.py +++ b/mlrun/utils/logger.py @@ -17,7 +17,7 @@ from enum import Enum from sys import stdout from traceback import format_exception -from typing import IO, Union +from typing import IO, Optional, Union from mlrun.config import config @@ -42,24 +42,42 @@ def format(self, record): class HumanReadableFormatter(logging.Formatter): - def __init__(self): - super(HumanReadableFormatter, self).__init__() - def format(self, record): + record_with = self._record_with(record) + more = f": {record_with}" if record_with else "" + return f"> {self.formatTime(record, self.datefmt)} [{record.levelname.lower()}] {record.getMessage()}{more}" + + def _record_with(self, record): record_with = getattr(record, "with", {}) if record.exc_info: record_with.update(exc_info=format_exception(*record.exc_info)) + return record_with + + +class HumanReadableExtendedFormatter(HumanReadableFormatter): + def format(self, record): + record_with = self._record_with(record) more = f": {record_with}" if record_with else "" - return f"> {self.formatTime(record, self.datefmt)} [{record.levelname.lower()}] {record.getMessage()}{more}" + return ( + "> " + f"{self.formatTime(record, self.datefmt)} " + f"[{record.name}:{record.levelname.lower()}] " + f"{record.getMessage()}{more}" + ) class Logger(object): - def __init__(self, level, name="mlrun", propagate=True): - self._logger = logging.getLogger(name) + def __init__( + self, + level, + name="mlrun", + propagate=True, + logger: Optional[logging.Logger] = None, + ): + self._logger = logger or logging.getLogger(name) self._logger.propagate = propagate self._logger.setLevel(level) self._bound_variables = {} - self._handlers = {} for log_level_func in [ self.exception, @@ -76,14 +94,14 @@ def set_handler( ): # check if there's a handler by this name - if handler_name in self._handlers: - # log that we're removing it - self.info("Replacing logger output", handler_name=handler_name) - - self._logger.removeHandler(self._handlers[handler_name]) + for handler in self._logger.handlers: + if handler.name == handler_name: + self._logger.removeHandler(handler) + break # create a stream handler from the file stream_handler = logging.StreamHandler(file) + stream_handler.name = handler_name # set the formatter stream_handler.setFormatter(formatter) @@ -91,8 +109,24 @@ def set_handler( # add the handler to the logger self._logger.addHandler(stream_handler) - # save as the named output - self._handlers[handler_name] = stream_handler + def get_child(self, suffix): + """ + Get a child logger with the given suffix. + This is useful for when you want to have a logger for a specific component. + Once the formatter will support logger name, it will be easier to understand + which component logged the message. + + :param suffix: The suffix to add to the logger name. + """ + return Logger( + self.level, + # name is not set as it is provided by the "getChild" + name="", + # allowing child to delegate events logged to ancestor logger + # not doing so, will leave log lines not being handled + propagate=True, + logger=self._logger.getChild(suffix), + ) @property def level(self): @@ -102,7 +136,11 @@ def set_logger_level(self, level: Union[str, int]): self._logger.setLevel(level) def replace_handler_stream(self, handler_name: str, file: IO[str]): - self._handlers[handler_name].stream = file + for handler in self._logger.handlers: + if handler.name == handler_name: + handler.stream = file + return + raise ValueError(f"Logger does not have a handler named '{handler_name}'") def debug(self, message, *args, **kw_args): self._update_bound_vars_and_log(logging.DEBUG, message, *args, **kw_args) @@ -143,12 +181,14 @@ def _update_bound_vars_and_log( class FormatterKinds(Enum): HUMAN = "human" + HUMAN_EXTENDED = "human_extended" JSON = "json" def _create_formatter_instance(formatter_kind: FormatterKinds) -> logging.Formatter: return { FormatterKinds.HUMAN: HumanReadableFormatter(), + FormatterKinds.HUMAN_EXTENDED: HumanReadableExtendedFormatter(), FormatterKinds.JSON: JSONFormatter(), }[formatter_kind] diff --git a/mlrun/utils/model_monitoring.py b/mlrun/utils/model_monitoring.py index 361f938b2b0f..e6349a744910 100644 --- a/mlrun/utils/model_monitoring.py +++ b/mlrun/utils/model_monitoring.py @@ -13,97 +13,16 @@ # limitations under the License. # -import enum -import hashlib -from dataclasses import dataclass -from typing import Optional, Union +import json +import warnings +from typing import Union import mlrun +import mlrun.common.model_monitoring as model_monitoring_constants import mlrun.model -import mlrun.model_monitoring.constants as model_monitoring_constants import mlrun.platforms.iguazio -import mlrun.utils -from mlrun.api.schemas.schedule import ScheduleCronTrigger - - -@dataclass -class FunctionURI: - project: str - function: str - tag: Optional[str] = None - hash_key: Optional[str] = None - - @classmethod - def from_string(cls, function_uri): - project, uri, tag, hash_key = mlrun.utils.parse_versioned_object_uri( - function_uri - ) - return cls( - project=project, - function=uri, - tag=tag or None, - hash_key=hash_key or None, - ) - - -@dataclass -class VersionedModel: - model: str - version: Optional[str] - - @classmethod - def from_string(cls, model): - try: - model, version = model.split(":") - except ValueError: - model, version = model, None - - return cls(model, version) - - -@dataclass -class EndpointUID: - project: str - function: str - function_tag: str - function_hash_key: str - model: str - model_version: str - uid: Optional[str] = None - - def __post_init__(self): - function_ref = ( - f"{self.function}_{self.function_tag or self.function_hash_key or 'N/A'}" - ) - versioned_model = f"{self.model}_{self.model_version or 'N/A'}" - unique_string = f"{self.project}_{function_ref}_{versioned_model}" - self.uid = hashlib.sha1(unique_string.encode("utf-8")).hexdigest() - - def __str__(self): - return self.uid - - -def create_model_endpoint_id(function_uri: str, versioned_model: str): - function_uri = FunctionURI.from_string(function_uri) - versioned_model = VersionedModel.from_string(versioned_model) - - if ( - not function_uri.project - or not function_uri.function - or not versioned_model.model - ): - raise ValueError("Both function_uri and versioned_model have to be initialized") - - uid = EndpointUID( - function_uri.project, - function_uri.function, - function_uri.tag, - function_uri.hash_key, - versioned_model.model, - versioned_model.version, - ) - - return uid +from mlrun.common.schemas.schedule import ScheduleCronTrigger +from mlrun.config import is_running_as_api def parse_model_endpoint_project_prefix(path: str, project_name: str): @@ -116,29 +35,20 @@ def parse_model_endpoint_store_prefix(store_prefix: str): return endpoint, container, path -def set_project_model_monitoring_credentials( - access_key: str, project: Optional[str] = None -): +def set_project_model_monitoring_credentials(access_key: str, project: str = None): """Set the credentials that will be used by the project's model monitoring infrastructure functions. The supplied credentials must have data access - :param access_key: Model Monitoring access key for managing user permissions. :param project: The name of the model monitoring project. """ mlrun.get_run_db().create_project_secrets( project=project or mlrun.mlconf.default_project, - provider=mlrun.api.schemas.SecretProviderName.kubernetes, - secrets={"MODEL_MONITORING_ACCESS_KEY": access_key}, + provider=mlrun.common.schemas.SecretProviderName.kubernetes, + secrets={model_monitoring_constants.ProjectSecretKeys.ACCESS_KEY: access_key}, ) -class EndpointType(enum.IntEnum): - NODE_EP = 1 # end point that is not a child of a router - ROUTER = 2 # endpoint that is router - LEAF_EP = 3 # end point that is a child of a router - - class TrackingPolicy(mlrun.model.ModelObj): """ Modified model monitoring configurations. By using TrackingPolicy, the user can apply his model monitoring @@ -215,3 +125,125 @@ def to_dict(self, fields=None, exclude=None): model_monitoring_constants.EventFieldType.DEFAULT_BATCH_INTERVALS ] = self.default_batch_intervals.dict() return struct + + +def get_connection_string(project: str = None): + """Get endpoint store connection string from the project secret. + If wasn't set, take it from the system configurations""" + if is_running_as_api(): + # Running on API server side + import mlrun.api.crud.secrets + import mlrun.common.schemas + + return ( + mlrun.api.crud.secrets.Secrets().get_project_secret( + project=project, + provider=mlrun.common.schemas.secret.SecretProviderName.kubernetes, + allow_secrets_from_k8s=True, + secret_key=model_monitoring_constants.ProjectSecretKeys.ENDPOINT_STORE_CONNECTION, + ) + or mlrun.mlconf.model_endpoint_monitoring.endpoint_store_connection + ) + else: + # Running on stream server side + import mlrun + + return ( + mlrun.get_secret_or_env( + model_monitoring_constants.ProjectSecretKeys.ENDPOINT_STORE_CONNECTION + ) + or mlrun.mlconf.model_endpoint_monitoring.endpoint_store_connection + ) + + +def get_stream_path(project: str = None): + # TODO: This function (as well as other methods in this file) includes both client and server side code. We will + # need to refactor and adjust this file in the future. + """Get stream path from the project secret. If wasn't set, take it from the system configurations""" + + if is_running_as_api(): + # Running on API server side + import mlrun.api.crud.secrets + import mlrun.common.schemas + + stream_uri = mlrun.api.crud.secrets.Secrets().get_project_secret( + project=project, + provider=mlrun.common.schemas.secret.SecretProviderName.kubernetes, + allow_secrets_from_k8s=True, + secret_key=model_monitoring_constants.ProjectSecretKeys.STREAM_PATH, + ) or mlrun.mlconf.get_model_monitoring_file_target_path( + project=project, + kind=model_monitoring_constants.FileTargetKind.STREAM, + target="online", + ) + + else: + import mlrun + + stream_uri = mlrun.get_secret_or_env( + model_monitoring_constants.ProjectSecretKeys.STREAM_PATH + ) or mlrun.mlconf.get_model_monitoring_file_target_path( + project=project, + kind=model_monitoring_constants.FileTargetKind.STREAM, + target="online", + ) + + if stream_uri.startswith("kafka://"): + if "?topic" in stream_uri: + raise mlrun.errors.MLRunInvalidArgumentError( + "Custom kafka topic is not allowed" + ) + # Add topic to stream kafka uri + stream_uri += f"?topic=monitoring_stream_{project}" + + elif stream_uri.startswith("v3io://") and mlrun.mlconf.is_ce_mode(): + # V3IO is not supported in CE mode, generating a default http stream path + stream_uri = mlrun.mlconf.model_endpoint_monitoring.default_http_sink + + return stream_uri + + +def validate_old_schema_fields(endpoint: dict): + """ + Replace default null values for `error_count` and `metrics` for users that logged a model endpoint before 1.3.0. + In addition, this function also validates that the key name of the endpoint unique id is `uid` and not + `endpoint_id` that has been used before 1.3.0. + + Leaving here for backwards compatibility which related to the model endpoint schema. + + :param endpoint: An endpoint flattened dictionary. + """ + warnings.warn( + "This will be deprecated in 1.3.0, and will be removed in 1.5.0", + # TODO: In 1.3.0 do changes in examples & demos In 1.5.0 remove + FutureWarning, + ) + + # Validate default value for `error_count` + # For backwards compatibility reasons, we validate that the model endpoint includes the `error_count` key + if ( + model_monitoring_constants.EventFieldType.ERROR_COUNT in endpoint + and endpoint[model_monitoring_constants.EventFieldType.ERROR_COUNT] == "null" + ): + endpoint[model_monitoring_constants.EventFieldType.ERROR_COUNT] = "0" + + # Validate default value for `metrics` + # For backwards compatibility reasons, we validate that the model endpoint includes the `metrics` key + if ( + model_monitoring_constants.EventFieldType.METRICS in endpoint + and endpoint[model_monitoring_constants.EventFieldType.METRICS] == "null" + ): + endpoint[model_monitoring_constants.EventFieldType.METRICS] = json.dumps( + { + model_monitoring_constants.EventKeyMetrics.GENERIC: { + model_monitoring_constants.EventLiveStats.LATENCY_AVG_1H: 0, + model_monitoring_constants.EventLiveStats.PREDICTIONS_PER_SECOND: 0, + } + } + ) + # Validate key `uid` instead of `endpoint_id` + # For backwards compatibility reasons, we replace the `endpoint_id` with `uid` which is the updated key name + if model_monitoring_constants.EventFieldType.ENDPOINT_ID in endpoint: + endpoint[model_monitoring_constants.EventFieldType.UID] = endpoint[ + model_monitoring_constants.EventFieldType.ENDPOINT_ID + ] diff --git a/mlrun/utils/notifications/notification/__init__.py b/mlrun/utils/notifications/notification/__init__.py index f0e7435a3080..f8ef41cbc6bb 100644 --- a/mlrun/utils/notifications/notification/__init__.py +++ b/mlrun/utils/notifications/notification/__init__.py @@ -15,6 +15,8 @@ import enum import typing +from mlrun.common.schemas.notification import NotificationKind + from .base import NotificationBase from .console import ConsoleNotification from .git import GitNotification @@ -23,10 +25,10 @@ class NotificationTypes(str, enum.Enum): - console = "console" - git = "git" - ipython = "ipython" - slack = "slack" + console = NotificationKind.console.value + git = NotificationKind.git.value + ipython = NotificationKind.ipython.value + slack = NotificationKind.slack.value def get_notification(self) -> typing.Type[NotificationBase]: return { diff --git a/mlrun/utils/notifications/notification/base.py b/mlrun/utils/notifications/notification/base.py index 4668f15c615c..eb587cdf5ce3 100644 --- a/mlrun/utils/notifications/notification/base.py +++ b/mlrun/utils/notifications/notification/base.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import typing -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.lists @@ -31,12 +32,16 @@ def __init__( def active(self) -> bool: return True + @property + def is_async(self) -> bool: + return asyncio.iscoroutinefunction(self.push) + def push( self, message: str, severity: typing.Union[ - mlrun.api.schemas.NotificationSeverity, str - ] = mlrun.api.schemas.NotificationSeverity.INFO, + mlrun.common.schemas.NotificationSeverity, str + ] = mlrun.common.schemas.NotificationSeverity.INFO, runs: typing.Union[mlrun.lists.RunList, list] = None, custom_html: str = None, ): @@ -52,8 +57,8 @@ def _get_html( self, message: str, severity: typing.Union[ - mlrun.api.schemas.NotificationSeverity, str - ] = mlrun.api.schemas.NotificationSeverity.INFO, + mlrun.common.schemas.NotificationSeverity, str + ] = mlrun.common.schemas.NotificationSeverity.INFO, runs: typing.Union[mlrun.lists.RunList, list] = None, custom_html: str = None, ) -> str: diff --git a/mlrun/utils/notifications/notification/console.py b/mlrun/utils/notifications/notification/console.py index 4f56c34fa9ad..3b6aacbb8f1e 100644 --- a/mlrun/utils/notifications/notification/console.py +++ b/mlrun/utils/notifications/notification/console.py @@ -16,7 +16,7 @@ import tabulate -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.lists import mlrun.utils.helpers @@ -32,8 +32,8 @@ def push( self, message: str, severity: typing.Union[ - mlrun.api.schemas.NotificationSeverity, str - ] = mlrun.api.schemas.NotificationSeverity.INFO, + mlrun.common.schemas.NotificationSeverity, str + ] = mlrun.common.schemas.NotificationSeverity.INFO, runs: typing.Union[mlrun.lists.RunList, list] = None, custom_html: str = None, ): diff --git a/mlrun/utils/notifications/notification/git.py b/mlrun/utils/notifications/notification/git.py index 401a39f6dc75..49beea8cce14 100644 --- a/mlrun/utils/notifications/notification/git.py +++ b/mlrun/utils/notifications/notification/git.py @@ -18,7 +18,7 @@ import aiohttp -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.errors import mlrun.lists @@ -34,13 +34,14 @@ async def push( self, message: str, severity: typing.Union[ - mlrun.api.schemas.NotificationSeverity, str - ] = mlrun.api.schemas.NotificationSeverity.INFO, + mlrun.common.schemas.NotificationSeverity, str + ] = mlrun.common.schemas.NotificationSeverity.INFO, runs: typing.Union[mlrun.lists.RunList, list] = None, custom_html: str = None, ): git_repo = self.params.get("repo", None) git_issue = self.params.get("issue", None) + git_merge_request = self.params.get("merge_request", None) token = ( self.params.get("token", None) or self.params.get("GIT_TOKEN", None) @@ -52,6 +53,7 @@ async def push( self._get_html(message, severity, runs, custom_html), git_repo, git_issue, + merge_request=git_merge_request, token=token, server=server, gitlab=gitlab, @@ -62,6 +64,7 @@ async def _pr_comment( message: str, repo: str = None, issue: int = None, + merge_request: int = None, token: str = None, server: str = None, gitlab: bool = False, @@ -89,12 +92,19 @@ async def _pr_comment( headers = {"PRIVATE-TOKEN": token} repo = repo or os.environ.get("CI_PROJECT_ID") # auto detect GitLab pr id from the environment - issue = issue or os.environ.get("CI_MERGE_REQUEST_IID") + issue = issue or os.environ.get("CI_ISSUE_IID") + merge_request = merge_request or os.environ.get("CI_MERGE_REQUEST_IID") # replace slash with url encoded slash for GitLab to accept a repo name with slash repo = repo.replace("/", "%2F") - url = ( - f"https://{server}/api/v4/projects/{repo}/merge_requests/{issue}/notes" - ) + + if merge_request: + url = f"https://{server}/api/v4/projects/{repo}/merge_requests/{merge_request}/notes" + elif issue: + url = f"https://{server}/api/v4/projects/{repo}/issues/{issue}/notes" + else: + raise mlrun.errors.MLRunInvalidArgumentError( + "GitLab issue or merge request id not specified" + ) else: server = server or "api.github.com" repo = repo or os.environ.get("GITHUB_REPOSITORY") @@ -120,7 +130,7 @@ async def _pr_comment( if not resp.ok: resp_text = await resp.text() raise mlrun.errors.MLRunBadRequestError( - f"Failed commenting on PR: {resp_text}", status=resp.status + f"Failed commenting on PR: {resp_text}" ) data = await resp.json() return data.get("id") diff --git a/mlrun/utils/notifications/notification/ipython.py b/mlrun/utils/notifications/notification/ipython.py index 7fc7f2fcc666..3871fe1a0fbd 100644 --- a/mlrun/utils/notifications/notification/ipython.py +++ b/mlrun/utils/notifications/notification/ipython.py @@ -14,7 +14,7 @@ import typing -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.lists import mlrun.utils.helpers @@ -28,9 +28,10 @@ class IPythonNotification(NotificationBase): def __init__( self, + name: str = None, params: typing.Dict[str, str] = None, ): - super().__init__(params) + super().__init__(name, params) self._ipython = None try: import IPython @@ -50,8 +51,8 @@ def push( self, message: str, severity: typing.Union[ - mlrun.api.schemas.NotificationSeverity, str - ] = mlrun.api.schemas.NotificationSeverity.INFO, + mlrun.common.schemas.NotificationSeverity, str + ] = mlrun.common.schemas.NotificationSeverity.INFO, runs: typing.Union[mlrun.lists.RunList, list] = None, custom_html: str = None, ): diff --git a/mlrun/utils/notifications/notification/slack.py b/mlrun/utils/notifications/notification/slack.py index 683b8c68857a..3ad897a4d663 100644 --- a/mlrun/utils/notifications/notification/slack.py +++ b/mlrun/utils/notifications/notification/slack.py @@ -16,7 +16,7 @@ import aiohttp -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.lists import mlrun.utils.helpers @@ -38,8 +38,8 @@ async def push( self, message: str, severity: typing.Union[ - mlrun.api.schemas.NotificationSeverity, str - ] = mlrun.api.schemas.NotificationSeverity.INFO, + mlrun.common.schemas.NotificationSeverity, str + ] = mlrun.common.schemas.NotificationSeverity.INFO, runs: typing.Union[mlrun.lists.RunList, list] = None, custom_html: str = None, ): @@ -63,8 +63,8 @@ def _generate_slack_data( self, message: str, severity: typing.Union[ - mlrun.api.schemas.NotificationSeverity, str - ] = mlrun.api.schemas.NotificationSeverity.INFO, + mlrun.common.schemas.NotificationSeverity, str + ] = mlrun.common.schemas.NotificationSeverity.INFO, runs: typing.Union[mlrun.lists.RunList, list] = None, ) -> dict: data = { @@ -75,6 +75,10 @@ def _generate_slack_data( }, ] } + if self.name: + data["blocks"].append( + {"type": "section", "text": self._get_slack_row(self.name)} + ) if not runs: return data diff --git a/mlrun/utils/notifications/notification_pusher.py b/mlrun/utils/notifications/notification_pusher.py index b0becaddbd6b..066bb40fb3ef 100644 --- a/mlrun/utils/notifications/notification_pusher.py +++ b/mlrun/utils/notifications/notification_pusher.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ast import asyncio import datetime import os @@ -22,13 +21,13 @@ import mlrun.api.db.base import mlrun.api.db.session -import mlrun.api.schemas -import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas import mlrun.config import mlrun.lists import mlrun.model import mlrun.utils.helpers from mlrun.utils import logger +from mlrun.utils.condition_evaluator import evaluate_condition_in_separate_process from .notification import NotificationBase, NotificationTypes @@ -38,20 +37,30 @@ class NotificationPusher(object): messages = { "completed": "Run completed", "error": "Run failed", + "aborted": "Run aborted", } def __init__(self, runs: typing.Union[mlrun.lists.RunList, list]): self._runs = runs - self._notification_data = [] - self._notifications = {} + self._sync_notifications = [] + self._async_notifications = [] for run in self._runs: if isinstance(run, dict): run = mlrun.model.RunObject.from_dict(run) for notification in run.spec.notifications: + try: + notification.status = run.status.notifications.get( + notification.name + ).get("status", mlrun.common.schemas.NotificationStatus.PENDING) + except (AttributeError, KeyError): + notification.status = ( + mlrun.common.schemas.NotificationStatus.PENDING + ) + if self._should_notify(run, notification): - self._notification_data.append((run, notification)) + self._load_notification(run, notification) def push( self, @@ -63,33 +72,55 @@ def push( wait for all notifications to be pushed before returning. """ - async def _push(): + if not len(self._sync_notifications) and not len(self._async_notifications): + return + + def _sync_push(): + for notification_data in self._sync_notifications: + self._push_notification_sync( + notification_data[0], + notification_data[1], + notification_data[2], + db, + ) + + async def _async_push(): tasks = [] - for notification_data in self._notification_data: + for notification_data in self._async_notifications: tasks.append( - self._push_notification( - self._load_notification(*notification_data), + self._push_notification_async( notification_data[0], notification_data[1], + notification_data[2], db, ) ) - await asyncio.gather(*tasks) + + # return exceptions to "best-effort" fire all notifications + await asyncio.gather(*tasks, return_exceptions=True) logger.debug( - "Pushing notifications", notifications_amount=len(self._notification_data) + "Pushing notifications", + notifications_amount=len(self._sync_notifications) + + len(self._async_notifications), ) + + # first push async notifications main_event_loop = asyncio.get_event_loop() if main_event_loop.is_running(): # If running from the api or from jupyter notebook, we are already in an event loop. # We add the async push function to the loop and run it. - asyncio.run_coroutine_threadsafe(_push(), main_event_loop) + asyncio.run_coroutine_threadsafe(_async_push(), main_event_loop) else: # If running mlrun SDK locally (not from jupyter), there isn't necessarily an event loop. # We create a new event loop and run the async push function in it. - main_event_loop.run_until_complete(_push()) + main_event_loop.run_until_complete(_async_push()) + + # then push sync notifications + if not mlrun.config.is_running_as_api(): + _sync_push() @staticmethod def _should_notify( @@ -97,104 +128,161 @@ def _should_notify( notification: mlrun.model.Notification, ) -> bool: when_states = notification.when - condition = notification.condition run_state = run.state() # if the notification isn't pending, don't push it if ( notification.status - and notification.status != mlrun.api.schemas.NotificationStatus.PENDING + and notification.status != mlrun.common.schemas.NotificationStatus.PENDING ): return False # if at least one condition is met, notify for when_state in when_states: - if ( - when_state == run_state == "completed" - and (not condition or ast.literal_eval(condition)) - ) or when_state == run_state == "error": - return True + if when_state == run_state: + if ( + run_state == "completed" + and evaluate_condition_in_separate_process( + notification.condition, + context={ + "run": run.to_dict(), + "notification": notification.to_dict(), + }, + ) + ) or run_state in ["error", "aborted"]: + return True return False def _load_notification( - self, run: mlrun.model.RunObject, notification: mlrun.model.Notification + self, run: mlrun.model.RunObject, notification_object: mlrun.model.Notification ) -> NotificationBase: - name = notification.name + name = notification_object.name notification_type = NotificationTypes( - notification.kind or NotificationTypes.console + notification_object.kind or NotificationTypes.console + ) + notification = notification_type.get_notification()( + name, notification_object.params ) - notification_key = f"{run.metadata.uid}-{name or notification_type}" - if notification_key not in self._notifications: - self._notifications[ - notification_key - ] = notification_type.get_notification()(name, notification.params) + if notification.is_async: + self._async_notifications.append((notification, run, notification_object)) else: - self._notifications[notification_key].load_notification(notification.params) + self._sync_notifications.append((notification, run, notification_object)) logger.debug( - "Loaded notification", notification=self._notifications[notification_key] + "Loaded notification", notification=name, type=notification_type.value + ) + return notification + + def _prepare_notification_args( + self, run: mlrun.model.RunObject, notification_object: mlrun.model.Notification + ): + custom_message = ( + f": {notification_object.message}" if notification_object.message else "" + ) + message = self.messages.get(run.state(), "") + custom_message + + severity = ( + notification_object.severity + or mlrun.common.schemas.NotificationSeverity.INFO ) - return self._notifications[notification_key] + return message, severity, [run.to_dict()] - async def _push_notification( + def _push_notification_sync( self, notification: NotificationBase, run: mlrun.model.RunObject, notification_object: mlrun.model.Notification, db: mlrun.api.db.base.DBInterface, ): - message = self.messages.get(run.state(), "") - severity = ( - notification_object.severity or mlrun.api.schemas.NotificationSeverity.INFO + message, severity, runs = self._prepare_notification_args( + run, notification_object ) logger.debug( - "Pushing notification", - notification=notification_object.to_dict(), + "Pushing sync notification", + notification=_sanitize_notification(notification_object), run_uid=run.metadata.uid, ) try: - if asyncio.iscoroutinefunction(notification.push): - await notification.push(message, severity, [run.to_dict()]) - else: - notification.push(message, severity, [run.to_dict()]) + notification.push(message, severity, runs) + self._update_notification_status( + db, + run.metadata.uid, + run.metadata.project, + notification_object, + status=mlrun.common.schemas.NotificationStatus.SENT, + sent_time=datetime.datetime.now(tz=datetime.timezone.utc), + ) + except Exception as exc: + self._update_notification_status( + db, + run.metadata.uid, + run.metadata.project, + notification_object, + status=mlrun.common.schemas.NotificationStatus.ERROR, + ) + raise exc - if mlrun.config.is_running_as_api(): - await self._update_notification_status( - db, - run.metadata.uid, - run.metadata.project, - notification_object, - status=mlrun.api.schemas.NotificationStatus.SENT, - sent_time=datetime.datetime.now(tz=datetime.timezone.utc), - ) + async def _push_notification_async( + self, + notification: NotificationBase, + run: mlrun.model.RunObject, + notification_object: mlrun.model.Notification, + db: mlrun.api.db.base.DBInterface, + ): + message, severity, runs = self._prepare_notification_args( + run, notification_object + ) + logger.debug( + "Pushing async notification", + notification=_sanitize_notification(notification_object), + run_uid=run.metadata.uid, + ) + try: + await notification.push(message, severity, runs) + + await run_in_threadpool( + self._update_notification_status, + db, + run.metadata.uid, + run.metadata.project, + notification_object, + status=mlrun.common.schemas.NotificationStatus.SENT, + sent_time=datetime.datetime.now(tz=datetime.timezone.utc), + ) except Exception as exc: - if mlrun.config.is_running_as_api(): - await self._update_notification_status( - db, - run.metadata.uid, - run.metadata.project, - notification_object, - status=mlrun.api.schemas.NotificationStatus.ERROR, - ) + await run_in_threadpool( + self._update_notification_status, + db, + run.metadata.uid, + run.metadata.project, + notification_object, + status=mlrun.common.schemas.NotificationStatus.ERROR, + ) raise exc @staticmethod - async def _update_notification_status( + def _update_notification_status( db: mlrun.api.db.base.DBInterface, run_uid: str, project: str, notification: mlrun.model.Notification, status: str = None, - sent_time: datetime.datetime = None, + sent_time: typing.Optional[datetime.datetime] = None, ): + + # nothing to update if not running as api + # note, the notification mechanism may run "locally" for certain runtimes + if not mlrun.config.is_running_as_api(): + return + + # TODO: move to api side db_session = mlrun.api.db.session.create_session() notification.status = status or notification.status notification.sent_time = sent_time or notification.sent_time # store directly in db, no need to use crud as the secrets are already loaded - await run_in_threadpool( - db.store_run_notifications, + db.store_run_notifications( db_session, [notification], run_uid, @@ -204,64 +292,78 @@ async def _update_notification_status( class CustomNotificationPusher(object): def __init__(self, notification_types: typing.List[str] = None): - self._notifications = { + notifications = { notification_type: NotificationTypes(notification_type).get_notification()() for notification_type in notification_types } + self._sync_notifications = { + notification_type: notification + for notification_type, notification in notifications.items() + if not notification.is_async + } + self._async_notifications = { + notification_type: notification + for notification_type, notification in notifications.items() + if notification.is_async + } def push( self, message: str, severity: typing.Union[ - mlrun.api.schemas.NotificationSeverity, str - ] = mlrun.api.schemas.NotificationSeverity.INFO, + mlrun.common.schemas.NotificationSeverity, str + ] = mlrun.common.schemas.NotificationSeverity.INFO, runs: typing.Union[mlrun.lists.RunList, list] = None, custom_html: str = None, ): - async def _push(): + def _sync_push(): + for notification_type, notification in self._sync_notifications.items(): + if self.should_push_notification(notification_type): + notification.push(message, severity, runs, custom_html) + + async def _async_push(): tasks = [] - for notification_type, notification in self._notifications.items(): + for notification_type, notification in self._async_notifications.items(): if self.should_push_notification(notification_type): tasks.append( - self._push_notification( - notification, message, severity, runs, custom_html - ) + notification.push(message, severity, runs, custom_html) ) - await asyncio.gather(*tasks) + # return exceptions to "best-effort" fire all notifications + await asyncio.gather(*tasks, return_exceptions=True) + + # first push async notifications main_event_loop = asyncio.get_event_loop() if main_event_loop.is_running(): - asyncio.run_coroutine_threadsafe(_push(), main_event_loop) + asyncio.run_coroutine_threadsafe(_async_push(), main_event_loop) else: - main_event_loop.run_until_complete(_push()) + main_event_loop.run_until_complete(_async_push()) - @staticmethod - async def _push_notification( - notification: NotificationBase, - message: str, - severity: typing.Union[ - mlrun.api.schemas.NotificationSeverity, str - ] = mlrun.api.schemas.NotificationSeverity.INFO, - runs: typing.Union[mlrun.lists.RunList, list] = None, - custom_html: str = None, - ): - if asyncio.iscoroutinefunction(notification.push): - await notification.push(message, severity, runs, custom_html) - else: - notification.push(message, severity, runs, custom_html) + # then push sync notifications + if not mlrun.config.is_running_as_api(): + _sync_push() def add_notification( self, notification_type: str, params: typing.Dict[str, str] = None ): - if notification_type in self._notifications: - self._notifications[notification_type].load_notification(params) + if notification_type in self._async_notifications: + self._async_notifications[notification_type].load_notification(params) + elif notification_type in self._sync_notifications: + self._sync_notifications[notification_type].load_notification(params) else: - self._notifications[notification_type] = NotificationTypes( - notification_type - ).get_notification()(params) + notification = NotificationTypes(notification_type).get_notification()( + params + ) + if notification.is_async: + self._async_notifications[notification_type] = notification + else: + self._sync_notifications[notification_type] = notification def should_push_notification(self, notification_type): - notification = self._notifications.get(notification_type) + notifications = {} + notifications.update(self._sync_notifications) + notifications.update(self._async_notifications) + notification = notifications.get(notification_type) if not notification or not notification.active: return False @@ -271,9 +373,7 @@ def should_push_notification(self, notification_type): notification_type ).inverse_dependencies() for inverse_dependency in inverse_dependencies: - inverse_dependency_notification = self._notifications.get( - inverse_dependency - ) + inverse_dependency_notification = notifications.get(inverse_dependency) if ( inverse_dependency_notification and inverse_dependency_notification.active @@ -339,3 +439,22 @@ def push_pipeline_run_results( if state: text += f", state={state}" self.push(text, "info", runs=runs_list) + + +def _sanitize_notification(notification: mlrun.model.Notification): + notification_dict = notification.to_dict() + notification_dict.pop("params", None) + return notification_dict + + +def _separate_sync_notifications( + notifications: typing.List[NotificationBase], +) -> typing.Tuple[typing.List[NotificationBase], typing.List[NotificationBase]]: + sync_notifications = [] + async_notifications = [] + for notification in notifications: + if notification.is_async: + async_notifications.append(notification) + else: + sync_notifications.append(notification) + return sync_notifications, async_notifications diff --git a/mlrun/utils/vault.py b/mlrun/utils/vault.py index 1a679466e5f3..5ca3e82230e8 100644 --- a/mlrun/utils/vault.py +++ b/mlrun/utils/vault.py @@ -11,271 +11,272 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import json -import os -from os.path import expanduser - -import requests - -from mlrun.errors import MLRunInvalidArgumentError - -from ..config import config as mlconf -from ..k8s_utils import get_k8s_helper -from .helpers import logger - -vault_default_prefix = "v1/secret/data" - - -class VaultStore: - def __init__(self, token=None): - self._token = token - self.url = mlconf.secret_stores.vault.url - - @property - def token(self): - if not self._token: - self._login() - - return self._token - - def _login(self): - if self._token: - return - - if mlconf.secret_stores.vault.user_token != "": - logger.warning( - "Using a user-token from configuration. This should only be done in test/debug!" - ) - self._token = mlconf.secret_stores.vault.user_token - return - - config_role = mlconf.secret_stores.vault.role - if config_role != "": - role_type, role_val = config_role.split(":", 1) - vault_role = f"mlrun-role-{role_type}-{role_val}" - self._safe_login_with_jwt_token(vault_role) - - if self._token is None: - logger.warning( - "Vault login: no vault token is available. No secrets will be accessible" - ) - - @staticmethod - def _generate_path( - prefix=vault_default_prefix, - user=None, - project=None, - user_prefix="users", - project_prefix="projects", - ): - if user and project: - raise MLRunInvalidArgumentError( - "Both user and project were provided for Vault operations" - ) - - if user: - return prefix + f"/mlrun/{user_prefix}/{user}" - elif project: - return prefix + f"/mlrun/{project_prefix}/{project}" - else: - raise MLRunInvalidArgumentError( - "To generate a vault secret path, either user or project must be specified" - ) - - @staticmethod - def _read_jwt_token(): - # if for some reason the path to the token is not in conf, then attempt to get the SA token (works on k8s pods) - token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token" - if mlconf.secret_stores.vault.token_path: - # Override the default SA token in case a specific token is installed in the mlconf-specified path - secret_token_path = expanduser( - mlconf.secret_stores.vault.token_path + "/token" - ) - if os.path.isfile(secret_token_path): - token_path = secret_token_path - - with open(token_path, "r") as token_file: - jwt_token = token_file.read() - - return jwt_token - - def _api_call(self, method, url, data=None): - self._login() - - headers = {"X-Vault-Token": self._token} - full_url = self.url + "/" + url - - response = requests.request(method, full_url, headers=headers, json=data) - - if not response: - logger.error( - "Vault failed the API call", - status_code=response.status_code, - reason=response.reason, - url=url, - ) - return response - - # This method logins to the vault, assuming the container has a JWT token mounted as part of its assigned service - # account. - def _safe_login_with_jwt_token(self, role): - - if role is None: - logger.warning( - "login_with_token: Role passed is None. Will not attempt login" - ) - return - - jwt_token = self._read_jwt_token() - - login_url = f"{self.url}/v1/auth/kubernetes/login" - data = {"jwt": jwt_token, "role": role} - - response = requests.post(login_url, data=json.dumps(data)) - if not response: - logger.error( - "login_with_token: Vault failed the login request", - role=role, - status_code=response.status_code, - reason=response.reason, - ) - return - self._token = response.json()["auth"]["client_token"] - - def get_secrets(self, keys, user=None, project=None): - secret_path = VaultStore._generate_path(user=user, project=project) - secrets = {} - - # Since this method is called both on the client side (when constructing VaultStore before persisting to - # pod configuration) and on server side and in execution pods, we let this method fail gracefully in this case. - # Should replace with something that will explode on server-side, once we have a way to do that. - if not self.url: - return secrets - - response = self._api_call("GET", secret_path) - - if not response: - return secrets - - values = response.json()["data"]["data"] - - # if no specific keys were asked for, return all the values available - if not keys: - return values - - for key in keys: - if key in values: - secrets[key] = values[key] - return secrets - - def add_vault_secrets(self, items, project=None, user=None): - data_object = {"data": items} - url = VaultStore._generate_path(project=project, user=user) - - response = self._api_call("POST", url, data_object) - if not response: - raise MLRunInvalidArgumentError( - f"Vault failed the API call to create secrets. project={project}/user={user}" - ) - - def delete_vault_secrets(self, project=None, user=None): - self._login() - # Using the API to delete all versions + metadata of the given secret. - url = "v1/secret/metadata" + VaultStore._generate_path( - prefix="", project=project, user=user - ) - - response = self._api_call("DELETE", url) - if not response: - raise MLRunInvalidArgumentError( - f"Vault failed the API call to delete secrets. project={project}/user={user}" - ) - - def create_project_policy(self, project): - policy_name = f"mlrun-project-{project}" - # TODO - need to make sure name is escaped properly and invalid chars are stripped - url = "v1/sys/policies/acl/" + policy_name - - policy_str = ( - f'path "secret/data/mlrun/projects/{project}" {{\n' - + ' capabilities = ["read", "list", "create", "delete", "update"]\n' - + "}\n" - + f'path "secret/data/mlrun/projects/{project}/*" {{\n' - + ' capabilities = ["read", "list", "create", "delete", "update"]\n' - + "}" - ) - - data_object = {"policy": policy_str} - - response = self._api_call("PUT", url, data_object) - if not response: - raise MLRunInvalidArgumentError( - f"Vault failed the API call to create a policy. " - f"Response code: ({response.status_code}) - {response.reason}" - ) - return policy_name - - def create_project_role(self, project, sa, policy, namespace="default-tenant"): - role_name = f"mlrun-role-project-{project}" - # TODO - need to make sure name is escaped properly and invalid chars are stripped - url = "v1/auth/kubernetes/role/" + role_name - - role_object = { - "bound_service_account_names": sa, - "bound_service_account_namespaces": namespace, - "policies": [policy], - "token_ttl": mlconf.secret_stores.vault.token_ttl, - } - - response = self._api_call("POST", url, role_object) - if not response: - raise MLRunInvalidArgumentError( - f"Vault failed the API call to create a secret. " - f"Response code: ({response.status_code}) - {response.reason}" - ) - return role_name - - -def store_vault_project_secrets(project, items): - return VaultStore().add_vault_secrets(items, project=project) - - -def add_vault_user_secrets(user, items): - return VaultStore().add_vault_secrets(items, user=user) - - -def init_project_vault_configuration(project): - """Create needed configurations for this new project: - - Create a k8s service account with the name sa_vault_{proj name} - - Create a Vault policy with the name proj_{proj name} - - Create a Vault k8s auth role with the name role_proj_{proj name} - These constructs will enable any pod created as part of this project to access the project's secrets - in Vault, assuming that the secret which is part of the SA created is mounted to the pod. - - :param project: Project name - """ - logger.info("Initializing project vault configuration", project=project) - - namespace = mlconf.namespace - k8s = get_k8s_helper(silent=True) - service_account_name = ( - mlconf.secret_stores.vault.project_service_account_name.format(project=project) - ) - - secret_name = k8s.get_project_vault_secret_name( - project, service_account_name, namespace=namespace - ) - - if not secret_name: - k8s.create_project_service_account( - project, service_account_name, namespace=namespace - ) - - vault = VaultStore() - policy_name = vault.create_project_policy(project) - role_name = vault.create_project_role( - project, namespace=namespace, sa=service_account_name, policy=policy_name - ) - - logger.info("Created Vault policy. ", policy=policy_name, role=role_name) +# +# import json +# import os +# from os.path import expanduser +# +# import requests +# +# from mlrun.errors import MLRunInvalidArgumentError +# +# from ..config import config as mlconf +# from ..k8s_utils import get_k8s_helper +# from .helpers import logger +# +# vault_default_prefix = "v1/secret/data" +# +# +# class VaultStore: +# def __init__(self, token=None): +# self._token = token +# self.url = mlconf.secret_stores.vault.url +# +# @property +# def token(self): +# if not self._token: +# self._login() +# +# return self._token +# +# def _login(self): +# if self._token: +# return +# +# if mlconf.secret_stores.vault.user_token != "": +# logger.warning( +# "Using a user-token from configuration. This should only be done in test/debug!" +# ) +# self._token = mlconf.secret_stores.vault.user_token +# return +# +# config_role = mlconf.secret_stores.vault.role +# if config_role != "": +# role_type, role_val = config_role.split(":", 1) +# vault_role = f"mlrun-role-{role_type}-{role_val}" +# self._safe_login_with_jwt_token(vault_role) +# +# if self._token is None: +# logger.warning( +# "Vault login: no vault token is available. No secrets will be accessible" +# ) +# +# @staticmethod +# def _generate_path( +# prefix=vault_default_prefix, +# user=None, +# project=None, +# user_prefix="users", +# project_prefix="projects", +# ): +# if user and project: +# raise MLRunInvalidArgumentError( +# "Both user and project were provided for Vault operations" +# ) +# +# if user: +# return prefix + f"/mlrun/{user_prefix}/{user}" +# elif project: +# return prefix + f"/mlrun/{project_prefix}/{project}" +# else: +# raise MLRunInvalidArgumentError( +# "To generate a vault secret path, either user or project must be specified" +# ) +# +# @staticmethod +# def _read_jwt_token(): +# # if for some reason the path to the token is not in conf, then attempt to get the SA token +# # (works on k8s pods) +# token_path = "/var/run/secrets/kubernetes.io/serviceaccount/token" +# if mlconf.secret_stores.vault.token_path: +# # Override the default SA token in case a specific token is installed in the mlconf-specified path +# secret_token_path = expanduser( +# mlconf.secret_stores.vault.token_path + "/token" +# ) +# if os.path.isfile(secret_token_path): +# token_path = secret_token_path +# +# with open(token_path, "r") as token_file: +# jwt_token = token_file.read() +# +# return jwt_token +# +# def _api_call(self, method, url, data=None): +# self._login() +# +# headers = {"X-Vault-Token": self._token} +# full_url = self.url + "/" + url +# +# response = requests.request(method, full_url, headers=headers, json=data) +# +# if not response: +# logger.error( +# "Vault failed the API call", +# status_code=response.status_code, +# reason=response.reason, +# url=url, +# ) +# return response +# +# # This method logins to the vault, assuming the container has a JWT token mounted as part of its assigned service +# # account. +# def _safe_login_with_jwt_token(self, role): +# +# if role is None: +# logger.warning( +# "login_with_token: Role passed is None. Will not attempt login" +# ) +# return +# +# jwt_token = self._read_jwt_token() +# +# login_url = f"{self.url}/v1/auth/kubernetes/login" +# data = {"jwt": jwt_token, "role": role} +# +# response = requests.post(login_url, data=json.dumps(data)) +# if not response: +# logger.error( +# "login_with_token: Vault failed the login request", +# role=role, +# status_code=response.status_code, +# reason=response.reason, +# ) +# return +# self._token = response.json()["auth"]["client_token"] +# +# def get_secrets(self, keys, user=None, project=None): +# secret_path = VaultStore._generate_path(user=user, project=project) +# secrets = {} +# +# # Since this method is called both on the client side (when constructing VaultStore before persisting to +# # pod configuration) and on server side and in execution pods, we let this method fail gracefully in this case +# # Should replace with something that will explode on server-side, once we have a way to do that. +# if not self.url: +# return secrets +# +# response = self._api_call("GET", secret_path) +# +# if not response: +# return secrets +# +# values = response.json()["data"]["data"] +# +# # if no specific keys were asked for, return all the values available +# if not keys: +# return values +# +# for key in keys: +# if key in values: +# secrets[key] = values[key] +# return secrets +# +# def add_vault_secrets(self, items, project=None, user=None): +# data_object = {"data": items} +# url = VaultStore._generate_path(project=project, user=user) +# +# response = self._api_call("POST", url, data_object) +# if not response: +# raise MLRunInvalidArgumentError( +# f"Vault failed the API call to create secrets. project={project}/user={user}" +# ) +# +# def delete_vault_secrets(self, project=None, user=None): +# self._login() +# # Using the API to delete all versions + metadata of the given secret. +# url = "v1/secret/metadata" + VaultStore._generate_path( +# prefix="", project=project, user=user +# ) +# +# response = self._api_call("DELETE", url) +# if not response: +# raise MLRunInvalidArgumentError( +# f"Vault failed the API call to delete secrets. project={project}/user={user}" +# ) +# +# def create_project_policy(self, project): +# policy_name = f"mlrun-project-{project}" +# # TODO - need to make sure name is escaped properly and invalid chars are stripped +# url = "v1/sys/policies/acl/" + policy_name +# +# policy_str = ( +# f'path "secret/data/mlrun/projects/{project}" {{\n' +# + ' capabilities = ["read", "list", "create", "delete", "update"]\n' +# + "}\n" +# + f'path "secret/data/mlrun/projects/{project}/*" {{\n' +# + ' capabilities = ["read", "list", "create", "delete", "update"]\n' +# + "}" +# ) +# +# data_object = {"policy": policy_str} +# +# response = self._api_call("PUT", url, data_object) +# if not response: +# raise MLRunInvalidArgumentError( +# f"Vault failed the API call to create a policy. " +# f"Response code: ({response.status_code}) - {response.reason}" +# ) +# return policy_name +# +# def create_project_role(self, project, sa, policy, namespace="default-tenant"): +# role_name = f"mlrun-role-project-{project}" +# # TODO - need to make sure name is escaped properly and invalid chars are stripped +# url = "v1/auth/kubernetes/role/" + role_name +# +# role_object = { +# "bound_service_account_names": sa, +# "bound_service_account_namespaces": namespace, +# "policies": [policy], +# "token_ttl": mlconf.secret_stores.vault.token_ttl, +# } +# +# response = self._api_call("POST", url, role_object) +# if not response: +# raise MLRunInvalidArgumentError( +# f"Vault failed the API call to create a secret. " +# f"Response code: ({response.status_code}) - {response.reason}" +# ) +# return role_name +# +# +# def store_vault_project_secrets(project, items): +# return VaultStore().add_vault_secrets(items, project=project) +# +# +# def add_vault_user_secrets(user, items): +# return VaultStore().add_vault_secrets(items, user=user) +# +# +# def init_project_vault_configuration(project): +# """Create needed configurations for this new project: +# - Create a k8s service account with the name sa_vault_{proj name} +# - Create a Vault policy with the name proj_{proj name} +# - Create a Vault k8s auth role with the name role_proj_{proj name} +# These constructs will enable any pod created as part of this project to access the project's secrets +# in Vault, assuming that the secret which is part of the SA created is mounted to the pod. +# +# :param project: Project name +# """ +# logger.info("Initializing project vault configuration", project=project) +# +# namespace = mlconf.namespace +# k8s = get_k8s_helper(silent=True) +# service_account_name = ( +# mlconf.secret_stores.vault.project_service_account_name.format(project=project) +# ) +# +# secret_name = k8s.get_project_vault_secret_name( +# project, service_account_name, namespace=namespace +# ) +# +# if not secret_name: +# k8s.create_project_service_account( +# project, service_account_name, namespace=namespace +# ) +# +# vault = VaultStore() +# policy_name = vault.create_project_policy(project) +# role_name = vault.create_project_role( +# project, namespace=namespace, sa=service_account_name, policy=policy_name +# ) +# +# logger.info("Created Vault policy. ", policy=policy_name, role=role_name) diff --git a/requirements.txt b/requirements.txt index 9b1f0ec99f8e..115084a1bcc9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -# >=1.25.4, <1.27 from botocore 1.19.28 inside boto3 1.16.28 inside nuclio-jupyter 0.8.8 -urllib3>=1.25.4, <1.27 +# >=1.26.9, <1.27 from botocore 1.19.28 inside boto3 1.16.28 inside nuclio-jupyter 0.8.8 +urllib3>=1.26.9, <1.27 # >=3.0.2 from requests 2.25.1 <4.0 from aiohttp 3.7.3, requests is <5, so without the upbound there's a conflict chardet>=3.0.2, <4.0 GitPython~=3.1, >= 3.1.30 @@ -7,11 +7,6 @@ aiohttp~=3.8 aiohttp-retry~=2.8 # 8.1.0+ breaks dask/distributed versions older than 2022.04.0, see here - https://github.com/dask/distributed/pull/6018 click~=8.0.0 -# when installing google-cloud-storage which required >=3.20.1, <5 it was upgrading the protobuf version to the latest -# version and because kfp 1.8.13 requires protobuf>=3.13, <4 it resulted incompatibility between kfp and protobuf -# this can be removed once kfp will support protobuf > 4 -# since google-cloud blacklisted 3.20.0 and 3.20.1 we start from 3.20.2 -protobuf>=3.13, <3.20 # 3.0/3.2 iguazio system uses 1.0.1, but we needed >=1.6.0 to be compatible with k8s>=12.0 to fix scurity issue # since the sdk is still mark as beta (and not stable) I'm limiting to only patch changes # 1.8.14 introduced new features related to ParallelFor, while our actual kfp server is 1.8.1, which isn't compatible @@ -21,15 +16,15 @@ nest-asyncio~=1.0 # ipython 8.0 + only supports python3.8 +, so to keep backwards compatibility with python 3.7 we support 7.x # we rely on pip and nuclio-jupyter requirements to install the right package per python version ipython>=7.0, <9.0 -nuclio-jupyter~=0.9.9 +nuclio-jupyter~=0.9.10 # >=1.16.5 from pandas 1.2.1 and <1.23.0 from storey numpy>=1.16.5, <1.23.0 # limiting pandas to <1.5.0 since 1.5.0 causes exception in storey on casting from ns to us pandas~=1.2, <1.5.0 # used as a the engine for parquet files by pandas # >=10 to resolve https://issues.apache.org/jira/browse/ARROW-16838 bug that is triggered by ingest (ML-3299) -# < 11 since starting from 11 ParquetDataset is deprecated and ParquetDatasetV2 is used instead -pyarrow>=10,<11 +# <12 to prevent bugs due to major upgrading +pyarrow>=10.0, <12 pyyaml~=5.1 requests~=2.22 # in sqlalchemy>=2.0 there is breaking changes (such as in Table class autoload argument is removed) @@ -38,7 +33,8 @@ sqlalchemy~=1.4 tabulate~=0.8.6 v3io~=0.5.20 pydantic~=1.5 -orjson~=3.3 +# blacklist 3.8.12 due to a bug not being able to collect traceback of exceptions +orjson~=3.3, <3.8.12 alembic~=1.9 mergedeep~=1.3 v3io-frames~=0.10.4 @@ -49,11 +45,11 @@ distributed~=2021.11.2 kubernetes~=12.0 # TODO: move to API requirements (shouldn't really be here, the sql run db using the API sqldb is preventing us from # separating the SDK and API code) (referring to humanfriendly and fastapi) -humanfriendly~=8.2 -fastapi~=0.92.0 -fsspec~=2021.8.1 +humanfriendly~=9.2 +fastapi~=0.95.2 +fsspec~=2023.1.0 v3iofs~=0.1.15 -storey~=1.3.15 +storey~=1.4.3 deepdiff~=5.0 pymysql~=1.0 inflection~=0.5.0 @@ -61,3 +57,4 @@ python-dotenv~=0.17.0 # older version of setuptools contains vulnerabilities, see `GHSA-r9hx-vwmv-q579`, so we limit to 65.5 and above setuptools~=65.5 deprecated~=1.2 +jinja2~=3.1 diff --git a/tests/api/api/feature_store/base.py b/tests/api/api/feature_store/base.py index 840aa1323ce0..32008d6f94cd 100644 --- a/tests/api/api/feature_store/base.py +++ b/tests/api/api/feature_store/base.py @@ -17,7 +17,7 @@ from deepdiff import DeepDiff from fastapi.testclient import TestClient -import mlrun.api.schemas +import mlrun.common.schemas def _list_and_assert_objects( @@ -68,7 +68,7 @@ def _patch_object( patch_mode = "replace" if additive: patch_mode = "additive" - headers = {mlrun.api.schemas.HeaderNames.patch_mode: patch_mode} + headers = {mlrun.common.schemas.HeaderNames.patch_mode: patch_mode} response = client.patch( f"projects/{project_name}/{object_url_path}/{name}/references/{reference}", json=object_update, diff --git a/tests/api/api/feature_store/test_feature_vectors.py b/tests/api/api/feature_store/test_feature_vectors.py index 136e94d9e40f..1402a8602e1a 100644 --- a/tests/api/api/feature_store/test_feature_vectors.py +++ b/tests/api/api/feature_store/test_feature_vectors.py @@ -23,8 +23,8 @@ from sqlalchemy.orm import Session import mlrun.api.api.endpoints.feature_store -import mlrun.api.schemas import mlrun.api.utils.auth.verifier +import mlrun.common.schemas import tests.api.api.utils from .base import ( @@ -483,11 +483,11 @@ async def test_verify_feature_vector_features_permissions( label_feature = "some-feature-set.some-feature" def _verify_queried_resources( - resource_type: mlrun.api.schemas.AuthorizationResourceTypes, + resource_type: mlrun.common.schemas.AuthorizationResourceTypes, resources: typing.List, project_and_resource_name_extractor: typing.Callable, - action: mlrun.api.schemas.AuthorizationAction, - auth_info: mlrun.api.schemas.AuthInfo, + action: mlrun.common.schemas.AuthorizationAction, + auth_info: mlrun.common.schemas.AuthInfo, raise_on_forbidden: bool = True, ): expected_resources = [ @@ -508,7 +508,7 @@ def _verify_queried_resources( unittest.mock.AsyncMock(side_effect=_verify_queried_resources) ) await mlrun.api.api.endpoints.feature_store._verify_feature_vector_features_permissions( - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, {"spec": {"features": features, "label_feature": label_feature}}, ) diff --git a/tests/api/api/framework/test_middlewares.py b/tests/api/api/framework/test_middlewares.py index 4fdcea3e75c9..2e98030df931 100644 --- a/tests/api/api/framework/test_middlewares.py +++ b/tests/api/api/framework/test_middlewares.py @@ -19,7 +19,7 @@ import pytest import sqlalchemy.orm -import mlrun.api.schemas.constants +import mlrun.common.schemas.constants import mlrun.utils.version @@ -51,20 +51,20 @@ def test_ui_clear_cache_middleware( response = client.get( "client-spec", headers={ - mlrun.api.schemas.constants.HeaderNames.ui_version: ui_version, + mlrun.common.schemas.constants.HeaderNames.ui_version: ui_version, }, ) if clear_cache: assert response.headers["Clear-Site-Data"] == '"cache"' assert ( - response.headers[mlrun.api.schemas.constants.HeaderNames.ui_clear_cache] + response.headers[mlrun.common.schemas.constants.HeaderNames.ui_clear_cache] == "true" ) else: assert "Clear-Site-Data" not in response.headers assert ( - mlrun.api.schemas.constants.HeaderNames.ui_clear_cache + mlrun.common.schemas.constants.HeaderNames.ui_clear_cache not in response.headers ) @@ -77,6 +77,6 @@ def test_ensure_be_version_middleware( ) as mock_version_get: response = client.get("client-spec") assert ( - response.headers[mlrun.api.schemas.constants.HeaderNames.backend_version] + response.headers[mlrun.common.schemas.constants.HeaderNames.backend_version] == mock_version_get.return_value["version"] ) diff --git a/tests/api/api/hub/__init__.py b/tests/api/api/hub/__init__.py new file mode 100644 index 000000000000..b3085be1eb56 --- /dev/null +++ b/tests/api/api/hub/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/api/api/marketplace/functions/channel/catalog.json b/tests/api/api/hub/functions/channel/catalog.json similarity index 98% rename from tests/api/api/marketplace/functions/channel/catalog.json rename to tests/api/api/hub/functions/channel/catalog.json index 0407fc5acae5..18b99bcd6e9b 100644 --- a/tests/api/api/marketplace/functions/channel/catalog.json +++ b/tests/api/api/hub/functions/channel/catalog.json @@ -27,6 +27,9 @@ "pandas_profiling" ] }, + "assets": { + "html_asset": "static/my_html.html" + }, "url": "", "version": "0.0.1" } diff --git a/tests/api/api/hub/functions/channel/dev_function/latest/static/my_html.html b/tests/api/api/hub/functions/channel/dev_function/latest/static/my_html.html new file mode 100644 index 000000000000..2a53fedf8ac0 --- /dev/null +++ b/tests/api/api/hub/functions/channel/dev_function/latest/static/my_html.html @@ -0,0 +1,6 @@ + + + +

Example HTML File

+ + diff --git a/tests/api/api/marketplace/test_marketplace.py b/tests/api/api/hub/test_hub.py similarity index 67% rename from tests/api/api/marketplace/test_marketplace.py rename to tests/api/api/hub/test_hub.py index eb8bb76917aa..7459d1e2aaa5 100644 --- a/tests/api/api/marketplace/test_marketplace.py +++ b/tests/api/api/hub/test_hub.py @@ -12,17 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import http import pathlib import random from http import HTTPStatus import deepdiff +import pytest import yaml from fastapi.testclient import TestClient from sqlalchemy.orm import Session import mlrun.api.crud -import mlrun.api.schemas +import mlrun.common.schemas import tests.api.conftest from mlrun.config import config @@ -33,7 +35,7 @@ def _generate_source_dict(index, name, credentials=None): return { "index": index, "source": { - "kind": "MarketplaceSource", + "kind": "HubSource", "metadata": {"name": name, "description": "A test", "labels": None}, "spec": { "path": path, @@ -51,7 +53,7 @@ def _assert_sources_in_correct_order(client, expected_order, exclude_paths=None) "root['metadata']['created']", "root['spec']['object_type']", ] - response = client.get("marketplace/sources") + response = client.get("hub/sources") assert response.status_code == HTTPStatus.OK.value json_response = response.json() # Default source is not in the expected data @@ -68,29 +70,29 @@ def _assert_sources_in_correct_order(client, expected_order, exclude_paths=None) ) -def test_marketplace_source_apis( +def test_hub_source_apis( db: Session, client: TestClient, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ) -> None: # Make sure the default source is there. - response = client.get("marketplace/sources") + response = client.get("hub/sources") assert response.status_code == HTTPStatus.OK.value json_response = response.json() assert ( len(json_response) == 1 and json_response[0]["index"] == -1 and json_response[0]["source"]["metadata"]["name"] - == config.marketplace.default_source.name + == config.hub.default_source.name ) source_1 = _generate_source_dict(1, "source_1") - response = client.post("marketplace/sources", json=source_1) + response = client.post("hub/sources", json=source_1) assert response.status_code == HTTPStatus.CREATED.value # Modify existing source with a new field source_1["source"]["metadata"]["something_new"] = 42 - response = client.put("marketplace/sources/source_1", json=source_1) + response = client.put("hub/sources/source_1", json=source_1) assert response.status_code == HTTPStatus.OK.value exclude_paths = [ "root['metadata']['updated']", @@ -106,12 +108,12 @@ def test_marketplace_source_apis( # Insert in 1st place, pushing source_1 to be #2 source_2 = _generate_source_dict(1, "source_2") - response = client.put("marketplace/sources/source_2", json=source_2) + response = client.put("hub/sources/source_2", json=source_2) assert response.status_code == HTTPStatus.OK.value # Insert last, making it #3 source_3 = _generate_source_dict(-1, "source_3") - response = client.post("marketplace/sources", json=source_3) + response = client.post("hub/sources", json=source_3) assert response.status_code == HTTPStatus.CREATED.value expected_response = { @@ -123,7 +125,7 @@ def test_marketplace_source_apis( # Change order for existing source (3->1) source_3["index"] = 1 - response = client.put("marketplace/sources/source_3", json=source_3) + response = client.put("hub/sources/source_3", json=source_3) assert response.status_code == HTTPStatus.OK.value expected_response = { 1: source_3, @@ -132,7 +134,7 @@ def test_marketplace_source_apis( } _assert_sources_in_correct_order(client, expected_response) - response = client.delete("marketplace/sources/source_2") + response = client.delete("hub/sources/source_2") assert response.status_code == HTTPStatus.NO_CONTENT.value expected_response = { @@ -143,27 +145,25 @@ def test_marketplace_source_apis( # Negative tests # Try to delete the default source. - response = client.delete( - f"marketplace/sources/{config.marketplace.default_source.name}" - ) + response = client.delete(f"hub/sources/{config.hub.default_source.name}") assert response.status_code == HTTPStatus.BAD_REQUEST.value # Try to store an object with invalid order source_2["index"] = 42 - response = client.post("marketplace/sources", json=source_2) + response = client.post("hub/sources", json=source_2) assert response.status_code == HTTPStatus.BAD_REQUEST.value -def test_marketplace_credentials_removed_from_db( +def test_hub_credentials_removed_from_db( db: Session, client: TestClient, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock ) -> None: # Validate that a source with credentials is stored (and retrieved back) without them, while the creds # are stored in the k8s secret. credentials = {"secret1": "value1", "another-secret": "42"} source_1 = _generate_source_dict(-1, "source_1", credentials) - response = client.post("marketplace/sources", json=source_1) + response = client.post("hub/sources", json=source_1) assert response.status_code == HTTPStatus.CREATED.value - response = client.get("marketplace/sources/source_1") + response = client.get("hub/sources/source_1") assert response.status_code == HTTPStatus.OK.value object_dict = response.json() @@ -181,20 +181,18 @@ def test_marketplace_credentials_removed_from_db( == {} ) expected_credentials = { - mlrun.api.crud.Marketplace()._generate_credentials_secret_key( - "source_1", key - ): value + mlrun.api.crud.Hub()._generate_credentials_secret_key("source_1", key): value for key, value in credentials.items() } k8s_secrets_mock.assert_project_secrets( - config.marketplace.k8s_secrets_project_name, expected_credentials + config.hub.k8s_secrets_project_name, expected_credentials ) -def test_marketplace_source_manager( +def test_hub_source_manager( k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ) -> None: - manager = mlrun.api.crud.Marketplace() + manager = mlrun.api.crud.Hub() credentials = {"secret1": "value1", "secret2": "value2"} expected_credentials = {} @@ -202,28 +200,26 @@ def test_marketplace_source_manager( source_dict = _generate_source_dict(i, f"source_{i}", credentials) expected_credentials.update( { - mlrun.api.crud.Marketplace()._generate_credentials_secret_key( + mlrun.api.crud.Hub()._generate_credentials_secret_key( f"source_{i}", key ): value for key, value in credentials.items() } ) - source_object = mlrun.api.schemas.MarketplaceSource(**source_dict["source"]) + source_object = mlrun.common.schemas.HubSource(**source_dict["source"]) manager.add_source(source_object) k8s_secrets_mock.assert_project_secrets( - config.marketplace.k8s_secrets_project_name, expected_credentials + config.hub.k8s_secrets_project_name, expected_credentials ) manager.remove_source("source_1") for key in credentials: expected_credentials.pop( - mlrun.api.crud.Marketplace()._generate_credentials_secret_key( - "source_1", key - ) + mlrun.api.crud.Hub()._generate_credentials_secret_key("source_1", key) ) k8s_secrets_mock.assert_project_secrets( - config.marketplace.k8s_secrets_project_name, expected_credentials + config.hub.k8s_secrets_project_name, expected_credentials ) # Test catalog retrieval, with various filters @@ -250,13 +246,12 @@ def test_marketplace_source_manager( assert item.metadata.name == "prod_function" and item.metadata.version == "1.0.0" -def test_marketplace_default_source( +def test_hub_default_source( k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ) -> None: # This test validates that the default source is valid is its catalog and objects can be retrieved. - manager = mlrun.api.crud.Marketplace() - - source_object = mlrun.api.schemas.MarketplaceSource.generate_default_source() + manager = mlrun.api.crud.Hub() + source_object = mlrun.common.schemas.HubSource.generate_default_source() catalog = manager.get_source_catalog(source_object) assert len(catalog.catalog) > 0 print(f"Retrieved function catalog. Has {len(catalog.catalog)} functions in it.") @@ -281,19 +276,19 @@ def test_marketplace_default_source( assert yaml_function_name == function_modified_name -def test_marketplace_catalog_apis( +def test_hub_catalog_apis( db: Session, client: TestClient, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock ) -> None: # Get the global hub source-name - sources = client.get("marketplace/sources").json() + sources = client.get("hub/sources").json() source_name = sources[0]["source"]["metadata"]["name"] - catalog = client.get(f"marketplace/sources/{source_name}/items").json() + catalog = client.get(f"hub/sources/{source_name}/items").json() item = random.choice(catalog["catalog"]) url = item["spec"]["item_uri"] + "src/function.yaml" function_yaml = client.get( - f"marketplace/sources/{source_name}/item-object", params={"url": url} + f"hub/sources/{source_name}/item-object", params={"url": url} ) function_dict = yaml.safe_load(function_yaml.content) @@ -304,3 +299,63 @@ def test_marketplace_catalog_apis( function_modified_name = item["metadata"]["name"].replace("_", "-") assert yaml_function_name == function_modified_name + + +def test_hub_get_asset_from_default_source( + db: Session, client: TestClient, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock +) -> None: + possible_assets = [ + ("docs", "text/html; charset=utf-8"), + ("source", "text/x-python; charset=utf-8"), + ("example", "application/octet-stream"), + ("function", "application/octet-stream"), + ] + sources = client.get("hub/sources").json() + source_name = sources[0]["source"]["metadata"]["name"] + catalog = client.get(f"hub/sources/{source_name}/items").json() + for _ in range(10): + item = random.choice(catalog["catalog"]) + asset_name, expected_content_type = random.choice(possible_assets) + response = client.get( + f"hub/sources/{source_name}/items/{item['metadata']['name']}/assets/{asset_name}" + ) + assert response.status_code == http.HTTPStatus.OK.value + assert response.headers["content-type"] == expected_content_type + + +def test_hub_get_asset( + k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, +) -> None: + manager = mlrun.api.crud.Hub() + + # Adding hub source with credentials: + credentials = {"secret": "value"} + + source_dict = _generate_source_dict(1, "source", credentials) + expected_credentials = { + mlrun.api.crud.Hub()._generate_credentials_secret_key( + "source", "secret" + ): credentials["secret"] + } + source_object = mlrun.common.schemas.HubSource(**source_dict["source"]) + manager.add_source(source_object) + k8s_secrets_mock.assert_project_secrets( + config.hub.k8s_secrets_project_name, expected_credentials + ) + # getting asset: + catalog = manager.get_source_catalog(source_object) + item = catalog.catalog[0] + # verifying item contain the asset: + assert item.spec.assets.get("html_asset", "") == "static/my_html.html" + + asset_object, url = manager.get_asset(source_object, item, "html_asset") + relative_asset_path = "functions/channel/dev_function/latest/static/my_html.html" + asset_path = pathlib.Path(__file__).absolute().parent / relative_asset_path + with open(asset_path, "r") as f: + expected_content = f.read() + # Validating content and url: + assert expected_content == asset_object.decode("utf-8") and url == str(asset_path) + + # Verify not-found assets are handled properly + with pytest.raises(mlrun.errors.MLRunNotFoundError): + manager.get_asset(source_object, item, "not-found") diff --git a/tests/api/api/test_artifacts.py b/tests/api/api/test_artifacts.py index e2172db10586..8c1121ace628 100644 --- a/tests/api/api/test_artifacts.py +++ b/tests/api/api/test_artifacts.py @@ -19,8 +19,8 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -import mlrun.api.schemas import mlrun.artifacts +import mlrun.common.schemas from mlrun.utils.helpers import is_legacy_artifact PROJECT = "prj" @@ -45,9 +45,9 @@ def test_list_artifact_tags(db: Session, client: TestClient) -> None: def _create_project(client: TestClient, project_name: str = PROJECT): - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec( + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec( description="banana", source="source", goals="some goals" ), ) @@ -134,7 +134,7 @@ def test_store_artifact_with_invalid_tag(db: Session, client: TestClient): "projects/{project}/tags/{tag}".format(project=PROJECT, tag=tag), json={ "kind": "artifact", - "identifiers": [(mlrun.api.schemas.ArtifactIdentifier(key=KEY).dict())], + "identifiers": [(mlrun.common.schemas.ArtifactIdentifier(key=KEY).dict())], }, ) @@ -145,7 +145,7 @@ def test_store_artifact_with_invalid_tag(db: Session, client: TestClient): "projects/{project}/tags/{tag}".format(project=PROJECT, tag=tag), json={ "kind": "artifact", - "identifiers": [(mlrun.api.schemas.ArtifactIdentifier(key=KEY).dict())], + "identifiers": [(mlrun.common.schemas.ArtifactIdentifier(key=KEY).dict())], }, ) assert resp.status_code == HTTPStatus.UNPROCESSABLE_ENTITY.value @@ -368,7 +368,7 @@ def test_list_artifact_with_multiple_tags(db: Session, client: TestClient): "projects/{project}/tags/{tag}".format(project=PROJECT, tag=new_tag), json={ "kind": "artifact", - "identifiers": [(mlrun.api.schemas.ArtifactIdentifier(key=KEY).dict())], + "identifiers": [(mlrun.common.schemas.ArtifactIdentifier(key=KEY).dict())], }, ) diff --git a/tests/api/api/test_auth.py b/tests/api/api/test_auth.py index 0c2c2625791f..3180c9133bd2 100644 --- a/tests/api/api/test_auth.py +++ b/tests/api/api/test_auth.py @@ -17,15 +17,18 @@ import fastapi.testclient import sqlalchemy.orm -import mlrun.api.schemas import mlrun.api.utils.auth.verifier +import mlrun.common.schemas def test_verify_authorization( db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient ) -> None: - authorization_verification_input = mlrun.api.schemas.AuthorizationVerificationInput( - resource="/some-resource", action=mlrun.api.schemas.AuthorizationAction.create + authorization_verification_input = ( + mlrun.common.schemas.AuthorizationVerificationInput( + resource="/some-resource", + action=mlrun.common.schemas.AuthorizationAction.create, + ) ) async def _mock_successful_query_permissions(resource, action, *args): diff --git a/tests/api/api/test_background_tasks.py b/tests/api/api/test_background_tasks.py index 68ccd8b98a09..897104e708c1 100644 --- a/tests/api/api/test_background_tasks.py +++ b/tests/api/api/test_background_tasks.py @@ -25,10 +25,10 @@ import mlrun.api.api.deps import mlrun.api.main -import mlrun.api.schemas import mlrun.api.utils.auth.verifier import mlrun.api.utils.background_tasks import mlrun.api.utils.clients.chief +import mlrun.common.schemas test_router = fastapi.APIRouter() @@ -37,7 +37,7 @@ # and to get this class, we must trigger an endpoint @test_router.post( "/projects/{project}/background-tasks", - response_model=mlrun.api.schemas.BackgroundTask, + response_model=mlrun.common.schemas.BackgroundTask, ) def create_project_background_task( project: str, @@ -57,7 +57,7 @@ def create_project_background_task( @test_router.post( "/internal-background-tasks", - response_model=mlrun.api.schemas.BackgroundTask, + response_model=mlrun.common.schemas.BackgroundTask, ) def create_internal_background_task( background_tasks: fastapi.BackgroundTasks, @@ -174,9 +174,10 @@ def test_create_project_background_task_in_chief_success( f"{ORIGINAL_VERSIONED_API_PREFIX}/projects/{project}/background-tasks/{background_task.metadata.name}" ) assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) assert ( - background_task.status.state == mlrun.api.schemas.BackgroundTaskState.succeeded + background_task.status.state + == mlrun.common.schemas.BackgroundTaskState.succeeded ) assert background_task.metadata.updated is not None assert call_counter == 1 @@ -194,8 +195,10 @@ def test_create_project_background_task_in_chief_failure( f"{ORIGINAL_VERSIONED_API_PREFIX}/projects/{project}/background-tasks/{background_task.metadata.name}" ) assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.failed + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.failed + ) assert background_task.metadata.updated is not None @@ -219,7 +222,7 @@ def test_get_background_task_auth_skip( mlrun.mlconf.igz_version = "3.2.0-b26.20210904121245" response = client.post("/test/internal-background-tasks") assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) response = client.get( f"{ORIGINAL_VERSIONED_API_PREFIX}/background-tasks/{background_task.metadata.name}" ) @@ -257,7 +260,7 @@ def test_get_internal_background_task_redirect_from_worker_to_chief_exists( ) response = client.get(f"{ORIGINAL_VERSIONED_API_PREFIX}/background-tasks/{name}") assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) assert background_task == expected_background_task @@ -284,7 +287,7 @@ def test_get_internal_background_task_in_chief_exists( ): response = client.post("/test/internal-background-tasks") assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) assert background_task.metadata.project is None response = client.get( @@ -340,26 +343,28 @@ def test_trigger_migrations_from_worker_returns_same_response_as_chief( def _generate_background_task( background_task_name, - state: mlrun.api.schemas.BackgroundTaskState = mlrun.api.schemas.BackgroundTaskState.running, -) -> mlrun.api.schemas.BackgroundTask: + state: mlrun.common.schemas.BackgroundTaskState = mlrun.common.schemas.BackgroundTaskState.running, +) -> mlrun.common.schemas.BackgroundTask: now = datetime.datetime.utcnow() - return mlrun.api.schemas.BackgroundTask( - metadata=mlrun.api.schemas.BackgroundTaskMetadata( + return mlrun.common.schemas.BackgroundTask( + metadata=mlrun.common.schemas.BackgroundTaskMetadata( name=background_task_name, created=now, updated=now, ), - status=mlrun.api.schemas.BackgroundTaskStatus(state=state.value), - spec=mlrun.api.schemas.BackgroundTaskSpec(), + status=mlrun.common.schemas.BackgroundTaskStatus(state=state.value), + spec=mlrun.common.schemas.BackgroundTaskSpec(), ) def _assert_background_task_creation(expected_project, response): assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) - assert background_task.kind == mlrun.api.schemas.ObjectKind.background_task + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) + assert background_task.kind == mlrun.common.schemas.ObjectKind.background_task assert background_task.metadata.project == expected_project assert background_task.metadata.created is not None assert background_task.metadata.updated is not None - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.running + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.running + ) return background_task diff --git a/tests/api/api/test_client_spec.py b/tests/api/api/test_client_spec.py index 74eedfe814d6..b39d0edd8d39 100644 --- a/tests/api/api/test_client_spec.py +++ b/tests/api/api/test_client_spec.py @@ -23,8 +23,8 @@ import mlrun import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.clients.iguazio +import mlrun.common.schemas import mlrun.errors import mlrun.runtimes import mlrun.utils.version @@ -141,8 +141,8 @@ def test_client_spec_response_based_on_client_version( response = client.get( "client-spec", headers={ - mlrun.api.schemas.HeaderNames.client_version: "", - mlrun.api.schemas.HeaderNames.python_version: "", + mlrun.common.schemas.HeaderNames.client_version: "", + mlrun.common.schemas.HeaderNames.python_version: "", }, ) assert response.status_code == http.HTTPStatus.OK.value @@ -157,8 +157,8 @@ def test_client_spec_response_based_on_client_version( response = client.get( "client-spec", headers={ - mlrun.api.schemas.HeaderNames.client_version: "", - mlrun.api.schemas.HeaderNames.python_version: "", + mlrun.common.schemas.HeaderNames.client_version: "", + mlrun.common.schemas.HeaderNames.python_version: "", }, ) assert response.status_code == http.HTTPStatus.OK.value @@ -170,7 +170,7 @@ def test_client_spec_response_based_on_client_version( response = client.get( "client-spec", headers={ - mlrun.api.schemas.HeaderNames.client_version: "1.2.0", + mlrun.common.schemas.HeaderNames.client_version: "1.2.0", }, ) assert response.status_code == http.HTTPStatus.OK.value @@ -182,8 +182,8 @@ def test_client_spec_response_based_on_client_version( response = client.get( "client-spec", headers={ - mlrun.api.schemas.HeaderNames.client_version: "1.3.0-rc20", - mlrun.api.schemas.HeaderNames.python_version: "3.7.13", + mlrun.common.schemas.HeaderNames.client_version: "1.3.0-rc20", + mlrun.common.schemas.HeaderNames.python_version: "3.7.13", }, ) assert response.status_code == http.HTTPStatus.OK.value @@ -194,8 +194,8 @@ def test_client_spec_response_based_on_client_version( response = client.get( "client-spec", headers={ - mlrun.api.schemas.HeaderNames.client_version: "1.3.0-rc20", - mlrun.api.schemas.HeaderNames.python_version: "3.9.13", + mlrun.common.schemas.HeaderNames.client_version: "1.3.0-rc20", + mlrun.common.schemas.HeaderNames.python_version: "3.9.13", }, ) assert response.status_code == http.HTTPStatus.OK.value @@ -207,8 +207,8 @@ def test_client_spec_response_based_on_client_version( response = client.get( "client-spec", headers={ - mlrun.api.schemas.HeaderNames.client_version: "test-integration", - mlrun.api.schemas.HeaderNames.python_version: "3.9.13", + mlrun.common.schemas.HeaderNames.client_version: "test-integration", + mlrun.common.schemas.HeaderNames.python_version: "3.9.13", }, ) assert response.status_code == http.HTTPStatus.OK.value diff --git a/tests/api/api/test_frontend_spec.py b/tests/api/api/test_frontend_spec.py index 9aaded047d6b..33e1c6fdeabf 100644 --- a/tests/api/api/test_frontend_spec.py +++ b/tests/api/api/test_frontend_spec.py @@ -20,8 +20,9 @@ import sqlalchemy.orm import mlrun.api.crud -import mlrun.api.schemas +import mlrun.api.utils.builder import mlrun.api.utils.clients.iguazio +import mlrun.common.schemas import mlrun.errors import mlrun.runtimes @@ -43,7 +44,7 @@ def test_get_frontend_spec( response = client.get("frontend-spec") assert response.status_code == http.HTTPStatus.OK.value - frontend_spec = mlrun.api.schemas.FrontendSpec(**response.json()) + frontend_spec = mlrun.common.schemas.FrontendSpec(**response.json()) assert ( deepdiff.DeepDiff( frontend_spec.abortable_function_kinds, @@ -53,19 +54,19 @@ def test_get_frontend_spec( ) assert ( frontend_spec.feature_flags.project_membership - == mlrun.api.schemas.ProjectMembershipFeatureFlag.disabled + == mlrun.common.schemas.ProjectMembershipFeatureFlag.disabled ) assert ( frontend_spec.feature_flags.authentication - == mlrun.api.schemas.AuthenticationFeatureFlag.none + == mlrun.common.schemas.AuthenticationFeatureFlag.none ) assert ( frontend_spec.feature_flags.nuclio_streams - == mlrun.api.schemas.NuclioStreamsFeatureFlag.disabled + == mlrun.common.schemas.NuclioStreamsFeatureFlag.disabled ) assert ( frontend_spec.feature_flags.preemption_nodes - == mlrun.api.schemas.PreemptionNodesFeatureFlag.disabled + == mlrun.common.schemas.PreemptionNodesFeatureFlag.disabled ) assert frontend_spec.default_function_image_by_kind is not None assert frontend_spec.function_deployment_mlrun_command is not None @@ -79,7 +80,7 @@ def test_get_frontend_spec( bla = f"{{{expected_template_field}}}" assert bla in frontend_spec.function_deployment_target_image_template - assert frontend_spec.default_function_pod_resources, mlrun.api.schemas.Resources( + assert frontend_spec.default_function_pod_resources, mlrun.common.schemas.Resources( **default_function_pod_resources ) assert ( @@ -93,12 +94,16 @@ def test_get_frontend_spec( assert ( frontend_spec.default_function_preemption_mode - == mlrun.api.schemas.PreemptionModes.prevent.value + == mlrun.common.schemas.PreemptionModes.prevent.value ) assert ( frontend_spec.allowed_artifact_path_prefixes_list == mlrun.api.api.utils.get_allowed_path_prefixes_list() ) + assert ( + frontend_spec.function_deployment_mlrun_command + == f'python -m pip install "{mlrun.api.utils.builder.resolve_mlrun_install_command_version()}"' + ) def test_get_frontend_spec_jobs_dashboard_url_resolution( @@ -110,7 +115,7 @@ def test_get_frontend_spec_jobs_dashboard_url_resolution( # no cookie so no url response = client.get("frontend-spec") assert response.status_code == http.HTTPStatus.OK.value - frontend_spec = mlrun.api.schemas.FrontendSpec(**response.json()) + frontend_spec = mlrun.common.schemas.FrontendSpec(**response.json()) assert frontend_spec.jobs_dashboard_url is None mlrun.api.utils.clients.iguazio.Client().try_get_grafana_service_url.assert_not_called() @@ -119,7 +124,7 @@ def test_get_frontend_spec_jobs_dashboard_url_resolution( mlrun.api.utils.clients.iguazio.AsyncClient().verify_request_session = ( unittest.mock.AsyncMock( return_value=( - mlrun.api.schemas.AuthInfo( + mlrun.common.schemas.AuthInfo( username=None, session="946b0749-5c40-4837-a4ac-341d295bfaf7", user_id=None, @@ -134,7 +139,7 @@ def test_get_frontend_spec_jobs_dashboard_url_resolution( ) response = client.get("frontend-spec") assert response.status_code == http.HTTPStatus.OK.value - frontend_spec = mlrun.api.schemas.FrontendSpec(**response.json()) + frontend_spec = mlrun.common.schemas.FrontendSpec(**response.json()) assert frontend_spec.jobs_dashboard_url is None mlrun.api.utils.clients.iguazio.Client().try_get_grafana_service_url.assert_called_once() @@ -146,23 +151,7 @@ def test_get_frontend_spec_jobs_dashboard_url_resolution( response = client.get("frontend-spec") assert response.status_code == http.HTTPStatus.OK.value - frontend_spec = mlrun.api.schemas.FrontendSpec(**response.json()) - assert ( - frontend_spec.jobs_dashboard_url - == f"{grafana_url}/d/mlrun-jobs-monitoring/mlrun-jobs-monitoring?orgId=1" - f"&var-groupBy={{filter_name}}&var-filter={{filter_value}}" - ) - mlrun.api.utils.clients.iguazio.Client().try_get_grafana_service_url.assert_called_once() - - # now one time with the 3.0 iguazio auth way - mlrun.mlconf.httpdb.authentication.mode = "none" - mlrun.api.utils.clients.iguazio.Client().try_get_grafana_service_url.reset_mock() - response = client.get( - "frontend-spec", - cookies={"session": 'j:{"sid":"946b0749-5c40-4837-a4ac-341d295bfaf7"}'}, - ) - assert response.status_code == http.HTTPStatus.OK.value - frontend_spec = mlrun.api.schemas.FrontendSpec(**response.json()) + frontend_spec = mlrun.common.schemas.FrontendSpec(**response.json()) assert ( frontend_spec.jobs_dashboard_url == f"{grafana_url}/d/mlrun-jobs-monitoring/mlrun-jobs-monitoring?orgId=1" @@ -178,22 +167,22 @@ def test_get_frontend_spec_nuclio_streams( { "iguazio_version": "3.2.0", "nuclio_version": "1.6.23", - "expected_feature_flag": mlrun.api.schemas.NuclioStreamsFeatureFlag.disabled, + "expected_feature_flag": mlrun.common.schemas.NuclioStreamsFeatureFlag.disabled, }, { "iguazio_version": None, "nuclio_version": "1.6.23", - "expected_feature_flag": mlrun.api.schemas.NuclioStreamsFeatureFlag.disabled, + "expected_feature_flag": mlrun.common.schemas.NuclioStreamsFeatureFlag.disabled, }, { "iguazio_version": None, "nuclio_version": "1.7.8", - "expected_feature_flag": mlrun.api.schemas.NuclioStreamsFeatureFlag.disabled, + "expected_feature_flag": mlrun.common.schemas.NuclioStreamsFeatureFlag.disabled, }, { "iguazio_version": "3.4.0", "nuclio_version": "1.7.8", - "expected_feature_flag": mlrun.api.schemas.NuclioStreamsFeatureFlag.enabled, + "expected_feature_flag": mlrun.common.schemas.NuclioStreamsFeatureFlag.enabled, }, ]: # init cached value to None in the beginning of each test case @@ -202,7 +191,7 @@ def test_get_frontend_spec_nuclio_streams( mlrun.mlconf.nuclio_version = test_case.get("nuclio_version") response = client.get("frontend-spec") - frontend_spec = mlrun.api.schemas.FrontendSpec(**response.json()) + frontend_spec = mlrun.common.schemas.FrontendSpec(**response.json()) assert response.status_code == http.HTTPStatus.OK.value assert frontend_spec.feature_flags.nuclio_streams == test_case.get( "expected_feature_flag" @@ -219,7 +208,7 @@ def test_get_frontend_spec_ce( response = client.get("frontend-spec") assert response.status_code == http.HTTPStatus.OK.value - frontend_spec = mlrun.api.schemas.FrontendSpec(**response.json()) + frontend_spec = mlrun.common.schemas.FrontendSpec(**response.json()) assert frontend_spec.ce["release"] == ce_release assert frontend_spec.ce["mode"] == frontend_spec.ce_mode == ce_mode @@ -238,7 +227,7 @@ def test_get_frontend_spec_feature_store_data_prefixes( ) response = client.get("frontend-spec") assert response.status_code == http.HTTPStatus.OK.value - frontend_spec = mlrun.api.schemas.FrontendSpec(**response.json()) + frontend_spec = mlrun.common.schemas.FrontendSpec(**response.json()) assert ( frontend_spec.feature_store_data_prefixes["default"] == feature_store_data_prefix_default diff --git a/tests/api/api/test_functions.py b/tests/api/api/test_functions.py index 1d65ae7c9fab..2da8689b4345 100644 --- a/tests/api/api/test_functions.py +++ b/tests/api/api/test_functions.py @@ -29,12 +29,14 @@ import mlrun.api.api.endpoints.functions import mlrun.api.api.utils import mlrun.api.crud -import mlrun.api.schemas +import mlrun.api.main +import mlrun.api.utils.builder import mlrun.api.utils.clients.chief import mlrun.api.utils.singletons.db import mlrun.api.utils.singletons.k8s import mlrun.artifacts.dataset import mlrun.artifacts.model +import mlrun.common.schemas import mlrun.errors import mlrun.utils.model_monitoring import tests.api.api.utils @@ -63,8 +65,8 @@ def test_build_status_pod_not_found( ) assert response.status_code == HTTPStatus.OK.value - mlrun.api.utils.singletons.k8s.get_k8s().v1api = unittest.mock.Mock() - mlrun.api.utils.singletons.k8s.get_k8s().v1api.read_namespaced_pod = ( + mlrun.api.utils.singletons.k8s.get_k8s_helper().v1api = unittest.mock.Mock() + mlrun.api.utils.singletons.k8s.get_k8s_helper().v1api.read_namespaced_pod = ( unittest.mock.Mock( side_effect=kubernetes.client.rest.ApiException( status=HTTPStatus.NOT_FOUND.value @@ -113,8 +115,7 @@ async def test_list_functions_with_hash_key_versioned( } post_function1_response = await async_client.post( - f"func/{function_project}/" - f"{function_name}?tag={function_tag}&versioned={True}", + f"projects/{function_project}/functions/{function_name}?tag={function_tag}&versioned={True}", json=function, ) @@ -123,14 +124,14 @@ async def test_list_functions_with_hash_key_versioned( # Store another function with the same project and name but different tag and hash key post_function2_response = await async_client.post( - f"func/{function_project}/" + f"projects/{function_project}/functions/" f"{function_name}?tag={another_tag}&versioned={True}", json=function2, ) assert post_function2_response.status_code == HTTPStatus.OK.value list_functions_by_hash_key_response = await async_client.get( - f"funcs?project={function_project}&name={function_name}&hash_key={hash_key}" + f"projects/{function_project}/functions?name={function_name}&hash_key={hash_key}" ) list_functions_results = list_functions_by_hash_key_response.json()["funcs"] @@ -138,6 +139,80 @@ async def test_list_functions_with_hash_key_versioned( assert list_functions_results[0]["metadata"]["hash"] == hash_key +@pytest.mark.parametrize("post_schedule", [True, False]) +def test_delete_function_with_schedule( + db: sqlalchemy.orm.Session, + client: fastapi.testclient.TestClient, + post_schedule, +): + # create project and function + tests.api.api.utils.create_project(client, PROJECT) + + function_tag = "function-tag" + function_name = "function-name" + project_name = "project-name" + + function = { + "kind": "job", + "metadata": { + "name": function_name, + "project": project_name, + "tag": function_tag, + }, + "spec": {"image": "mlrun/mlrun"}, + } + + function_endpoint = f"projects/{PROJECT}/functions/{function_name}" + function = client.post(function_endpoint, data=mlrun.utils.dict_to_json(function)) + assert function.status_code == HTTPStatus.OK.value + hash_key = function.json()["hash_key"] + + endpoint = f"projects/{PROJECT}/schedules" + if post_schedule: + # generate schedule object that matches to the function and create it + scheduled_object = { + "task": { + "spec": { + "function": f"{PROJECT}/{function_name}@{hash_key}", + "handler": "handler", + }, + "metadata": {"name": "my-task", "project": f"{PROJECT}"}, + } + } + schedule_cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(minute=1) + + schedule = mlrun.common.schemas.ScheduleInput( + name=function_name, + kind=mlrun.common.schemas.ScheduleKinds.job, + scheduled_object=scheduled_object, + cron_trigger=schedule_cron_trigger, + ) + + endpoint = f"projects/{PROJECT}/schedules" + response = client.post(endpoint, data=mlrun.utils.dict_to_json(schedule.dict())) + assert response.status_code == HTTPStatus.CREATED.value + + response = client.get(endpoint) + assert ( + response.status_code == HTTPStatus.OK.value + and response.json()["schedules"][0]["name"] == function_name + ) + + # delete the function and assert that it has been removed, as has its schedule if created + response = client.delete(function_endpoint) + assert response.status_code == HTTPStatus.NO_CONTENT.value + + response = client.get(function_endpoint) + assert response.status_code == HTTPStatus.NOT_FOUND.value + + if post_schedule: + response = client.get(endpoint) + assert ( + response.status_code == HTTPStatus.OK.value + and not response.json()["schedules"] + ) + + @pytest.mark.asyncio async def test_multiple_store_function_race_condition( db: sqlalchemy.orm.Session, async_client: httpx.AsyncClient @@ -296,6 +371,7 @@ def test_tracking_on_serving( ], mlrun.api.crud: ["ModelEndpoints"], nuclio.deploy: ["deploy_config"], + mlrun.utils.model_monitoring: ["get_stream_path"], } for package in functions_to_monkeypatch: @@ -375,6 +451,52 @@ def test_build_function_with_mlrun_bool( mlrun.api.api.endpoints.functions._build_function = original_build_function +@pytest.mark.parametrize( + "source, load_source_on_run", + [ + ("./", False), + (".", False), + ("./", True), + (".", True), + ], +) +def test_build_function_with_project_repo( + db: sqlalchemy.orm.Session, + client: fastapi.testclient.TestClient, + source, + load_source_on_run, +): + git_repo = "git://github.com/mlrun/test.git" + tests.api.api.utils.create_project( + client, PROJECT, source=git_repo, load_source_on_run=load_source_on_run + ) + function_dict = { + "kind": "job", + "metadata": { + "name": "function-name", + "project": "project-name", + "tag": "latest", + }, + "spec": { + "build": { + "source": source, + }, + }, + } + original_build_runtime = mlrun.api.utils.builder.build_image + mlrun.api.utils.builder.build_image = unittest.mock.Mock(return_value="success") + response = client.post( + "build/function", + json={"function": function_dict}, + ) + assert response.status_code == HTTPStatus.OK.value + function = mlrun.new_function(runtime=response.json()["data"]) + assert function.spec.build.source == git_repo + assert function.spec.build.load_source_on_run == load_source_on_run + + mlrun.api.utils.builder.build_image = original_build_runtime + + def test_start_function_succeeded( db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient, monkeypatch ): @@ -399,16 +521,19 @@ def test_start_function_succeeded( ), ) assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.running + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.running + ) response = client.get( f"projects/{project}/background-tasks/{background_task.metadata.name}" ) assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) assert ( - background_task.status.state == mlrun.api.schemas.BackgroundTaskState.succeeded + background_task.status.state + == mlrun.common.schemas.BackgroundTaskState.succeeded ) @@ -440,14 +565,18 @@ def failing_func(): ), ) assert response.status_code == http.HTTPStatus.OK - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.running + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.running + ) response = client.get( f"projects/{project}/background-tasks/{background_task.metadata.name}" ) assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.failed + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.failed + ) def test_start_function( @@ -461,26 +590,26 @@ def failing_func(): for test_case in [ { "_start_function_mock": unittest.mock.Mock, - "expected_status_result": mlrun.api.schemas.BackgroundTaskState.succeeded, + "expected_status_result": mlrun.common.schemas.BackgroundTaskState.succeeded, "background_timeout_mode": "enabled", "dask_timeout": 100, }, { "_start_function_mock": failing_func, - "expected_status_result": mlrun.api.schemas.BackgroundTaskState.failed, + "expected_status_result": mlrun.common.schemas.BackgroundTaskState.failed, "background_timeout_mode": "enabled", "dask_timeout": None, }, { "_start_function_mock": unittest.mock.Mock, - "expected_status_result": mlrun.api.schemas.BackgroundTaskState.succeeded, + "expected_status_result": mlrun.common.schemas.BackgroundTaskState.succeeded, "background_timeout_mode": "disabled", "dask_timeout": 0, }, ]: _start_function_mock = test_case.get("_start_function_mock", unittest.mock.Mock) expected_status_result = test_case.get( - "expected_status_result", mlrun.api.schemas.BackgroundTaskState.running + "expected_status_result", mlrun.common.schemas.BackgroundTaskState.running ) background_timeout_mode = test_case.get("background_timeout_mode", "enabled") dask_timeout = test_case.get("dask_timeout", None) @@ -507,16 +636,16 @@ def failing_func(): ), ) assert response.status_code == http.HTTPStatus.OK - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) assert ( background_task.status.state - == mlrun.api.schemas.BackgroundTaskState.running + == mlrun.common.schemas.BackgroundTaskState.running ) response = client.get( f"projects/{project}/background-tasks/{background_task.metadata.name}" ) assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) assert background_task.status.state == expected_status_result diff --git a/tests/api/api/test_grafana_proxy.py b/tests/api/api/test_grafana_proxy.py index 131bcc2d0452..81fa6323cec8 100644 --- a/tests/api/api/test_grafana_proxy.py +++ b/tests/api/api/test_grafana_proxy.py @@ -29,11 +29,13 @@ import mlrun import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.clients.iguazio -from mlrun.api.api.endpoints.grafana_proxy import ( - _parse_query_parameters, - _validate_query_parameters, +import mlrun.common.model_monitoring as model_monitoring_constants +import mlrun.common.schemas +import mlrun.model_monitoring.stores +from mlrun.api.crud.model_monitoring.grafana import ( + parse_query_parameters, + validate_query_parameters, ) from mlrun.config import config from mlrun.errors import MLRunBadRequestError @@ -60,7 +62,7 @@ def test_grafana_proxy_model_endpoints_check_connection( mlrun.api.utils.clients.iguazio.AsyncClient().verify_request_session = ( unittest.mock.AsyncMock( return_value=( - mlrun.api.schemas.AuthInfo( + mlrun.common.schemas.AuthInfo( username=None, session="some-session", data_session="some-session", @@ -82,17 +84,17 @@ def test_grafana_proxy_model_endpoints_check_connection( reason=_build_skip_message(), ) def test_grafana_list_endpoints(db: Session, client: TestClient): + endpoints_in = [_mock_random_endpoint("active") for _ in range(5)] # Initialize endpoint store target object - endpoint_target = ( - mlrun.api.crud.model_monitoring.model_endpoint_store._ModelEndpointKVStore( - project=TEST_PROJECT, access_key=_get_access_key() - ) + store_type_object = mlrun.model_monitoring.stores.ModelEndpointStoreType(value="kv") + endpoint_store = store_type_object.to_endpoint_store( + project=TEST_PROJECT, access_key=_get_access_key() ) for endpoint in endpoints_in: - endpoint_target.write_model_endpoint(endpoint) + endpoint_store.write_model_endpoint(endpoint.flat_dict()) response = client.post( url="grafana-proxy/model-endpoints/query", @@ -302,30 +304,30 @@ def test_grafana_overall_feature_analysis(db: Session, client: TestClient): def test_parse_query_parameters_failure(): # No 'targets' in body with pytest.raises(MLRunBadRequestError): - _parse_query_parameters({}) + parse_query_parameters({}) # No 'target' list in 'targets' dictionary with pytest.raises(MLRunBadRequestError): - _parse_query_parameters({"targets": []}) + parse_query_parameters({"targets": []}) # Target query not separated by equals ('=') char with pytest.raises(MLRunBadRequestError): - _parse_query_parameters({"targets": [{"target": "test"}]}) + parse_query_parameters({"targets": [{"target": "test"}]}) def test_parse_query_parameters_success(): # Target query separated by equals ('=') char - params = _parse_query_parameters({"targets": [{"target": "test=some_test"}]}) + params = parse_query_parameters({"targets": [{"target": "test=some_test"}]}) assert params["test"] == "some_test" # Target query separated by equals ('=') char (multiple queries) - params = _parse_query_parameters( + params = parse_query_parameters( {"targets": [{"target": "test=some_test;another_test=some_other_test"}]} ) assert params["test"] == "some_test" assert params["another_test"] == "some_other_test" - params = _parse_query_parameters( + params = parse_query_parameters( {"targets": [{"target": "test=some_test;another_test=some_other_test;"}]} ) assert params["test"] == "some_test" @@ -335,19 +337,17 @@ def test_parse_query_parameters_success(): def test_validate_query_parameters_failure(): # No 'target_endpoint' in query parameters with pytest.raises(MLRunBadRequestError): - _validate_query_parameters({}) + validate_query_parameters({}) # target_endpoint unsupported with pytest.raises(MLRunBadRequestError): - _validate_query_parameters( + validate_query_parameters( {"target_endpoint": "unsupported_endpoint"}, {"supported_endpoint"} ) def test_validate_query_parameters_success(): - _validate_query_parameters( - {"target_endpoint": "list_endpoints"}, {"list_endpoints"} - ) + validate_query_parameters({"target_endpoint": "list_endpoints"}, {"list_endpoints"}) def _get_access_key() -> Optional[str]: @@ -359,13 +359,13 @@ def cleanup_endpoints(db: Session, client: TestClient): if not _is_env_params_dont_exist(): kv_path = config.model_endpoint_monitoring.store_prefixes.default.format( project=TEST_PROJECT, - kind=mlrun.api.schemas.ModelMonitoringStoreKinds.ENDPOINTS, + kind=model_monitoring_constants.ModelMonitoringStoreKinds.ENDPOINTS, ) _, kv_container, kv_path = parse_model_endpoint_store_prefix(kv_path) tsdb_path = config.model_endpoint_monitoring.store_prefixes.default.format( project=TEST_PROJECT, - kind=mlrun.api.schemas.ModelMonitoringStoreKinds.EVENTS, + kind=model_monitoring_constants.ModelMonitoringStoreKinds.EVENTS, ) _, tsdb_container, tsdb_path = parse_model_endpoint_store_prefix(tsdb_path) @@ -414,7 +414,8 @@ def cleanup_endpoints(db: Session, client: TestClient): ) def test_grafana_incoming_features(db: Session, client: TestClient): path = config.model_endpoint_monitoring.store_prefixes.default.format( - project=TEST_PROJECT, kind=mlrun.api.schemas.ModelMonitoringStoreKinds.EVENTS + project=TEST_PROJECT, + kind=model_monitoring_constants.ModelMonitoringStoreKinds.EVENTS, ) _, container, path = parse_model_endpoint_store_prefix(path) @@ -432,14 +433,15 @@ def test_grafana_incoming_features(db: Session, client: TestClient): e.spec.feature_names = ["f0", "f1", "f2", "f3"] # Initialize endpoint store target object - endpoint_target = ( - mlrun.api.crud.model_monitoring.model_endpoint_store._ModelEndpointKVStore( - project=TEST_PROJECT, access_key=_get_access_key() - ) + store_type_object = mlrun.model_monitoring.stores.ModelEndpointStoreType( + value="v3io-nosql" + ) + endpoint_store = store_type_object.to_endpoint_store( + project=TEST_PROJECT, access_key=_get_access_key() ) for endpoint in endpoints: - endpoint_target.write_model_endpoint(endpoint) + endpoint_store.write_model_endpoint(endpoint.flat_dict()) total = 0 diff --git a/tests/api/api/test_healthz.py b/tests/api/api/test_healthz.py index f8dcc34ed0e2..c20dd11dae87 100644 --- a/tests/api/api/test_healthz.py +++ b/tests/api/api/test_healthz.py @@ -17,25 +17,19 @@ import fastapi.testclient import sqlalchemy.orm -import mlrun -import mlrun.api.crud -import mlrun.api.schemas -import mlrun.api.utils.clients.iguazio -import mlrun.errors -import mlrun.runtimes +import mlrun.common.schemas +import mlrun.config def test_health( db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient ) -> None: - overridden_ui_projects_prefix = "some-prefix" - mlrun.mlconf.ui.projects_prefix = overridden_ui_projects_prefix - nuclio_version = "x.x.x" - mlrun.mlconf.nuclio_version = nuclio_version + + # sanity response = client.get("healthz") assert response.status_code == http.HTTPStatus.OK.value - response_body = response.json() - for key in ["scrape_metrics", "hub_url"]: - assert response_body[key] is None - assert response_body["ui_projects_prefix"] == overridden_ui_projects_prefix - assert response_body["nuclio_version"] == nuclio_version + + # fail + mlrun.config.config.httpdb.state = mlrun.common.schemas.APIStates.offline + response = client.get("healthz") + assert response.status_code == http.HTTPStatus.SERVICE_UNAVAILABLE.value diff --git a/tests/api/api/test_model_endpoints.py b/tests/api/api/test_model_endpoints.py index 1b6ab99644f2..b771e28d0f39 100644 --- a/tests/api/api/test_model_endpoints.py +++ b/tests/api/api/test_model_endpoints.py @@ -14,52 +14,59 @@ # import os import string +import typing from random import choice, randint from typing import Optional +import deepdiff import pytest import mlrun.api.crud -import mlrun.api.schemas -from mlrun.api.schemas import ( - ModelEndpoint, - ModelEndpointMetadata, - ModelEndpointSpec, - ModelEndpointStatus, -) +import mlrun.common.schemas from mlrun.errors import MLRunBadRequestError, MLRunInvalidArgumentError +from mlrun.model_monitoring import ModelMonitoringStoreKinds +from mlrun.model_monitoring.stores import ( # noqa: F401 + ModelEndpointStore, + ModelEndpointStoreType, +) TEST_PROJECT = "test_model_endpoints" - +ENDPOINT_STORE_CONNECTION = "sqlite:///test.db" # Set a default v3io access key env variable V3IO_ACCESS_KEY = "1111-2222-3333-4444" os.environ["V3IO_ACCESS_KEY"] = V3IO_ACCESS_KEY +# Bound a typing variable for ModelEndpointStore +KVmodelType = typing.TypeVar("KVmodelType", bound="ModelEndpointStore") + def test_build_kv_cursor_filter_expression(): """Validate that the filter expression format converter for the KV cursor works as expected.""" # Initialize endpoint store target object - endpoint_target = ( - mlrun.api.crud.model_monitoring.model_endpoint_store._ModelEndpointKVStore( - project=TEST_PROJECT, access_key=V3IO_ACCESS_KEY - ) + store_type_object = mlrun.model_monitoring.stores.ModelEndpointStoreType( + value="v3io-nosql" + ) + + endpoint_store: KVmodelType = store_type_object.to_endpoint_store( + project=TEST_PROJECT, access_key=V3IO_ACCESS_KEY ) + with pytest.raises(MLRunInvalidArgumentError): - endpoint_target.build_kv_cursor_filter_expression("") + endpoint_store._build_kv_cursor_filter_expression("") - filter_expression = endpoint_target.build_kv_cursor_filter_expression( + filter_expression = endpoint_store._build_kv_cursor_filter_expression( project=TEST_PROJECT ) assert filter_expression == f"project=='{TEST_PROJECT}'" - filter_expression = endpoint_target.build_kv_cursor_filter_expression( + filter_expression = endpoint_store._build_kv_cursor_filter_expression( project=TEST_PROJECT, function="test_function", model="test_model" ) expected = f"project=='{TEST_PROJECT}' AND function=='test_function' AND model=='test_model'" assert filter_expression == expected - filter_expression = endpoint_target.build_kv_cursor_filter_expression( + filter_expression = endpoint_store._build_kv_cursor_filter_expression( project=TEST_PROJECT, labels=["lbl1", "lbl2"] ) assert ( @@ -67,7 +74,7 @@ def test_build_kv_cursor_filter_expression(): == f"project=='{TEST_PROJECT}' AND exists(_lbl1) AND exists(_lbl2)" ) - filter_expression = endpoint_target.build_kv_cursor_filter_expression( + filter_expression = endpoint_store._build_kv_cursor_filter_expression( project=TEST_PROJECT, labels=["lbl1=1", "lbl2=2"] ) assert ( @@ -77,12 +84,12 @@ def test_build_kv_cursor_filter_expression(): def test_get_access_key(): key = mlrun.api.crud.ModelEndpoints().get_access_key( - mlrun.api.schemas.AuthInfo(data_session="asd") + mlrun.common.schemas.AuthInfo(data_session="asd") ) assert key == "asd" with pytest.raises(MLRunBadRequestError): - mlrun.api.crud.ModelEndpoints().get_access_key(mlrun.api.schemas.AuthInfo()) + mlrun.api.crud.ModelEndpoints().get_access_key(mlrun.common.schemas.AuthInfo()) def test_get_endpoint_features_function(): @@ -222,14 +229,9 @@ def test_get_endpoint_features_function(): } feature_names = list(stats.keys()) - # Initialize endpoint store target object - endpoint_target = ( - mlrun.api.crud.model_monitoring.model_endpoint_store._ModelEndpointKVStore( - project=TEST_PROJECT, access_key=V3IO_ACCESS_KEY - ) + features = mlrun.api.crud.ModelEndpoints.get_endpoint_features( + feature_names, stats, stats ) - - features = endpoint_target.get_endpoint_features(feature_names, stats, stats) assert len(features) == 4 # Commented out asserts should be re-enabled once buckets/counts length mismatch bug is fixed for feature in features: @@ -242,7 +244,9 @@ def test_get_endpoint_features_function(): assert feature.actual.histogram is not None # assert len(feature.actual.histogram.buckets) == len(feature.actual.histogram.counts) - features = endpoint_target.get_endpoint_features(feature_names, stats, None) + features = mlrun.api.crud.ModelEndpoints.get_endpoint_features( + feature_names, stats, None + ) assert len(features) == 4 for feature in features: assert feature.expected is not None @@ -251,7 +255,9 @@ def test_get_endpoint_features_function(): assert feature.expected.histogram is not None # assert len(feature.expected.histogram.buckets) == len(feature.expected.histogram.counts) - features = endpoint_target.get_endpoint_features(feature_names, None, stats) + features = mlrun.api.crud.ModelEndpoints.get_endpoint_features( + feature_names, None, stats + ) assert len(features) == 4 for feature in features: assert feature.expected is None @@ -260,28 +266,31 @@ def test_get_endpoint_features_function(): assert feature.actual.histogram is not None # assert len(feature.actual.histogram.buckets) == len(feature.actual.histogram.counts) - features = endpoint_target.get_endpoint_features(feature_names[1:], None, stats) + features = mlrun.api.crud.ModelEndpoints.get_endpoint_features( + feature_names[1:], None, stats + ) assert len(features) == 3 def test_generating_tsdb_paths(): - """Validate that the TSDB paths for the _ModelEndpointKVStore object are created as expected. These paths are + """Validate that the TSDB paths for the KVModelEndpointStore object are created as expected. These paths are usually important when the user call the delete project API and as a result the TSDB resources should be deleted""" # Initialize endpoint store target object - endpoint_target = ( - mlrun.api.crud.model_monitoring.model_endpoint_store._ModelEndpointKVStore( - project=TEST_PROJECT, access_key=V3IO_ACCESS_KEY - ) + store_type_object = mlrun.model_monitoring.stores.ModelEndpointStoreType( + value="v3io-nosql" + ) + endpoint_store: KVmodelType = store_type_object.to_endpoint_store( + project=TEST_PROJECT, access_key=V3IO_ACCESS_KEY ) # Generating the required tsdb paths - tsdb_path, filtered_path = endpoint_target._generate_tsdb_paths() + tsdb_path, filtered_path = endpoint_store._generate_tsdb_paths() # Validate the expected results based on the full path to the TSDB events directory full_path = mlrun.mlconf.model_endpoint_monitoring.store_prefixes.default.format( project=TEST_PROJECT, - kind=mlrun.api.schemas.ModelMonitoringStoreKinds.EVENTS, + kind=ModelMonitoringStoreKinds.EVENTS, ) # TSDB short path that should point to the main directory @@ -291,20 +300,153 @@ def test_generating_tsdb_paths(): assert filtered_path == full_path[-len(filtered_path) + 1 :] + "/" -def _get_auth_info() -> mlrun.api.schemas.AuthInfo: - return mlrun.api.schemas.AuthInfo(data_session=os.environ.get("V3IO_ACCESS_KEY")) +def _get_auth_info() -> mlrun.common.schemas.AuthInfo: + return mlrun.common.schemas.AuthInfo(data_session=os.environ.get("V3IO_ACCESS_KEY")) -def _mock_random_endpoint(state: Optional[str] = None) -> ModelEndpoint: +def _mock_random_endpoint( + state: Optional[str] = None, +) -> mlrun.common.schemas.ModelEndpoint: def random_labels(): return {f"{choice(string.ascii_letters)}": randint(0, 100) for _ in range(1, 5)} - return ModelEndpoint( - metadata=ModelEndpointMetadata(project=TEST_PROJECT, labels=random_labels()), - spec=ModelEndpointSpec( + return mlrun.common.schemas.ModelEndpoint( + metadata=mlrun.common.schemas.ModelEndpointMetadata( + project=TEST_PROJECT, labels=random_labels(), uid=str(randint(1000, 5000)) + ), + spec=mlrun.common.schemas.ModelEndpointSpec( function_uri=f"test/function_{randint(0, 100)}:v{randint(0, 100)}", model=f"model_{randint(0, 100)}:v{randint(0, 100)}", model_class="classifier", ), - status=ModelEndpointStatus(state=state), + status=mlrun.common.schemas.ModelEndpointStatus(state=state), + ) + + +def test_sql_target_list_model_endpoints(): + """Testing list model endpoint using SQLModelEndpointStore object. In the following test + we create two model endpoints and list these endpoints. In addition, this test validates the + filter optional operation within the list model endpoints API. At the end of this test, we validate + that the model endpoints are deleted from the DB. + """ + + # Generate model endpoint target + store_type_object = mlrun.model_monitoring.stores.ModelEndpointStoreType( + value="sql" + ) + endpoint_store = store_type_object.to_endpoint_store( + project=TEST_PROJECT, endpoint_store_connection=ENDPOINT_STORE_CONNECTION + ) + + # First, validate that there are no model endpoints records at the moment + list_of_endpoints = endpoint_store.list_model_endpoints() + endpoint_store.delete_model_endpoints_resources(endpoints=list_of_endpoints) + + list_of_endpoints = endpoint_store.list_model_endpoints() + assert len(list_of_endpoints) == 0 + + # Generate and write the 1st model endpoint into the DB table + mock_endpoint_1 = _mock_random_endpoint() + endpoint_store.write_model_endpoint(endpoint=mock_endpoint_1.flat_dict()) + + # Validate that there is a single model endpoint + list_of_endpoints = endpoint_store.list_model_endpoints() + assert len(list_of_endpoints) == 1 + + # Generate and write the 2nd model endpoint into the DB table + mock_endpoint_2 = _mock_random_endpoint() + mock_endpoint_2.spec.model = "test_model" + mock_endpoint_2.metadata.uid = "12345" + endpoint_store.write_model_endpoint(endpoint=mock_endpoint_2.flat_dict()) + + # Validate that there are exactly two model endpoints within the DB + list_of_endpoints = endpoint_store.list_model_endpoints() + assert len(list_of_endpoints) == 2 + + # List only the model endpoint that has the model test_model + filtered_list_of_endpoints = endpoint_store.list_model_endpoints(model="test_model") + assert len(filtered_list_of_endpoints) == 1 + + # Clean model endpoints from DB + endpoint_store.delete_model_endpoints_resources(endpoints=list_of_endpoints) + list_of_endpoints = endpoint_store.list_model_endpoints() + assert (len(list_of_endpoints)) == 0 + + +def test_sql_target_patch_endpoint(): + """Testing the update of a model endpoint using SQLModelEndpointStore object. In the following + test we update attributes within the model endpoint spec and status and then validate that there + attributes were actually updated. + """ + + # Generate model endpoint target + store_type_object = mlrun.model_monitoring.stores.ModelEndpointStoreType( + value="sql" + ) + endpoint_store = store_type_object.to_endpoint_store( + project=TEST_PROJECT, endpoint_store_connection=ENDPOINT_STORE_CONNECTION ) + + # First, validate that there are no model endpoints records at the moment + list_of_endpoints = endpoint_store.list_model_endpoints() + if len(list_of_endpoints) > 0: + # Delete old model endpoints records + endpoint_store.delete_model_endpoints_resources(endpoints=list_of_endpoints) + list_of_endpoints = endpoint_store.list_model_endpoints() + assert len(list_of_endpoints) == 0 + + # Generate and write the model endpoint into the DB table + mock_endpoint = _mock_random_endpoint() + mock_endpoint.metadata.uid = "1234" + endpoint_store.write_model_endpoint(mock_endpoint.flat_dict()) + + # Generate dictionary of attributes and update the model endpoint + updated_attributes = {"model": "test_model", "error_count": 2} + endpoint_store.update_model_endpoint( + endpoint_id=mock_endpoint.metadata.uid, attributes=updated_attributes + ) + + # Validate that these attributes were actually updated + endpoint = endpoint_store.get_model_endpoint(endpoint_id=mock_endpoint.metadata.uid) + + # Convert to model endpoint object + endpoint = mlrun.api.crud.ModelEndpoints()._convert_into_model_endpoint_object( + endpoint=endpoint + ) + assert endpoint.spec.model == "test_model" + assert endpoint.status.error_count == 2 + + # Clear model endpoint from DB + endpoint_store.delete_model_endpoint(endpoint_id=mock_endpoint.metadata.uid) + + # Drop model endpoints test table from DB + list_of_endpoints = endpoint_store.list_model_endpoints() + endpoint_store.delete_model_endpoints_resources(endpoints=list_of_endpoints) + + +def test_validate_model_endpoints_schema(): + # Validate that both model endpoint basemodel schema and model endpoint ModelObj schema have similar keys + model_endpoint_basemodel = mlrun.common.schemas.ModelEndpoint() + model_endpoint_modelobj = mlrun.model_monitoring.ModelEndpoint() + + # Compare status + base_model_status = model_endpoint_basemodel.status.__dict__ + model_object_status = model_endpoint_modelobj.status.__dict__ + assert ( + deepdiff.DeepDiff( + base_model_status, + model_object_status, + ignore_order=True, + ) + ) == {} + + # Compare spec + base_model_status = model_endpoint_basemodel.status.__dict__ + model_object_status = model_endpoint_modelobj.status.__dict__ + assert ( + deepdiff.DeepDiff( + base_model_status, + model_object_status, + ignore_order=True, + ) + ) == {} diff --git a/tests/api/api/test_operations.py b/tests/api/api/test_operations.py index b6db68e8ff53..16b6c01b1bb2 100644 --- a/tests/api/api/test_operations.py +++ b/tests/api/api/test_operations.py @@ -24,10 +24,10 @@ import mlrun.api.api.endpoints.operations import mlrun.api.crud import mlrun.api.initial_data -import mlrun.api.schemas import mlrun.api.utils.background_tasks import mlrun.api.utils.clients.iguazio import mlrun.api.utils.singletons.scheduler +import mlrun.common.schemas import mlrun.errors import mlrun.runtimes from mlrun.utils import logger @@ -49,10 +49,10 @@ def test_migrations_already_in_progress( "InternalBackgroundTasksHandler", lambda *args, **kwargs: handler_mock, ) - mlrun.mlconf.httpdb.state = mlrun.api.schemas.APIStates.migrations_in_progress + mlrun.mlconf.httpdb.state = mlrun.common.schemas.APIStates.migrations_in_progress response = client.post("operations/migrations") assert response.status_code == http.HTTPStatus.ACCEPTED.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) assert background_task_name == background_task.metadata.name mlrun.api.api.endpoints.operations.current_migration_background_task_name = None @@ -60,7 +60,7 @@ def test_migrations_already_in_progress( def test_migrations_failed( db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient ) -> None: - mlrun.mlconf.httpdb.state = mlrun.api.schemas.APIStates.migrations_failed + mlrun.mlconf.httpdb.state = mlrun.common.schemas.APIStates.migrations_failed response = client.post("operations/migrations") assert response.status_code == http.HTTPStatus.PRECONDITION_FAILED.value assert "Migrations were already triggered and failed" in response.text @@ -69,19 +69,19 @@ def test_migrations_failed( def test_migrations_not_needed( db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient ) -> None: - mlrun.mlconf.httpdb.state = mlrun.api.schemas.APIStates.online + mlrun.mlconf.httpdb.state = mlrun.common.schemas.APIStates.online response = client.post("operations/migrations") assert response.status_code == http.HTTPStatus.OK.value def _mock_migration_process(*args, **kwargs): logger.info("Mocking migration process") - mlrun.mlconf.httpdb.state = mlrun.api.schemas.APIStates.migrations_completed + mlrun.mlconf.httpdb.state = mlrun.common.schemas.APIStates.migrations_completed @pytest.fixture def _mock_waiting_for_migration(): - mlrun.mlconf.httpdb.state = mlrun.api.schemas.APIStates.waiting_for_migrations + mlrun.mlconf.httpdb.state = mlrun.common.schemas.APIStates.waiting_for_migrations def test_migrations_success( @@ -103,15 +103,18 @@ def test_migrations_success( # trigger migrations response = client.post("operations/migrations") assert response.status_code == http.HTTPStatus.ACCEPTED.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.running + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.running + ) response = client.get(f"background-tasks/{background_task.metadata.name}") assert response.status_code == http.HTTPStatus.OK.value - background_task = mlrun.api.schemas.BackgroundTask(**response.json()) + background_task = mlrun.common.schemas.BackgroundTask(**response.json()) assert ( - background_task.status.state == mlrun.api.schemas.BackgroundTaskState.succeeded + background_task.status.state + == mlrun.common.schemas.BackgroundTaskState.succeeded ) - assert mlrun.mlconf.httpdb.state == mlrun.api.schemas.APIStates.online + assert mlrun.mlconf.httpdb.state == mlrun.common.schemas.APIStates.online # now we should be able to get projects response = client.get("projects") assert response.status_code == http.HTTPStatus.OK.value @@ -124,15 +127,15 @@ def test_migrations_success( def _generate_background_task_schema( background_task_name, -) -> mlrun.api.schemas.BackgroundTask: - return mlrun.api.schemas.BackgroundTask( - metadata=mlrun.api.schemas.BackgroundTaskMetadata( +) -> mlrun.common.schemas.BackgroundTask: + return mlrun.common.schemas.BackgroundTask( + metadata=mlrun.common.schemas.BackgroundTaskMetadata( name=background_task_name, created=datetime.utcnow(), updated=datetime.utcnow(), ), - status=mlrun.api.schemas.BackgroundTaskStatus( - state=mlrun.api.schemas.BackgroundTaskState.running + status=mlrun.common.schemas.BackgroundTaskStatus( + state=mlrun.common.schemas.BackgroundTaskState.running ), - spec=mlrun.api.schemas.BackgroundTaskSpec(), + spec=mlrun.common.schemas.BackgroundTaskSpec(), ) diff --git a/tests/api/api/test_pipelines.py b/tests/api/api/test_pipelines.py index 7263f8ecd1ad..1b634ffe49a2 100644 --- a/tests/api/api/test_pipelines.py +++ b/tests/api/api/test_pipelines.py @@ -24,8 +24,8 @@ import sqlalchemy.orm import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas import tests.conftest @@ -33,7 +33,7 @@ def test_list_pipelines_not_exploding_on_no_k8s( db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient ) -> None: response = client.get("projects/*/pipelines") - expected_response = mlrun.api.schemas.PipelinesOutput( + expected_response = mlrun.common.schemas.PipelinesOutput( runs=[], total_size=0, next_page_token=None ) _assert_list_pipelines_response(expected_response, response) @@ -47,7 +47,7 @@ def test_list_pipelines_empty_list( runs = [] _mock_list_runs(kfp_client_mock, runs) response = client.get("projects/*/pipelines") - expected_response = mlrun.api.schemas.PipelinesOutput( + expected_response = mlrun.common.schemas.PipelinesOutput( runs=runs, total_size=len(runs), next_page_token=None ) _assert_list_pipelines_response(expected_response, response) @@ -59,9 +59,9 @@ def test_list_pipelines_formats( kfp_client_mock: kfp.Client, ) -> None: for format_ in [ - mlrun.api.schemas.PipelinesFormat.full, - mlrun.api.schemas.PipelinesFormat.metadata_only, - mlrun.api.schemas.PipelinesFormat.name_only, + mlrun.common.schemas.PipelinesFormat.full, + mlrun.common.schemas.PipelinesFormat.metadata_only, + mlrun.common.schemas.PipelinesFormat.name_only, ]: runs = _generate_list_runs_mocks() expected_runs = [run.to_dict() for run in runs] @@ -73,7 +73,7 @@ def test_list_pipelines_formats( "projects/*/pipelines", params={"format": format_}, ) - expected_response = mlrun.api.schemas.PipelinesOutput( + expected_response = mlrun.common.schemas.PipelinesOutput( runs=expected_runs, total_size=len(runs), next_page_token=None ) _assert_list_pipelines_response(expected_response, response) @@ -85,10 +85,10 @@ def test_get_pipeline_formats( kfp_client_mock: kfp.Client, ) -> None: for format_ in [ - mlrun.api.schemas.PipelinesFormat.full, - mlrun.api.schemas.PipelinesFormat.metadata_only, - mlrun.api.schemas.PipelinesFormat.summary, - mlrun.api.schemas.PipelinesFormat.name_only, + mlrun.common.schemas.PipelinesFormat.full, + mlrun.common.schemas.PipelinesFormat.metadata_only, + mlrun.common.schemas.PipelinesFormat.summary, + mlrun.common.schemas.PipelinesFormat.name_only, ]: api_run_detail = _generate_get_run_mock() _mock_get_run(kfp_client_mock, api_run_detail) @@ -107,7 +107,7 @@ def test_get_pipeline_no_project_opa_validation( client: fastapi.testclient.TestClient, kfp_client_mock: kfp.Client, ) -> None: - format_ = (mlrun.api.schemas.PipelinesFormat.summary,) + format_ = (mlrun.common.schemas.PipelinesFormat.summary,) project = "project-name" mlrun.api.crud.Pipelines().resolve_project_from_pipeline = unittest.mock.Mock( return_value=project @@ -138,10 +138,10 @@ def test_get_pipeline_specific_project( kfp_client_mock: kfp.Client, ) -> None: for format_ in [ - mlrun.api.schemas.PipelinesFormat.full, - mlrun.api.schemas.PipelinesFormat.metadata_only, - mlrun.api.schemas.PipelinesFormat.summary, - mlrun.api.schemas.PipelinesFormat.name_only, + mlrun.common.schemas.PipelinesFormat.full, + mlrun.common.schemas.PipelinesFormat.metadata_only, + mlrun.common.schemas.PipelinesFormat.summary, + mlrun.common.schemas.PipelinesFormat.name_only, ]: project = "project-name" api_run_detail = _generate_get_run_mock() @@ -176,9 +176,9 @@ def test_list_pipelines_specific_project( ) response = client.get( f"projects/{project}/pipelines", - params={"format": mlrun.api.schemas.PipelinesFormat.name_only}, + params={"format": mlrun.common.schemas.PipelinesFormat.name_only}, ) - expected_response = mlrun.api.schemas.PipelinesOutput( + expected_response = mlrun.common.schemas.PipelinesOutput( runs=expected_runs, total_size=len(expected_runs), next_page_token=None ) _assert_list_pipelines_response(expected_response, response) @@ -231,6 +231,22 @@ def _generate_get_run_mock() -> kfp_server_api.models.api_run_detail.ApiRunDetai ) +def test_get_pipeline_nonexistent_project( + db: sqlalchemy.orm.Session, + client: fastapi.testclient.TestClient, + kfp_client_mock: kfp.Client, +) -> None: + format_ = (mlrun.common.schemas.PipelinesFormat.summary,) + project = "n0_pr0ject" + api_run_detail = _generate_get_run_mock() + _mock_get_run(kfp_client_mock, api_run_detail) + response = client.get( + f"projects/{project}/pipelines/{api_run_detail.run.id}", + params={"format": format_}, + ) + assert response.status_code == http.HTTPStatus.NOT_FOUND.value + + def _generate_list_runs_mocks(): workflow_manifest = _generate_workflow_manifest() return [ @@ -419,7 +435,7 @@ def _mock_list_runs_with_one_run_per_page(kfp_client_mock: kfp.Client, runs): def list_runs_mock(*args, page_token=None, page_size=None, **kwargs): assert expected_page_tokens.pop(0) == page_token - assert mlrun.api.schemas.PipelinesPagination.max_page_size == page_size + assert mlrun.common.schemas.PipelinesPagination.max_page_size == page_size return kfp_server_api.models.api_list_runs_response.ApiListRunsResponse( [runs.pop(0)], 1, next_page_token=expected_page_tokens[0] ) @@ -431,7 +447,7 @@ def _mock_list_runs( kfp_client_mock: kfp.Client, runs, expected_page_token="", - expected_page_size=mlrun.api.schemas.PipelinesPagination.default_page_size, + expected_page_size=mlrun.common.schemas.PipelinesPagination.default_page_size, expected_sort_by="", expected_filter="", ): @@ -460,7 +476,7 @@ def get_run_mock(*args, **kwargs): def _assert_list_pipelines_response( - expected_response: mlrun.api.schemas.PipelinesOutput, response + expected_response: mlrun.common.schemas.PipelinesOutput, response ): assert response.status_code == http.HTTPStatus.OK.value assert ( diff --git a/tests/api/api/test_projects.py b/tests/api/api/test_projects.py index f3960ce218fd..b4df4a9f104a 100644 --- a/tests/api/api/test_projects.py +++ b/tests/api/api/test_projects.py @@ -33,7 +33,6 @@ import mlrun.api.api.utils import mlrun.api.crud import mlrun.api.main -import mlrun.api.schemas import mlrun.api.utils.background_tasks import mlrun.api.utils.clients.log_collector import mlrun.api.utils.singletons.db @@ -43,6 +42,7 @@ import mlrun.api.utils.singletons.scheduler import mlrun.artifacts.dataset import mlrun.artifacts.model +import mlrun.common.schemas import mlrun.errors import tests.api.conftest import tests.api.utils.clients.test_log_collector @@ -86,7 +86,7 @@ def test_redirection_from_worker_to_chief_delete_project( mlrun.mlconf.httpdb.clusterization.role = "worker" project = "test-project" endpoint = f"{ORIGINAL_VERSIONED_API_PREFIX}/projects/{project}" - for strategy in mlrun.api.schemas.DeletionStrategy: + for strategy in mlrun.common.schemas.DeletionStrategy: headers = {"x-mlrun-deletion-strategy": strategy.value} for test_case in [ # deleting schedule failed for unknown reason @@ -134,8 +134,8 @@ def test_create_project_failure_already_exists( db: Session, client: TestClient, project_member_mode: str ) -> None: name1 = f"prj-{uuid4().hex}" - project_1 = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=name1), + project_1 = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=name1), ) # create @@ -192,7 +192,7 @@ def test_delete_project_with_resources( response = client.delete( f"projects/{project_to_remove}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.check.value }, ) assert response.status_code == HTTPStatus.PRECONDITION_FAILED.value @@ -201,7 +201,7 @@ def test_delete_project_with_resources( response = client.delete( f"projects/{project_to_remove}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.restricted.value }, ) assert response.status_code == HTTPStatus.PRECONDITION_FAILED.value @@ -210,7 +210,7 @@ def test_delete_project_with_resources( response = client.delete( f"projects/{project_to_remove}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.cascading.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.cascading.value }, ) assert response.status_code == HTTPStatus.NO_CONTENT.value @@ -249,7 +249,7 @@ def test_delete_project_with_resources( response = client.delete( f"projects/{project_to_remove}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.check.value }, ) assert response.status_code == HTTPStatus.NO_CONTENT.value @@ -258,7 +258,7 @@ def test_delete_project_with_resources( response = client.delete( f"projects/{project_to_remove}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.restricted.value }, ) assert response.status_code == HTTPStatus.NO_CONTENT.value @@ -269,16 +269,16 @@ def test_list_and_get_project_summaries( ) -> None: # create empty project empty_project_name = "empty-project" - empty_project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=empty_project_name), + empty_project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=empty_project_name), ) response = client.post("projects", json=empty_project.dict()) assert response.status_code == HTTPStatus.CREATED.value # create project with resources project_name = "project-with-resources" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), ) response = client.post("projects", json=project.dict()) assert response.status_code == HTTPStatus.CREATED.value @@ -359,7 +359,7 @@ def test_list_and_get_project_summaries( # list project summaries response = client.get("project-summaries") - project_summaries_output = mlrun.api.schemas.ProjectSummariesOutput( + project_summaries_output = mlrun.common.schemas.ProjectSummariesOutput( **response.json() ) for index, project_summary in enumerate(project_summaries_output.project_summaries): @@ -381,7 +381,7 @@ def test_list_and_get_project_summaries( # get project summary response = client.get(f"project-summaries/{project_name}") - project_summary = mlrun.api.schemas.ProjectSummary(**response.json()) + project_summary = mlrun.common.schemas.ProjectSummary(**response.json()) _assert_project_summary( project_summary, files_count, @@ -402,8 +402,8 @@ def test_list_project_summaries_different_installation_modes( """ # create empty project empty_project_name = "empty-project" - empty_project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=empty_project_name), + empty_project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=empty_project_name), ) response = client.post("projects", json=empty_project.dict()) assert response.status_code == HTTPStatus.CREATED.value @@ -418,7 +418,7 @@ def test_list_project_summaries_different_installation_modes( response = client.get("project-summaries") assert response.status_code == HTTPStatus.OK.value - project_summaries_output = mlrun.api.schemas.ProjectSummariesOutput( + project_summaries_output = mlrun.common.schemas.ProjectSummariesOutput( **response.json() ) _assert_project_summary( @@ -440,7 +440,7 @@ def test_list_project_summaries_different_installation_modes( response = client.get("project-summaries") assert response.status_code == HTTPStatus.OK.value - project_summaries_output = mlrun.api.schemas.ProjectSummariesOutput( + project_summaries_output = mlrun.common.schemas.ProjectSummariesOutput( **response.json() ) _assert_project_summary( @@ -462,7 +462,7 @@ def test_list_project_summaries_different_installation_modes( response = client.get("project-summaries") assert response.status_code == HTTPStatus.OK.value - project_summaries_output = mlrun.api.schemas.ProjectSummariesOutput( + project_summaries_output = mlrun.common.schemas.ProjectSummariesOutput( **response.json() ) _assert_project_summary( @@ -484,7 +484,7 @@ def test_list_project_summaries_different_installation_modes( response = client.get("project-summaries") assert response.status_code == HTTPStatus.OK.value - project_summaries_output = mlrun.api.schemas.ProjectSummariesOutput( + project_summaries_output = mlrun.common.schemas.ProjectSummariesOutput( **response.json() ) _assert_project_summary( @@ -506,9 +506,9 @@ def test_delete_project_deletion_strategy_check( project_member_mode: str, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ) -> None: - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name="project-name"), - spec=mlrun.api.schemas.ProjectSpec(), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name="project-name"), + spec=mlrun.common.schemas.ProjectSpec(), ) # create @@ -520,7 +520,7 @@ def test_delete_project_deletion_strategy_check( response = client.delete( f"projects/{project.metadata.name}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.check.value }, ) assert response.status_code == HTTPStatus.NO_CONTENT.value @@ -542,7 +542,7 @@ def test_delete_project_deletion_strategy_check( response = client.delete( f"projects/{project.metadata.name}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.check.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.check.value }, ) assert response.status_code == HTTPStatus.PRECONDITION_FAILED.value @@ -608,7 +608,7 @@ def test_delete_project_not_deleting_versioned_objects_multiple_times( response = client.delete( f"projects/{project_name}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.cascading.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.cascading.value }, ) assert response.status_code == HTTPStatus.NO_CONTENT.value @@ -635,9 +635,9 @@ def test_delete_project_deletion_strategy_check_external_resource( k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ) -> None: mlrun.mlconf.namespace = "test-namespace" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name="project-name"), - spec=mlrun.api.schemas.ProjectSpec(), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name="project-name"), + spec=mlrun.common.schemas.ProjectSpec(), ) # create @@ -652,7 +652,7 @@ def test_delete_project_deletion_strategy_check_external_resource( response = client.delete( f"projects/{project.metadata.name}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.restricted.value }, ) assert response.status_code == HTTPStatus.PRECONDITION_FAILED.value @@ -662,7 +662,7 @@ def test_delete_project_deletion_strategy_check_external_resource( response = client.delete( f"projects/{project.metadata.name}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.restricted.value }, ) assert response @@ -674,14 +674,16 @@ def test_delete_project_with_stop_logs( project_member_mode: str, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): - mlrun.config.config.log_collector.mode = mlrun.api.schemas.LogsCollectorMode.sidecar + mlrun.config.config.log_collector.mode = ( + mlrun.common.schemas.LogsCollectorMode.sidecar + ) project_name = "project-name" mlrun.mlconf.namespace = "test-namespace" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec(), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec(), ) # create @@ -720,8 +722,8 @@ def test_list_projects_leader_format( project_names = [] for _ in range(5): project_name = f"prj-{uuid4().hex}" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), ) mlrun.api.utils.singletons.db.get_db().create_project(db, project) project_names.append(project_name) @@ -729,9 +731,9 @@ def test_list_projects_leader_format( # list in leader format response = client.get( "projects", - params={"format": mlrun.api.schemas.ProjectsFormat.leader}, + params={"format": mlrun.common.schemas.ProjectsFormat.leader}, headers={ - mlrun.api.schemas.HeaderNames.projects_role: mlrun.mlconf.httpdb.projects.leader + mlrun.common.schemas.HeaderNames.projects_role: mlrun.mlconf.httpdb.projects.leader }, ) returned_project_names = [ @@ -758,9 +760,9 @@ def test_projects_crud( k8s_secrets_mock.set_is_running_in_k8s_cluster(False) name1 = f"prj-{uuid4().hex}" - project_1 = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=name1), - spec=mlrun.api.schemas.ProjectSpec( + project_1 = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=name1), + spec=mlrun.common.schemas.ProjectSpec( description="banana", source="source", goals="some goals" ), ) @@ -778,7 +780,7 @@ def test_projects_crud( project_patch = { "spec": { "description": "lemon", - "desired_state": mlrun.api.schemas.ProjectState.archived, + "desired_state": mlrun.common.schemas.ProjectState.archived, } } response = client.patch(f"projects/{name1}", json=project_patch) @@ -797,9 +799,9 @@ def test_projects_crud( name2 = f"prj-{uuid4().hex}" labels_2 = {"key": "value"} - project_2 = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=name2, labels=labels_2), - spec=mlrun.api.schemas.ProjectSpec(description="banana2", source="source2"), + project_2 = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=name2, labels=labels_2), + spec=mlrun.common.schemas.ProjectSpec(description="banana2", source="source2"), ) # store @@ -824,9 +826,9 @@ def test_projects_crud( # list - full response = client.get( - "projects", params={"format": mlrun.api.schemas.ProjectsFormat.full} + "projects", params={"format": mlrun.common.schemas.ProjectsFormat.full} ) - projects_output = mlrun.api.schemas.ProjectsOutput(**response.json()) + projects_output = mlrun.common.schemas.ProjectsOutput(**response.json()) expected = [project_1, project_2] for project in projects_output.projects: for _project in expected: @@ -874,7 +876,7 @@ def test_projects_crud( # list - names only - filter by state _list_project_names_and_assert( - client, [name1], params={"state": mlrun.api.schemas.ProjectState.archived} + client, [name1], params={"state": mlrun.common.schemas.ProjectState.archived} ) # add function to project 1 @@ -887,7 +889,7 @@ def test_projects_crud( response = client.delete( f"projects/{name1}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.restricted.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.restricted.value }, ) assert response.status_code == HTTPStatus.PRECONDITION_FAILED.value @@ -896,7 +898,7 @@ def test_projects_crud( response = client.delete( f"projects/{name1}", headers={ - mlrun.api.schemas.HeaderNames.deletion_strategy: mlrun.api.schemas.DeletionStrategy.cascading.value + mlrun.common.schemas.HeaderNames.deletion_strategy: mlrun.common.schemas.DeletionStrategy.cascading.value }, ) assert response.status_code == HTTPStatus.NO_CONTENT.value @@ -916,11 +918,11 @@ def _create_resources_of_all_kinds( ): db = mlrun.api.utils.singletons.db.get_db() # add labels to project - project_schema = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + project_schema = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name=project, labels={"key": "value"} ), - spec=mlrun.api.schemas.ProjectSpec(description="some desc"), + spec=mlrun.common.schemas.ProjectSpec(description="some desc"), ) mlrun.api.utils.singletons.project_member.get_project_member().store_project( db_session, project, project_schema @@ -1014,15 +1016,15 @@ def _create_resources_of_all_kinds( "bla": "blabla", "status": {"bla": "blabla"}, } - schedule_cron_trigger = mlrun.api.schemas.ScheduleCronTrigger(year=1999) + schedule_cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year=1999) schedule_names = ["schedule_name_1", "schedule_name_2", "schedule_name_3"] for schedule_name in schedule_names: mlrun.api.utils.singletons.scheduler.get_scheduler().create_schedule( db_session, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - mlrun.api.schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, schedule, schedule_cron_trigger, labels, @@ -1032,18 +1034,18 @@ def _create_resources_of_all_kinds( labels = { "owner": "nobody", } - feature_set = mlrun.api.schemas.FeatureSet( - metadata=mlrun.api.schemas.ObjectMetadata( + feature_set = mlrun.common.schemas.FeatureSet( + metadata=mlrun.common.schemas.ObjectMetadata( name="dummy", tag="latest", labels=labels ), - spec=mlrun.api.schemas.FeatureSetSpec( + spec=mlrun.common.schemas.FeatureSetSpec( entities=[ - mlrun.api.schemas.Entity( + mlrun.common.schemas.Entity( name="ent1", value_type="str", labels={"label": "1"} ) ], features=[ - mlrun.api.schemas.Feature( + mlrun.common.schemas.Feature( name="feat1", value_type="str", labels={"label": "1"} ) ], @@ -1061,12 +1063,12 @@ def _create_resources_of_all_kinds( feature_set.spec.index = index db.store_feature_set(db_session, project, feature_set_name, feature_set) - feature_vector = mlrun.api.schemas.FeatureVector( - metadata=mlrun.api.schemas.ObjectMetadata( + feature_vector = mlrun.common.schemas.FeatureVector( + metadata=mlrun.common.schemas.ObjectMetadata( name="dummy", tag="latest", labels=labels ), - spec=mlrun.api.schemas.ObjectSpec(), - status=mlrun.api.schemas.ObjectStatus(state="created"), + spec=mlrun.common.schemas.ObjectSpec(), + status=mlrun.common.schemas.ObjectStatus(state="created"), ) feature_vector_names = ["feature_vector_1", "feature_vector_2", "feature_vector_3"] feature_vector_tags = ["some_tag", "some_tag2", "some_tag3"] @@ -1087,7 +1089,7 @@ def _create_resources_of_all_kinds( db_session, name="task", project=project, - state=mlrun.api.schemas.BackgroundTaskState.running, + state=mlrun.common.schemas.BackgroundTaskState.running, ) @@ -1163,7 +1165,7 @@ def _assert_db_resources_in_project( for cls in _classes: # User support is not really implemented or in use # Run tags support is not really implemented or in use - # Marketplace sources is not a project-level table, and hence is not relevant here. + # Hub sources is not a project-level table, and hence is not relevant here. # Version is not a project-level table, and hence is not relevant here. # Features and Entities are not directly linked to project since they are sub-entity of feature-sets # Logs are saved as files, the DB table is not really in use @@ -1171,7 +1173,7 @@ def _assert_db_resources_in_project( if ( cls.__name__ == "User" or cls.__tablename__ == "runs_tags" - or cls.__tablename__ == "marketplace_sources" + or cls.__tablename__ == "hub_sources" or cls.__tablename__ == "data_versions" or cls.__name__ == "Feature" or cls.__name__ == "Entity" @@ -1280,7 +1282,7 @@ def _list_project_names_and_assert( client: TestClient, expected_names: typing.List[str], params: typing.Dict = None ): params = params or {} - params["format"] = mlrun.api.schemas.ProjectsFormat.name_only + params["format"] = mlrun.common.schemas.ProjectsFormat.name_only # list - names only - filter by state response = client.get( "projects", @@ -1297,14 +1299,14 @@ def _list_project_names_and_assert( def _assert_project_response( - expected_project: mlrun.api.schemas.Project, response, extra_exclude: dict = None + expected_project: mlrun.common.schemas.Project, response, extra_exclude: dict = None ): - project = mlrun.api.schemas.Project(**response.json()) + project = mlrun.common.schemas.Project(**response.json()) _assert_project(expected_project, project, extra_exclude) def _assert_project_summary( - project_summary: mlrun.api.schemas.ProjectSummary, + project_summary: mlrun.common.schemas.ProjectSummary, files_count: int, feature_sets_count: int, models_count: int, @@ -1323,8 +1325,8 @@ def _assert_project_summary( def _assert_project( - expected_project: mlrun.api.schemas.Project, - project: mlrun.api.schemas.Project, + expected_project: mlrun.common.schemas.Project, + project: mlrun.common.schemas.Project, extra_exclude: dict = None, ): exclude = {"id": ..., "metadata": {"created"}, "status": {"state"}} @@ -1422,11 +1424,11 @@ def _create_runs( def _create_schedules(client: TestClient, project_name, schedules_count): for index in range(schedules_count): schedule_name = f"schedule-name-{str(uuid4())}" - schedule = mlrun.api.schemas.ScheduleInput( + schedule = mlrun.common.schemas.ScheduleInput( name=schedule_name, - kind=mlrun.api.schemas.ScheduleKinds.job, + kind=mlrun.common.schemas.ScheduleKinds.job, scheduled_object={"metadata": {"name": "something"}}, - cron_trigger=mlrun.api.schemas.ScheduleCronTrigger(year=1999), + cron_trigger=mlrun.common.schemas.ScheduleCronTrigger(year=1999), ) response = client.post( f"projects/{project_name}/schedules", json=schedule.dict() diff --git a/tests/api/api/test_runs.py b/tests/api/api/test_runs.py index 36ea17e6d8d9..368662061afe 100644 --- a/tests/api/api/test_runs.py +++ b/tests/api/api/test_runs.py @@ -22,8 +22,8 @@ from sqlalchemy.orm import Session import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.auth.verifier +import mlrun.common.schemas import mlrun.errors import mlrun.runtimes.constants from mlrun.api.db.sqldb.models import Run @@ -262,9 +262,9 @@ def test_list_runs_partition_by(db: Session, client: TestClient) -> None: client, { "project": projects[0], - "partition-by": mlrun.api.schemas.RunPartitionByField.name, - "partition-sort-by": mlrun.api.schemas.SortField.created, - "partition-order": mlrun.api.schemas.OrderType.asc, + "partition-by": mlrun.common.schemas.RunPartitionByField.name, + "partition-sort-by": mlrun.common.schemas.SortField.created, + "partition-order": mlrun.common.schemas.OrderType.asc, }, 3, ) @@ -277,9 +277,9 @@ def test_list_runs_partition_by(db: Session, client: TestClient) -> None: client, { "project": projects[0], - "partition-by": mlrun.api.schemas.RunPartitionByField.name, - "partition-sort-by": mlrun.api.schemas.SortField.updated, - "partition-order": mlrun.api.schemas.OrderType.desc, + "partition-by": mlrun.common.schemas.RunPartitionByField.name, + "partition-sort-by": mlrun.common.schemas.SortField.updated, + "partition-order": mlrun.common.schemas.OrderType.desc, }, 3, ) @@ -292,9 +292,9 @@ def test_list_runs_partition_by(db: Session, client: TestClient) -> None: client, { "project": projects[0], - "partition-by": mlrun.api.schemas.RunPartitionByField.name, - "partition-sort-by": mlrun.api.schemas.SortField.updated, - "partition-order": mlrun.api.schemas.OrderType.desc, + "partition-by": mlrun.common.schemas.RunPartitionByField.name, + "partition-sort-by": mlrun.common.schemas.SortField.updated, + "partition-order": mlrun.common.schemas.OrderType.desc, "rows-per-partition": 5, }, 15, @@ -305,9 +305,9 @@ def test_list_runs_partition_by(db: Session, client: TestClient) -> None: client, { "project": projects[0], - "partition-by": mlrun.api.schemas.RunPartitionByField.name, - "partition-sort-by": mlrun.api.schemas.SortField.updated, - "partition-order": mlrun.api.schemas.OrderType.desc, + "partition-by": mlrun.common.schemas.RunPartitionByField.name, + "partition-sort-by": mlrun.common.schemas.SortField.updated, + "partition-order": mlrun.common.schemas.OrderType.desc, "rows-per-partition": 5, "max-partitions": 2, }, @@ -323,9 +323,9 @@ def test_list_runs_partition_by(db: Session, client: TestClient) -> None: { "project": projects[0], "iter": False, - "partition-by": mlrun.api.schemas.RunPartitionByField.name, - "partition-sort-by": mlrun.api.schemas.SortField.updated, - "partition-order": mlrun.api.schemas.OrderType.desc, + "partition-by": mlrun.common.schemas.RunPartitionByField.name, + "partition-sort-by": mlrun.common.schemas.SortField.updated, + "partition-order": mlrun.common.schemas.OrderType.desc, "rows-per-partition": 2, "max-partitions": 1, }, diff --git a/tests/api/api/test_runtime_resources.py b/tests/api/api/test_runtime_resources.py index 8e0fabc255df..ed88e9827fee 100644 --- a/tests/api/api/test_runtime_resources.py +++ b/tests/api/api/test_runtime_resources.py @@ -22,8 +22,8 @@ import mlrun.api.api.endpoints.runtime_resources import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas def test_list_runtimes_resources_opa_filtering( @@ -83,7 +83,7 @@ def test_list_runtimes_resources_group_by_job( ) response = client.get( "projects/*/runtime-resources", - params={"group-by": mlrun.api.schemas.ListRuntimeResourcesGroupByField.job}, + params={"group-by": mlrun.common.schemas.ListRuntimeResourcesGroupByField.job}, ) body = response.json() expected_body = { @@ -140,9 +140,9 @@ def test_list_runtimes_resources_no_group_by( ) body = response.json() expected_body = [ - mlrun.api.schemas.KindRuntimeResources( + mlrun.common.schemas.KindRuntimeResources( kind=mlrun.runtimes.RuntimeKinds.job, - resources=mlrun.api.schemas.RuntimeResources( + resources=mlrun.common.schemas.RuntimeResources( crd_resources=[], pod_resources=grouped_by_project_runtime_resources_output[project_1][ mlrun.runtimes.RuntimeKinds.job @@ -152,9 +152,9 @@ def test_list_runtimes_resources_no_group_by( ].pod_resources, ), ).dict(), - mlrun.api.schemas.KindRuntimeResources( + mlrun.common.schemas.KindRuntimeResources( kind=mlrun.runtimes.RuntimeKinds.dask, - resources=mlrun.api.schemas.RuntimeResources( + resources=mlrun.common.schemas.RuntimeResources( crd_resources=[], pod_resources=grouped_by_project_runtime_resources_output[project_2][ mlrun.runtimes.RuntimeKinds.dask @@ -164,9 +164,9 @@ def test_list_runtimes_resources_no_group_by( ][mlrun.runtimes.RuntimeKinds.dask].service_resources, ), ).dict(), - mlrun.api.schemas.KindRuntimeResources( + mlrun.common.schemas.KindRuntimeResources( kind=mlrun.runtimes.RuntimeKinds.mpijob, - resources=mlrun.api.schemas.RuntimeResources( + resources=mlrun.common.schemas.RuntimeResources( crd_resources=grouped_by_project_runtime_resources_output[project_3][ mlrun.runtimes.RuntimeKinds.mpijob ].crd_resources, @@ -201,13 +201,15 @@ def test_list_runtime_resources_no_resources( assert body == [] response = client.get( "projects/*/runtime-resources", - params={"group-by": mlrun.api.schemas.ListRuntimeResourcesGroupByField.job}, + params={"group-by": mlrun.common.schemas.ListRuntimeResourcesGroupByField.job}, ) body = response.json() assert body == {} response = client.get( "projects/*/runtime-resources", - params={"group-by": mlrun.api.schemas.ListRuntimeResourcesGroupByField.project}, + params={ + "group-by": mlrun.common.schemas.ListRuntimeResourcesGroupByField.project + }, ) body = response.json() assert body == {} @@ -251,9 +253,9 @@ def test_list_runtime_resources_filter_by_kind( params={"kind": mlrun.runtimes.RuntimeKinds.job}, ) body = response.json() - expected_runtime_resources = mlrun.api.schemas.KindRuntimeResources( + expected_runtime_resources = mlrun.common.schemas.KindRuntimeResources( kind=mlrun.runtimes.RuntimeKinds.job, - resources=mlrun.api.schemas.RuntimeResources( + resources=mlrun.common.schemas.RuntimeResources( crd_resources=[], pod_resources=grouped_by_project_runtime_resources_output[project_1][ mlrun.runtimes.RuntimeKinds.job @@ -523,9 +525,9 @@ def _generate_grouped_by_project_runtime_resources_with_legacy_builder_output(): no_project_builder_name = "builder-name" grouped_by_project_runtime_resources_output = { project_1: { - mlrun.runtimes.RuntimeKinds.job: mlrun.api.schemas.RuntimeResources( + mlrun.runtimes.RuntimeKinds.job: mlrun.common.schemas.RuntimeResources( pod_resources=[ - mlrun.api.schemas.RuntimeResource( + mlrun.common.schemas.RuntimeResource( name=project_1_job_name, labels={ "mlrun/project": project_1, @@ -539,9 +541,9 @@ def _generate_grouped_by_project_runtime_resources_with_legacy_builder_output(): ) }, no_project: { - mlrun.runtimes.RuntimeKinds.job: mlrun.api.schemas.RuntimeResources( + mlrun.runtimes.RuntimeKinds.job: mlrun.common.schemas.RuntimeResources( pod_resources=[ - mlrun.api.schemas.RuntimeResource( + mlrun.common.schemas.RuntimeResource( name=no_project_builder_name, labels={ "mlrun/class": "build", @@ -571,9 +573,9 @@ def _generate_grouped_by_project_runtime_resources_output(): project_3_mpijob_name = "project-3-mpijob-name" grouped_by_project_runtime_resources_output = { project_1: { - mlrun.runtimes.RuntimeKinds.job: mlrun.api.schemas.RuntimeResources( + mlrun.runtimes.RuntimeKinds.job: mlrun.common.schemas.RuntimeResources( pod_resources=[ - mlrun.api.schemas.RuntimeResource( + mlrun.common.schemas.RuntimeResource( name=project_1_job_name, labels={ "mlrun/project": project_1, @@ -587,9 +589,9 @@ def _generate_grouped_by_project_runtime_resources_output(): ) }, project_2: { - mlrun.runtimes.RuntimeKinds.dask: mlrun.api.schemas.RuntimeResources( + mlrun.runtimes.RuntimeKinds.dask: mlrun.common.schemas.RuntimeResources( pod_resources=[ - mlrun.api.schemas.RuntimeResource( + mlrun.common.schemas.RuntimeResource( name=project_2_dask_name, labels={ "mlrun/project": project_2, @@ -601,7 +603,7 @@ def _generate_grouped_by_project_runtime_resources_output(): ], crd_resources=[], service_resources=[ - mlrun.api.schemas.RuntimeResource( + mlrun.common.schemas.RuntimeResource( name=project_2_dask_name, labels={ "mlrun/project": project_2, @@ -612,9 +614,9 @@ def _generate_grouped_by_project_runtime_resources_output(): ) ], ), - mlrun.runtimes.RuntimeKinds.job: mlrun.api.schemas.RuntimeResources( + mlrun.runtimes.RuntimeKinds.job: mlrun.common.schemas.RuntimeResources( pod_resources=[ - mlrun.api.schemas.RuntimeResource( + mlrun.common.schemas.RuntimeResource( name=project_2_job_name, labels={ "mlrun/project": project_2, @@ -628,10 +630,10 @@ def _generate_grouped_by_project_runtime_resources_output(): ), }, project_3: { - mlrun.runtimes.RuntimeKinds.mpijob: mlrun.api.schemas.RuntimeResources( + mlrun.runtimes.RuntimeKinds.mpijob: mlrun.common.schemas.RuntimeResources( pod_resources=[], crd_resources=[ - mlrun.api.schemas.RuntimeResource( + mlrun.common.schemas.RuntimeResource( name=project_3_mpijob_name, labels={ "mlrun/project": project_3, @@ -658,7 +660,7 @@ def _generate_grouped_by_project_runtime_resources_output(): def _mock_opa_filter_and_assert_list_response( client: fastapi.testclient.TestClient, - grouped_by_project_runtime_resources_output: mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + grouped_by_project_runtime_resources_output: mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, opa_filter_response, ): mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions = unittest.mock.AsyncMock( @@ -666,7 +668,9 @@ def _mock_opa_filter_and_assert_list_response( ) response = client.get( "projects/*/runtime-resources", - params={"group-by": mlrun.api.schemas.ListRuntimeResourcesGroupByField.project}, + params={ + "group-by": mlrun.common.schemas.ListRuntimeResourcesGroupByField.project + }, ) body = response.json() expected_body = ( @@ -687,7 +691,7 @@ def _mock_opa_filter_and_assert_list_response( def _filter_allowed_projects_and_kind_from_grouped_by_project_runtime_resources_output( allowed_projects: typing.List[str], filter_kind: str, - grouped_by_project_runtime_resources_output: mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + grouped_by_project_runtime_resources_output: mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, structured: bool = False, ): filtered_output = ( @@ -702,7 +706,7 @@ def _filter_allowed_projects_and_kind_from_grouped_by_project_runtime_resources_ def _filter_kind_from_grouped_by_project_runtime_resources_output( filter_kind: str, - grouped_by_project_runtime_resources_output: mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + grouped_by_project_runtime_resources_output: mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, ): filtered_output = {} for ( @@ -719,7 +723,7 @@ def _filter_kind_from_grouped_by_project_runtime_resources_output( def _filter_allowed_projects_from_grouped_by_project_runtime_resources_output( allowed_projects: typing.List[str], - grouped_by_project_runtime_resources_output: mlrun.api.schemas.GroupedByProjectRuntimeResourcesOutput, + grouped_by_project_runtime_resources_output: mlrun.common.schemas.GroupedByProjectRuntimeResourcesOutput, structured: bool = False, ): filtered_output = {} diff --git a/tests/api/api/test_schedules.py b/tests/api/api/test_schedules.py index 4b0461cb5a80..e402880fd887 100644 --- a/tests/api/api/test_schedules.py +++ b/tests/api/api/test_schedules.py @@ -24,8 +24,8 @@ import mlrun.api.utils.auth.verifier import mlrun.api.utils.singletons.project_member import mlrun.api.utils.singletons.scheduler +import mlrun.common.schemas import tests.api.api.utils -from mlrun.api import schemas from mlrun.api.utils.singletons.db import get_db from tests.common_fixtures import aioresponses_mock @@ -46,14 +46,14 @@ def test_list_schedules( labels_1 = { "label1": "value1", } - cron_trigger = schemas.ScheduleCronTrigger(year="1999") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") schedule_name = "schedule-name" project = mlrun.mlconf.default_project get_db().create_schedule( db, project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, mlrun.mlconf.httpdb.scheduling.default_concurrency_limit, @@ -68,7 +68,7 @@ def test_list_schedules( db, project, schedule_name_2, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, mlrun.mlconf.httpdb.scheduling.default_concurrency_limit, diff --git a/tests/api/api/test_secrets.py b/tests/api/api/test_secrets.py index 2f1fc7e3c71a..eb332ab50e71 100644 --- a/tests/api/api/test_secrets.py +++ b/tests/api/api/test_secrets.py @@ -19,8 +19,8 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session +import mlrun.common.schemas from mlrun import mlconf -from mlrun.api import schemas # Set a valid Vault token to run this test. # For this test, you must also have a k8s cluster available (minikube is good enough). @@ -45,8 +45,11 @@ def test_vault_create_project_secrets(db: Session, client: TestClient): response = client.post(f"projects/{project_name}/secrets", json=data) assert response.status_code == HTTPStatus.CREATED.value - params = {"provider": schemas.SecretProviderName.vault.value, "secrets": None} - headers = {schemas.HeaderNames.secret_store_token: user_token} + params = { + "provider": mlrun.common.schemas.SecretProviderName.vault.value, + "secrets": None, + } + headers = {mlrun.common.schemas.HeaderNames.secret_store_token: user_token} response = client.get( f"projects/{project_name}/secrets", headers=headers, params=params diff --git a/tests/api/api/test_submit.py b/tests/api/api/test_submit.py index d50bf04df208..9579d9fe7393 100644 --- a/tests/api/api/test_submit.py +++ b/tests/api/api/test_submit.py @@ -31,9 +31,9 @@ import mlrun.api.utils.auth.verifier import mlrun.api.utils.clients.chief import mlrun.api.utils.clients.iguazio +import mlrun.api.utils.singletons.k8s import tests.api.api.utils -from mlrun.api.schemas import AuthInfo -from mlrun.api.utils.singletons.k8s import get_k8s +from mlrun.common.schemas import AuthInfo from mlrun.config import config as mlconf from tests.api.conftest import K8sSecretsMock @@ -65,27 +65,28 @@ def test_submit_job_failure_function_not_found(db: Session, client: TestClient) @pytest.fixture() def pod_create_mock(): - create_pod_orig_function = get_k8s().create_pod + create_pod_orig_function = ( + mlrun.api.utils.singletons.k8s.get_k8s_helper().create_pod + ) _get_project_secrets_raw_data_orig_function = ( - get_k8s()._get_project_secrets_raw_data + mlrun.api.utils.singletons.k8s.get_k8s_helper()._get_project_secrets_raw_data + ) + mlrun.api.utils.singletons.k8s.get_k8s_helper().create_pod = unittest.mock.Mock( + return_value=("pod-name", "namespace") + ) + mlrun.api.utils.singletons.k8s.get_k8s_helper()._get_project_secrets_raw_data = ( + unittest.mock.Mock(return_value={}) ) - get_k8s().create_pod = unittest.mock.Mock(return_value=("pod-name", "namespace")) - get_k8s()._get_project_secrets_raw_data = unittest.mock.Mock(return_value={}) update_run_state_orig_function = ( mlrun.runtimes.kubejob.KubejobRuntime._update_run_state ) - mlrun.runtimes.kubejob.KubejobRuntime._update_run_state = unittest.mock.Mock() + mlrun.runtimes.kubejob.KubejobRuntime._update_run_state = unittest.mock.MagicMock() mock_run_object = mlrun.RunObject() mock_run_object.metadata.uid = "1234567890" mock_run_object.metadata.project = "project-name" - wrap_run_result_orig_function = mlrun.runtimes.base.BaseRuntime._wrap_run_result - mlrun.runtimes.base.BaseRuntime._wrap_run_result = unittest.mock.Mock( - return_value=mock_run_object - ) - auth_info_mock = AuthInfo( username=username, session="some-session", data_session=access_key ) @@ -97,17 +98,18 @@ def pod_create_mock(): unittest.mock.AsyncMock(return_value=auth_info_mock) ) - yield get_k8s().create_pod + yield mlrun.api.utils.singletons.k8s.get_k8s_helper().create_pod # Have to revert the mocks, otherwise other tests are failing - get_k8s().create_pod = create_pod_orig_function - get_k8s()._get_project_secrets_raw_data = ( + mlrun.api.utils.singletons.k8s.get_k8s_helper().create_pod = ( + create_pod_orig_function + ) + mlrun.api.utils.singletons.k8s.get_k8s_helper()._get_project_secrets_raw_data = ( _get_project_secrets_raw_data_orig_function ) mlrun.runtimes.kubejob.KubejobRuntime._update_run_state = ( update_run_state_orig_function ) - mlrun.runtimes.base.BaseRuntime._wrap_run_result = wrap_run_result_orig_function mlrun.api.utils.auth.verifier.AuthVerifier().authenticate_request = ( authenticate_request_orig_function ) @@ -144,7 +146,7 @@ def test_submit_job_auto_mount( "V3IO_USERNAME": username, "V3IO_ACCESS_KEY": ( secret_name, - mlrun.api.schemas.AuthSecretData.get_field_secret_key("access_key"), + mlrun.common.schemas.AuthSecretData.get_field_secret_key("access_key"), ), } _assert_pod_env_vars(pod_create_mock, expected_env_vars) @@ -174,12 +176,40 @@ def test_submit_job_ensure_function_has_auth_set( expected_env_vars = { mlrun.runtimes.constants.FunctionEnvironmentVariables.auth_session: ( secret_name, - mlrun.api.schemas.AuthSecretData.get_field_secret_key("access_key"), + mlrun.common.schemas.AuthSecretData.get_field_secret_key("access_key"), ), } _assert_pod_env_vars(pod_create_mock, expected_env_vars) +def test_submit_schedule_job_from_hub_from_ui( + db: Session, client: TestClient, pod_create_mock, k8s_secrets_mock +) -> None: + project = "my-proj1" + hub_function_uri = "hub://aggregate" + + tests.api.api.utils.create_project(client, project) + + function = mlrun.import_function(hub_function_uri) + submit_job_body = _create_submit_job_body(function, project) + + # replicate UI behavior + submit_job_body["task"]["spec"]["function"] = hub_function_uri + submit_job_body["schedule"] = "*/15 * * * *" + + resp = client.post("submit_job", json=submit_job_body) + assert resp.status_code == http.HTTPStatus.OK.value + + resp = client.get(f"projects/{project}/schedules") + assert resp.status_code == http.HTTPStatus.OK.value + + schedules = resp.json().get("schedules", []) + assert len(schedules) == 1 + + schedule = schedules[0] + assert schedule["scheduled_object"]["task"]["spec"]["function"] != hub_function_uri + + def test_submit_job_with_output_path_enrichment( db: Session, client: TestClient, pod_create_mock, k8s_secrets_mock ) -> None: @@ -337,7 +367,7 @@ def test_submit_job_with_hyper_params_file( ) async def auth_info_mock(*args, **kwargs): - return mlrun.api.schemas.AuthInfo(username="user", data_session=access_key) + return mlrun.common.schemas.AuthInfo(username="user", data_session=access_key) # Create test-specific mocks monkeypatch.setattr( @@ -526,7 +556,7 @@ def _create_submit_job_body(function, project, with_output_path=True): def _create_submit_job_body_with_schedule(function, project): job_body = _create_submit_job_body(function, project) - job_body["schedule"] = mlrun.api.schemas.ScheduleCronTrigger(year=1999).dict() + job_body["schedule"] = mlrun.common.schemas.ScheduleCronTrigger(year=1999).dict() return job_body diff --git a/tests/api/api/test_tags.py b/tests/api/api/test_tags.py index b9bcff56eef6..e89a93b415b3 100644 --- a/tests/api/api/test_tags.py +++ b/tests/api/api/test_tags.py @@ -21,7 +21,7 @@ import fastapi.testclient import sqlalchemy.orm -import mlrun.api.schemas +import mlrun.common.schemas API_PROJECTS_PATH = "projects" API_ARTIFACTS_PATH = "projects/{project}/artifacts" @@ -52,7 +52,7 @@ def test_overwrite_artifact_tags_by_key_identifier( client=client, tag=overwrite_tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(key=artifact1_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact1_key), ], ) assert response.status_code == http.HTTPStatus.OK.value @@ -86,7 +86,7 @@ def test_overwrite_artifact_tags_by_uid_identifier( client=client, tag=overwrite_tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(uid=artifact1_uid), + mlrun.common.schemas.ArtifactIdentifier(uid=artifact1_uid), ], ) assert response.status_code == http.HTTPStatus.OK.value @@ -120,8 +120,8 @@ def test_overwrite_artifact_tags_by_multiple_uid_identifiers( client=client, tag=overwrite_tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(uid=artifact1_uid), - mlrun.api.schemas.ArtifactIdentifier(uid=artifact2_uid), + mlrun.common.schemas.ArtifactIdentifier(uid=artifact1_uid), + mlrun.common.schemas.ArtifactIdentifier(uid=artifact2_uid), ], ) assert response.status_code == http.HTTPStatus.OK.value @@ -153,8 +153,8 @@ def test_overwrite_artifact_tags_by_multiple_key_identifiers( client=client, tag=overwrite_tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(key=artifact1_key), - mlrun.api.schemas.ArtifactIdentifier(key=artifact2_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact1_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact2_key), ], ) assert response.status_code == http.HTTPStatus.OK.value @@ -186,7 +186,7 @@ def test_append_artifact_tags_by_key_identifier( client=client, tag=new_tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(key=artifact1_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact1_key), ], ) assert response.status_code == http.HTTPStatus.OK.value @@ -223,7 +223,7 @@ def test_append_artifact_tags_by_uid_identifier_latest( client=client, tag=new_tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier( + mlrun.common.schemas.ArtifactIdentifier( key=artifact1_key, uid=artifact1_uid ), ], @@ -269,7 +269,7 @@ def test_create_and_append_artifact_tags_with_invalid_characters( client=client, tag=invalid_tag_name, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier( + mlrun.common.schemas.ArtifactIdentifier( key=artifact1_key, uid=artifact1_uid ), ], @@ -299,7 +299,7 @@ def test_overwrite_artifact_tags_with_invalid_characters( client=client, tag=invalid_tag_name, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier( + mlrun.common.schemas.ArtifactIdentifier( key=artifact_key, uid=artifact_uid ), ], @@ -339,7 +339,7 @@ def test_delete_artifact_tags_with_invalid_characters( client=client, tag=invalid_tag_name, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier( + mlrun.common.schemas.ArtifactIdentifier( key=artifact_key, uid=artifact_uid ), ], @@ -369,7 +369,7 @@ def test_append_artifact_tags_by_uid_identifier( client=client, tag=new_tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(uid=artifact1_uid), + mlrun.common.schemas.ArtifactIdentifier(uid=artifact1_uid), ], ) assert response.status_code == http.HTTPStatus.OK.value @@ -404,8 +404,8 @@ def test_append_artifact_tags_by_multiple_key_identifiers( client=client, tag=new_tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(key=artifact1_key), - mlrun.api.schemas.ArtifactIdentifier(key=artifact2_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact1_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact2_key), ], ) assert response.status_code == http.HTTPStatus.OK.value @@ -440,7 +440,7 @@ def test_append_artifact_existing_tag( client=client, tag=tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(key=artifact1_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact1_key), ], ) assert response.status_code == http.HTTPStatus.OK.value @@ -470,7 +470,7 @@ def test_delete_artifact_tag_by_key_identifier( client=client, tag=tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(key=artifact1_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact1_key), ], ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value @@ -497,7 +497,7 @@ def test_delete_artifact_tag_by_uid_identifier( client=client, tag=tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(uid=artifact1_uid), + mlrun.common.schemas.ArtifactIdentifier(uid=artifact1_uid), ], ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value @@ -525,8 +525,8 @@ def test_delete_artifact_tag_by_multiple_key_identifiers( client=client, tag=tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(key=artifact1_key), - mlrun.api.schemas.ArtifactIdentifier(key=artifact2_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact1_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact2_key), ], ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value @@ -553,8 +553,8 @@ def test_delete_artifact_tag_but_artifact_has_no_tag( client=client, tag=tag, identifiers=[ - mlrun.api.schemas.ArtifactIdentifier(key=artifact1_key), - mlrun.api.schemas.ArtifactIdentifier(key=artifact2_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact1_key), + mlrun.common.schemas.ArtifactIdentifier(key=artifact2_key), ], ) assert response.status_code == http.HTTPStatus.NO_CONTENT.value @@ -566,7 +566,7 @@ def _delete_artifact_tag( client, tag: str, identifiers: typing.List[ - typing.Union[typing.Dict, mlrun.api.schemas.ArtifactIdentifier] + typing.Union[typing.Dict, mlrun.common.schemas.ArtifactIdentifier] ], project: str = None, ): @@ -583,7 +583,7 @@ def _append_artifact_tag( client, tag: str, identifiers: typing.List[ - typing.Union[typing.Dict, mlrun.api.schemas.ArtifactIdentifier] + typing.Union[typing.Dict, mlrun.common.schemas.ArtifactIdentifier] ], project: str = None, ): @@ -597,7 +597,7 @@ def _overwrite_artifact_tags( client, tag: str, identifiers: typing.List[ - typing.Union[typing.Dict, mlrun.api.schemas.ArtifactIdentifier] + typing.Union[typing.Dict, mlrun.common.schemas.ArtifactIdentifier] ], project: str = None, ): @@ -609,7 +609,7 @@ def _overwrite_artifact_tags( @staticmethod def _generate_tag_identifiers_json( identifiers: typing.List[ - typing.Union[typing.Dict, mlrun.api.schemas.ArtifactIdentifier] + typing.Union[typing.Dict, mlrun.common.schemas.ArtifactIdentifier] ], ): return { @@ -617,7 +617,7 @@ def _generate_tag_identifiers_json( "identifiers": [ ( identifier.dict() - if isinstance(identifier, mlrun.api.schemas.ArtifactIdentifier) + if isinstance(identifier, mlrun.common.schemas.ArtifactIdentifier) else identifier ) for identifier in identifiers @@ -649,11 +649,11 @@ def _assert_tag(artifacts, expected_tag): def _create_project( self, client: fastapi.testclient.TestClient, project_name: str = None ): - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name=project_name or self.project ), - spec=mlrun.api.schemas.ProjectSpec( + spec=mlrun.common.schemas.ProjectSpec( description="banana", source="source", goals="some goals" ), ) diff --git a/tests/api/api/test_utils.py b/tests/api/api/test_utils.py index 00aa22c98c29..847df6b2636c 100644 --- a/tests/api/api/test_utils.py +++ b/tests/api/api/test_utils.py @@ -25,9 +25,9 @@ import mlrun import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.auth.verifier import mlrun.api.utils.clients.iguazio +import mlrun.common.schemas import mlrun.k8s_utils import mlrun.runtimes.pod import tests.api.api.utils @@ -40,7 +40,7 @@ ensure_function_security_context, get_scheduler, ) -from mlrun.api.schemas import SecurityContextEnrichmentModes +from mlrun.common.schemas import SecurityContextEnrichmentModes from mlrun.utils import logger # Want to use k8s_secrets_mock for all tests in this module. It is needed since @@ -50,7 +50,7 @@ def test_submit_run_sync(db: Session, client: TestClient): - auth_info = mlrun.api.schemas.AuthInfo() + auth_info = mlrun.common.schemas.AuthInfo() tests.api.api.utils.create_project(client, PROJECT) project, function_name, function_tag, original_function = _mock_original_function( client @@ -242,7 +242,7 @@ def test_generate_function_and_task_from_submit_run_body_body_override_values( }, } parsed_function_object, task = _generate_function_and_task_from_submit_run_body( - db, mlrun.api.schemas.AuthInfo(), submit_job_body + db, mlrun.common.schemas.AuthInfo(), submit_job_body ) assert parsed_function_object.metadata.name == function_name assert parsed_function_object.metadata.project == project @@ -343,7 +343,7 @@ def test_generate_function_and_task_from_submit_run_with_preemptible_nodes_and_t ), ) parsed_function_object, task = _generate_function_and_task_from_submit_run_body( - db, mlrun.api.schemas.AuthInfo(), submit_job_body + db, mlrun.common.schemas.AuthInfo(), submit_job_body ) assert ( parsed_function_object.spec.preemption_mode @@ -372,7 +372,7 @@ def test_generate_function_and_task_from_submit_run_with_preemptible_nodes_and_t "function": {"spec": {"preemption_mode": "constrain"}}, } parsed_function_object, task = _generate_function_and_task_from_submit_run_body( - db, mlrun.api.schemas.AuthInfo(), submit_job_body + db, mlrun.common.schemas.AuthInfo(), submit_job_body ) expected_affinity = kubernetes.client.V1Affinity( node_affinity=kubernetes.client.V1NodeAffinity( @@ -407,7 +407,7 @@ def test_generate_function_and_task_from_submit_run_body_keep_resources( "function": {"spec": {"resources": {"limits": {}, "requests": {}}}}, } parsed_function_object, task = _generate_function_and_task_from_submit_run_body( - db, mlrun.api.schemas.AuthInfo(), submit_job_body + db, mlrun.common.schemas.AuthInfo(), submit_job_body ) assert parsed_function_object.metadata.name == function_name assert parsed_function_object.metadata.project == PROJECT @@ -448,7 +448,7 @@ def test_generate_function_and_task_from_submit_run_body_keep_credentials( "function": {"metadata": {"credentials": None}}, } parsed_function_object, task = _generate_function_and_task_from_submit_run_body( - db, mlrun.api.schemas.AuthInfo(), submit_job_body + db, mlrun.common.schemas.AuthInfo(), submit_job_body ) assert parsed_function_object.metadata.name == function_name assert parsed_function_object.metadata.project == project @@ -471,7 +471,7 @@ def test_ensure_function_has_auth_set( ) original_function = mlrun.new_function(runtime=original_function_dict) function = mlrun.new_function(runtime=original_function_dict) - ensure_function_has_auth_set(function, mlrun.api.schemas.AuthInfo()) + ensure_function_has_auth_set(function, mlrun.common.schemas.AuthInfo()) assert ( DeepDiff( original_function.to_dict(), @@ -494,7 +494,7 @@ def test_ensure_function_has_auth_set( unittest.mock.Mock(return_value=access_key) ) ensure_function_has_auth_set( - function, mlrun.api.schemas.AuthInfo(username=username) + function, mlrun.common.schemas.AuthInfo(username=username) ) assert ( DeepDiff( @@ -519,7 +519,7 @@ def test_ensure_function_has_auth_set( function, mlrun.runtimes.constants.FunctionEnvironmentVariables.auth_session, secret_name, - mlrun.api.schemas.AuthSecretData.get_field_secret_key("access_key"), + mlrun.common.schemas.AuthSecretData.get_field_secret_key("access_key"), ) logger.info("No access key - explode") @@ -531,7 +531,7 @@ def test_ensure_function_has_auth_set( mlrun.errors.MLRunInvalidArgumentError, match=r"(.*)Function access key must be set(.*)", ): - ensure_function_has_auth_set(function, mlrun.api.schemas.AuthInfo()) + ensure_function_has_auth_set(function, mlrun.common.schemas.AuthInfo()) logger.info("Access key without username - explode") _, _, _, original_function_dict = _generate_original_function( @@ -541,7 +541,7 @@ def test_ensure_function_has_auth_set( with pytest.raises( mlrun.errors.MLRunInvalidArgumentError, match=r"(.*)Username is missing(.*)" ): - ensure_function_has_auth_set(function, mlrun.api.schemas.AuthInfo()) + ensure_function_has_auth_set(function, mlrun.common.schemas.AuthInfo()) logger.info("Access key ref provided - env should be set") secret_name = "some-access-key-secret-name" @@ -552,7 +552,7 @@ def test_ensure_function_has_auth_set( ) original_function = mlrun.new_function(runtime=original_function_dict) function = mlrun.new_function(runtime=original_function_dict) - ensure_function_has_auth_set(function, mlrun.api.schemas.AuthInfo()) + ensure_function_has_auth_set(function, mlrun.common.schemas.AuthInfo()) assert ( DeepDiff( original_function.to_dict(), @@ -567,7 +567,7 @@ def test_ensure_function_has_auth_set( function, mlrun.runtimes.constants.FunctionEnvironmentVariables.auth_session, secret_name, - mlrun.api.schemas.AuthSecretData.get_field_secret_key("access_key"), + mlrun.common.schemas.AuthSecretData.get_field_secret_key("access_key"), ) logger.info( @@ -582,7 +582,7 @@ def test_ensure_function_has_auth_set( original_function = mlrun.new_function(runtime=original_function_dict) function = mlrun.new_function(runtime=original_function_dict) ensure_function_has_auth_set( - function, mlrun.api.schemas.AuthInfo(username=username) + function, mlrun.common.schemas.AuthInfo(username=username) ) secret_name = k8s_secrets_mock.get_auth_secret_name(username, access_key) k8s_secrets_mock.assert_auth_secret(secret_name, username, access_key) @@ -608,7 +608,7 @@ def test_ensure_function_has_auth_set( function, mlrun.runtimes.constants.FunctionEnvironmentVariables.auth_session, secret_name, - mlrun.api.schemas.AuthSecretData.get_field_secret_key("access_key"), + mlrun.common.schemas.AuthSecretData.get_field_secret_key("access_key"), ) @@ -621,7 +621,7 @@ def test_mask_v3io_access_key_env_var( _, _, _, original_function_dict = _generate_original_function() original_function = mlrun.new_function(runtime=original_function_dict) function = mlrun.new_function(runtime=original_function_dict) - _mask_v3io_access_key_env_var(function, mlrun.api.schemas.AuthInfo()) + _mask_v3io_access_key_env_var(function, mlrun.common.schemas.AuthInfo()) assert ( DeepDiff( original_function.to_dict(), @@ -646,7 +646,7 @@ def test_mask_v3io_access_key_env_var( mlrun.errors.MLRunInvalidArgumentError, match=r"(.*)Username is missing(.*)", ): - _mask_v3io_access_key_env_var(function, mlrun.api.schemas.AuthInfo()) + _mask_v3io_access_key_env_var(function, mlrun.common.schemas.AuthInfo()) logger.info( "Mask function with access key without username when iguazio auth off - skip" @@ -659,7 +659,7 @@ def test_mask_v3io_access_key_env_var( ) original_function = mlrun.new_function(runtime=original_function_dict) function = mlrun.new_function(runtime=original_function_dict) - _mask_v3io_access_key_env_var(function, mlrun.api.schemas.AuthInfo()) + _mask_v3io_access_key_env_var(function, mlrun.common.schemas.AuthInfo()) assert ( DeepDiff( original_function.to_dict(), @@ -681,7 +681,7 @@ def test_mask_v3io_access_key_env_var( function: mlrun.runtimes.pod.KubeResource = mlrun.new_function( runtime=original_function_dict ) - _mask_v3io_access_key_env_var(function, mlrun.api.schemas.AuthInfo()) + _mask_v3io_access_key_env_var(function, mlrun.common.schemas.AuthInfo()) assert ( DeepDiff( original_function.to_dict(), @@ -698,14 +698,14 @@ def test_mask_v3io_access_key_env_var( function, "V3IO_ACCESS_KEY", secret_name, - mlrun.api.schemas.AuthSecretData.get_field_secret_key("access_key"), + mlrun.common.schemas.AuthSecretData.get_field_secret_key("access_key"), ) logger.info( "mask same function again, access key is already a reference - nothing should change" ) original_function = mlrun.new_function(runtime=function) - _mask_v3io_access_key_env_var(function, mlrun.api.schemas.AuthInfo()) + _mask_v3io_access_key_env_var(function, mlrun.common.schemas.AuthInfo()) mlrun.api.crud.Secrets().store_auth_secret = unittest.mock.Mock() assert ( DeepDiff( @@ -724,7 +724,7 @@ def test_mask_v3io_access_key_env_var( function.spec.env.append(function.spec.env.pop().to_dict()) original_function = mlrun.new_function(runtime=function) _mask_v3io_access_key_env_var( - function, mlrun.api.schemas.AuthInfo(username=username) + function, mlrun.common.schemas.AuthInfo(username=username) ) mlrun.api.crud.Secrets().store_auth_secret = unittest.mock.Mock() assert ( @@ -904,7 +904,7 @@ def test_ensure_function_security_context_no_enrichment( db: Session, client: TestClient ): tests.api.api.utils.create_project(client, PROJECT) - auth_info = mlrun.api.schemas.AuthInfo(user_unix_id=1000) + auth_info = mlrun.common.schemas.AuthInfo(user_unix_id=1000) mlrun.mlconf.igz_version = "3.6" logger.info("Enrichment mode is disabled, nothing should be changed") @@ -955,7 +955,7 @@ def test_ensure_function_security_context_no_enrichment( ) original_function = mlrun.new_function(runtime=original_function_dict_job_kind) function = mlrun.new_function(runtime=original_function_dict_job_kind) - ensure_function_security_context(function, mlrun.api.schemas.AuthInfo()) + ensure_function_security_context(function, mlrun.common.schemas.AuthInfo()) assert ( DeepDiff( original_function.to_dict(), @@ -977,7 +977,7 @@ def test_ensure_function_security_context_override_enrichment_mode( logger.info("Enrichment mode is override, security context should be enriched") mlrun.api.utils.clients.iguazio.Client.get_user_unix_id = unittest.mock.Mock() - auth_info = mlrun.api.schemas.AuthInfo(user_unix_id=1000) + auth_info = mlrun.common.schemas.AuthInfo(user_unix_id=1000) _, _, _, original_function_dict = _generate_original_function( kind=mlrun.runtimes.RuntimeKinds.job ) @@ -1024,7 +1024,7 @@ def test_ensure_function_security_context_enrichment_group_id( mlrun.mlconf.function.spec.security_context.enrichment_mode = ( SecurityContextEnrichmentModes.override.value ) - auth_info = mlrun.api.schemas.AuthInfo(user_unix_id=1000) + auth_info = mlrun.common.schemas.AuthInfo(user_unix_id=1000) _, _, _, original_function_dict = _generate_original_function( kind=mlrun.runtimes.RuntimeKinds.job ) @@ -1075,7 +1075,7 @@ def test_ensure_function_security_context_unknown_enrichment_mode( tests.api.api.utils.create_project(client, PROJECT) mlrun.mlconf.igz_version = "3.6" mlrun.mlconf.function.spec.security_context.enrichment_mode = "not a real mode" - auth_info = mlrun.api.schemas.AuthInfo(user_unix_id=1000) + auth_info = mlrun.common.schemas.AuthInfo(user_unix_id=1000) _, _, _, original_function_dict = _generate_original_function( kind=mlrun.runtimes.RuntimeKinds.job ) @@ -1098,7 +1098,7 @@ def test_ensure_function_security_context_missing_control_plane_session_tag( mlrun.mlconf.function.spec.security_context.enrichment_mode = ( SecurityContextEnrichmentModes.override ) - auth_info = mlrun.api.schemas.AuthInfo( + auth_info = mlrun.common.schemas.AuthInfo( planes=[mlrun.api.utils.clients.iguazio.SessionPlanes.data] ) _, _, _, original_function_dict = _generate_original_function( @@ -1121,7 +1121,7 @@ def test_ensure_function_security_context_missing_control_plane_session_tag( mlrun.api.utils.clients.iguazio.Client.get_user_unix_id = unittest.mock.Mock( return_value=user_unix_id ) - auth_info = mlrun.api.schemas.AuthInfo(planes=[]) + auth_info = mlrun.common.schemas.AuthInfo(planes=[]) logger.info( "Session missing control plane, but actually just because it wasn't enriched, expected to succeed" ) @@ -1142,7 +1142,7 @@ def test_ensure_function_security_context_get_user_unix_id( ) # set auth info with control plane and without user unix id so that it will be fetched - auth_info = mlrun.api.schemas.AuthInfo( + auth_info = mlrun.common.schemas.AuthInfo( planes=[mlrun.api.utils.clients.iguazio.SessionPlanes.control] ) mlrun.api.utils.clients.iguazio.Client.get_user_unix_id = unittest.mock.Mock( @@ -1181,13 +1181,13 @@ def test_generate_function_and_task_from_submit_run_body_imported_function_proje _mock_import_function(monkeypatch) submit_job_body = { "task": { - "spec": {"function": "hub://gen_class_data"}, + "spec": {"function": "hub://gen-class-data"}, "metadata": {"name": task_name, "project": PROJECT}, }, "function": {"spec": {"resources": {"limits": {}, "requests": {}}}}, } parsed_function_object, task = _generate_function_and_task_from_submit_run_body( - db, mlrun.api.schemas.AuthInfo(), submit_job_body + db, mlrun.common.schemas.AuthInfo(), submit_job_body ) assert parsed_function_object.metadata.project == PROJECT diff --git a/tests/api/api/utils.py b/tests/api/api/utils.py index 1d2414c69a43..f4ffa38c404f 100644 --- a/tests/api/api/utils.py +++ b/tests/api/api/utils.py @@ -20,19 +20,27 @@ import mlrun.api.api.endpoints.functions import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils import mlrun.api.utils.singletons.db import mlrun.api.utils.singletons.k8s import mlrun.artifacts.dataset import mlrun.artifacts.model +import mlrun.common.schemas import mlrun.errors PROJECT = "project-name" -def create_project(client: TestClient, project_name: str = PROJECT, artifact_path=None): - project = _create_project_obj(project_name, artifact_path) +def create_project( + client: TestClient, + project_name: str = PROJECT, + artifact_path=None, + source="source", + load_source_on_run=False, +): + project = _create_project_obj( + project_name, artifact_path, source, load_source_on_run + ) resp = client.post("projects", json=project.dict()) assert resp.status_code == HTTPStatus.CREATED.value return resp @@ -41,11 +49,11 @@ def create_project(client: TestClient, project_name: str = PROJECT, artifact_pat def compile_schedule(schedule_name: str = None, to_json: bool = True): if not schedule_name: schedule_name = f"schedule-name-{str(uuid.uuid4())}" - schedule = mlrun.api.schemas.ScheduleInput( + schedule = mlrun.common.schemas.ScheduleInput( name=schedule_name, - kind=mlrun.api.schemas.ScheduleKinds.job, + kind=mlrun.common.schemas.ScheduleKinds.job, scheduled_object={"metadata": {"name": "something"}}, - cron_trigger=mlrun.api.schemas.ScheduleCronTrigger(year=1999), + cron_trigger=mlrun.common.schemas.ScheduleCronTrigger(year=1999), ) if not to_json: return schedule @@ -55,9 +63,9 @@ def compile_schedule(schedule_name: str = None, to_json: bool = True): async def create_project_async( async_client: httpx.AsyncClient, project_name: str = PROJECT ): - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec( + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec( description="banana", source="source", goals="some goals" ), ) @@ -69,12 +77,15 @@ async def create_project_async( return resp -def _create_project_obj(project_name, artifact_path) -> mlrun.api.schemas.Project: - return mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec( +def _create_project_obj( + project_name, artifact_path, source, load_source_on_run=False +) -> mlrun.common.schemas.Project: + return mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec( description="banana", - source="source", + source=source, + load_source_on_run=load_source_on_run, goals="some goals", artifact_path=artifact_path, ), diff --git a/tests/api/conftest.py b/tests/api/conftest.py index fcf0e3b32786..c5fec1074f41 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -23,8 +23,9 @@ import pytest from fastapi.testclient import TestClient -import mlrun.api.schemas +import mlrun.api.utils.clients.iguazio import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas from mlrun import mlconf from mlrun.api.db.sqldb.session import _init_engine, create_session from mlrun.api.initial_data import init_data @@ -51,7 +52,7 @@ def db() -> Generator: # TODO: make it simpler - doesn't make sense to call 3 different functions to initialize the db # we need to force re-init the engine cause otherwise it is cached between tests - _init_engine(config.httpdb.dsn) + _init_engine(dsn=config.httpdb.dsn) # forcing from scratch because we created an empty file for the db init_data(from_scratch=True) @@ -124,18 +125,22 @@ def set_is_running_in_k8s_cluster(self, value: bool): def get_auth_secret_name(username: str, access_key: str) -> str: return f"secret-ref-{username}-{access_key}" - def store_auth_secret(self, username: str, access_key: str, namespace="") -> str: + def store_auth_secret( + self, username: str, access_key: str, namespace="" + ) -> (str, bool): secret_ref = self.get_auth_secret_name(username, access_key) self.auth_secrets_map.setdefault(secret_ref, {}).update( self._generate_auth_secret_data(username, access_key) ) - return secret_ref + return secret_ref, True @staticmethod def _generate_auth_secret_data(username: str, access_key: str): return { - mlrun.api.schemas.AuthSecretData.get_field_secret_key("username"): username, - mlrun.api.schemas.AuthSecretData.get_field_secret_key( + mlrun.common.schemas.AuthSecretData.get_field_secret_key( + "username" + ): username, + mlrun.common.schemas.AuthSecretData.get_field_secret_key( "access_key" ): access_key, } @@ -153,15 +158,17 @@ def read_auth_secret(self, secret_name, namespace="", raise_on_not_found=False): return None, None username = secret[ - mlrun.api.schemas.AuthSecretData.get_field_secret_key("username") + mlrun.common.schemas.AuthSecretData.get_field_secret_key("username") ] access_key = secret[ - mlrun.api.schemas.AuthSecretData.get_field_secret_key("access_key") + mlrun.common.schemas.AuthSecretData.get_field_secret_key("access_key") ] return username, access_key - def store_project_secrets(self, project, secrets, namespace=""): + def store_project_secrets(self, project, secrets, namespace="") -> (str, bool): self.project_secrets_map.setdefault(project, {}).update(secrets) + secret_name = project + return secret_name, True def delete_project_secrets(self, project, secrets, namespace=""): if not secrets: @@ -169,6 +176,7 @@ def delete_project_secrets(self, project, secrets, namespace=""): else: for key in secrets: self.project_secrets_map.get(project, {}).pop(key, None) + return "", True def get_project_secret_keys(self, project, namespace="", filter_internal=False): secret_keys = list(self.project_secrets_map.get(project, {}).keys()) @@ -206,8 +214,10 @@ def get_expected_env_variables_from_secrets( ) expected_env_from_secrets[env_variable_name] = {global_secret: key} - secret_name = mlrun.api.utils.singletons.k8s.get_k8s().get_project_secret_name( - project + secret_name = ( + mlrun.api.utils.singletons.k8s.get_k8s_helper().get_project_secret_name( + project + ) ) for key in self.project_secrets_map.get(project, {}): if key.startswith("mlrun.") and not include_internal: @@ -281,7 +291,7 @@ def k8s_secrets_mock(monkeypatch, client: TestClient) -> K8sSecretsMock: for mocked_function_name in mocked_function_names: monkeypatch.setattr( - mlrun.api.utils.singletons.k8s.get_k8s(), + mlrun.api.utils.singletons.k8s.get_k8s_helper(), mocked_function_name, getattr(k8s_secrets_mock, mocked_function_name), ) @@ -291,10 +301,37 @@ def k8s_secrets_mock(monkeypatch, client: TestClient) -> K8sSecretsMock: @pytest.fixture def kfp_client_mock(monkeypatch) -> kfp.Client: - mlrun.api.utils.singletons.k8s.get_k8s().is_running_inside_kubernetes_cluster = ( - unittest.mock.Mock(return_value=True) + mlrun.api.utils.singletons.k8s.get_k8s_helper().is_running_inside_kubernetes_cluster = unittest.mock.Mock( + return_value=True ) kfp_client_mock = unittest.mock.Mock() monkeypatch.setattr(kfp, "Client", lambda *args, **kwargs: kfp_client_mock) mlrun.mlconf.kfp_url = "http://ml-pipeline.custom_namespace.svc.cluster.local:8888" return kfp_client_mock + + +@pytest.fixture() +async def api_url() -> str: + api_url = "http://iguazio-api-url:8080" + mlrun.config.config._iguazio_api_url = api_url + return api_url + + +@pytest.fixture() +async def iguazio_client( + api_url: str, + request: pytest.FixtureRequest, +) -> mlrun.api.utils.clients.iguazio.Client: + if request.param == "async": + client = mlrun.api.utils.clients.iguazio.AsyncClient() + else: + client = mlrun.api.utils.clients.iguazio.Client() + + # force running init again so the configured api url will be used + client.__init__() + client._wait_for_job_completion_retry_interval = 0 + client._wait_for_project_terminal_state_retry_interval = 0 + + # inject the request param into client, so we can use it in tests + setattr(client, "mode", request.param) + return client diff --git a/tests/api/crud/runtimes/__init__.py b/tests/api/crud/runtimes/__init__.py new file mode 100644 index 000000000000..b3085be1eb56 --- /dev/null +++ b/tests/api/crud/runtimes/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/api/crud/runtimes/nuclio/__init__.py b/tests/api/crud/runtimes/nuclio/__init__.py new file mode 100644 index 000000000000..b3085be1eb56 --- /dev/null +++ b/tests/api/crud/runtimes/nuclio/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/api/crud/runtimes/nuclio/test_helpers.py b/tests/api/crud/runtimes/nuclio/test_helpers.py new file mode 100644 index 000000000000..dcd5d805bea0 --- /dev/null +++ b/tests/api/crud/runtimes/nuclio/test_helpers.py @@ -0,0 +1,104 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +import mlrun +import mlrun.api.crud.runtimes.nuclio.function +import mlrun.api.crud.runtimes.nuclio.helpers +from tests.conftest import examples_path + + +def test_compiled_function_config_nuclio_golang(): + name = f"{examples_path}/training.py" + fn = mlrun.code_to_function( + "nuclio", filename=name, kind="nuclio", handler="my_hand" + ) + ( + name, + project, + config, + ) = mlrun.api.crud.runtimes.nuclio.function._compile_function_config(fn) + assert fn.kind == "remote", "kind not set, test failed" + assert mlrun.utils.get_in(config, "spec.build.functionSourceCode"), "no source code" + assert mlrun.utils.get_in(config, "spec.runtime").startswith( + "py" + ), "runtime not set" + assert ( + mlrun.utils.get_in(config, "spec.handler") == "training:my_hand" + ), "wrong handler" + + +def test_compiled_function_config_nuclio_python(): + name = f"{examples_path}/training.py" + fn = mlrun.code_to_function( + "nuclio", filename=name, kind="nuclio", handler="my_hand" + ) + ( + name, + project, + config, + ) = mlrun.api.crud.runtimes.nuclio.function._compile_function_config(fn) + assert fn.kind == "remote", "kind not set, test failed" + assert mlrun.utils.get_in(config, "spec.build.functionSourceCode"), "no source code" + assert mlrun.utils.get_in(config, "spec.runtime").startswith( + "py" + ), "runtime not set" + assert ( + mlrun.utils.get_in(config, "spec.handler") == "training:my_hand" + ), "wrong handler" + + +@pytest.mark.parametrize( + "handler, expected", + [ + (None, ("", "main:handler")), + ("x", ("", "x:handler")), + ("x:y", ("", "x:y")), + ("dir#", ("dir", "main:handler")), + ("dir#x", ("dir", "x:handler")), + ("dir#x:y", ("dir", "x:y")), + ], +) +def test_resolve_work_dir_and_handler(handler, expected): + assert ( + expected + == mlrun.api.crud.runtimes.nuclio.helpers.resolve_work_dir_and_handler(handler) + ) + + +@pytest.mark.parametrize( + "mlrun_client_version,python_version,expected_runtime", + [ + ("1.3.0", "3.9.16", "python:3.9"), + ("1.3.0", "3.7.16", "python:3.7"), + (None, None, "python:3.7"), + (None, "3.9.16", "python:3.7"), + ("1.3.0", None, "python:3.7"), + ("0.0.0-unstable", "3.9.16", "python:3.9"), + ("0.0.0-unstable", "3.7.16", "python:3.7"), + ("1.2.0", "3.9.16", "python:3.7"), + ("1.2.0", "3.7.16", "python:3.7"), + ], +) +def test_resolve_nuclio_runtime_python_image( + mlrun_client_version, python_version, expected_runtime +): + assert ( + expected_runtime + == mlrun.api.crud.runtimes.nuclio.helpers.resolve_nuclio_runtime_python_image( + mlrun_client_version, python_version + ) + ) diff --git a/tests/api/crud/test_secrets.py b/tests/api/crud/test_secrets.py index 438452319735..9d9d02d0b843 100644 --- a/tests/api/crud/test_secrets.py +++ b/tests/api/crud/test_secrets.py @@ -22,7 +22,7 @@ import sqlalchemy.orm import mlrun.api.crud -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.errors import tests.api.conftest @@ -31,11 +31,11 @@ def test_store_project_secrets_verifications( db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient ): project = "project-name" - provider = mlrun.api.schemas.SecretProviderName.kubernetes + provider = mlrun.common.schemas.SecretProviderName.kubernetes with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={"invalid/key": "value"} ), ) @@ -43,7 +43,7 @@ def test_store_project_secrets_verifications( with pytest.raises(mlrun.errors.MLRunAccessDeniedError): mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={"mlrun.internal.key": "value"} ), ) @@ -55,7 +55,7 @@ def test_store_project_secrets_with_key_map_verifications( k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): project = "project-name" - provider = mlrun.api.schemas.SecretProviderName.kubernetes + provider = mlrun.common.schemas.SecretProviderName.kubernetes key_map_secret_key = ( mlrun.api.crud.Secrets().generate_client_key_map_project_secret_key( mlrun.api.crud.SecretsClientType.schedules @@ -65,7 +65,7 @@ def test_store_project_secrets_with_key_map_verifications( with pytest.raises(mlrun.errors.MLRunAccessDeniedError): mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={key_map_secret_key: "value"} ), ) @@ -74,8 +74,8 @@ def test_store_project_secrets_with_key_map_verifications( with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( - provider=mlrun.api.schemas.SecretProviderName.vault, + mlrun.common.schemas.SecretsData( + provider=mlrun.common.schemas.SecretProviderName.vault, secrets={"invalid/key": "value"}, ), ) @@ -84,7 +84,7 @@ def test_store_project_secrets_with_key_map_verifications( with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={"invalid/key": "value"} ), key_map_secret_key="invalid-key-map-secret-key", @@ -94,7 +94,7 @@ def test_store_project_secrets_with_key_map_verifications( with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={"invalid/key": "value"} ), allow_internal_secrets=True, @@ -105,7 +105,7 @@ def test_store_project_secrets_with_key_map_verifications( with pytest.raises(mlrun.errors.MLRunAccessDeniedError): mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={"valid-key": "value"} ), key_map_secret_key=key_map_secret_key, @@ -118,7 +118,7 @@ def test_get_project_secret_verifications( k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): project = "project-name" - provider = mlrun.api.schemas.SecretProviderName.kubernetes + provider = mlrun.common.schemas.SecretProviderName.kubernetes key_map_secret_key = ( mlrun.api.crud.Secrets().generate_client_key_map_project_secret_key( mlrun.api.crud.SecretsClientType.schedules @@ -136,7 +136,7 @@ def test_get_project_secret_verifications( with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): mlrun.api.crud.Secrets().get_project_secret( project, - mlrun.api.schemas.SecretProviderName.vault, + mlrun.common.schemas.SecretProviderName.vault, "does-not-exist-key", key_map_secret_key=key_map_secret_key, ) @@ -149,7 +149,7 @@ def test_get_project_secret( ): _mock_secrets_crud_uuid_generation() project = "project-name" - provider = mlrun.api.schemas.SecretProviderName.kubernetes + provider = mlrun.common.schemas.SecretProviderName.kubernetes key_map_secret_key = ( mlrun.api.crud.Secrets().generate_client_key_map_project_secret_key( mlrun.api.crud.SecretsClientType.schedules @@ -183,7 +183,7 @@ def test_get_project_secret( mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={ valid_secret_key: valid_secret_value, @@ -231,7 +231,7 @@ def test_delete_project_secret_verifications( k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): project = "project-name" - provider = mlrun.api.schemas.SecretProviderName.kubernetes + provider = mlrun.common.schemas.SecretProviderName.kubernetes key_map_secret_key = ( mlrun.api.crud.Secrets().generate_client_key_map_project_secret_key( mlrun.api.crud.SecretsClientType.schedules @@ -249,14 +249,14 @@ def test_delete_project_secret_verifications( # vault provider with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): mlrun.api.crud.Secrets().delete_project_secret( - project, mlrun.api.schemas.SecretProviderName.vault, "valid-key" + project, mlrun.common.schemas.SecretProviderName.vault, "valid-key" ) # key map with provider other than k8s with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): mlrun.api.crud.Secrets().delete_project_secret( project, - mlrun.api.schemas.SecretProviderName.vault, + mlrun.common.schemas.SecretProviderName.vault, "invalid/key", key_map_secret_key=key_map_secret_key, ) @@ -275,7 +275,7 @@ def test_delete_project_secret( ): _mock_secrets_crud_uuid_generation() project = "project-name" - provider = mlrun.api.schemas.SecretProviderName.kubernetes + provider = mlrun.common.schemas.SecretProviderName.kubernetes key_map_secret_key = ( mlrun.api.crud.Secrets().generate_client_key_map_project_secret_key( mlrun.api.crud.SecretsClientType.schedules @@ -295,7 +295,7 @@ def test_delete_project_secret( mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets=collections.OrderedDict( { @@ -370,7 +370,7 @@ def test_store_project_secrets_with_key_map_success( ): _mock_secrets_crud_uuid_generation() project = "project-name" - provider = mlrun.api.schemas.SecretProviderName.kubernetes + provider = mlrun.common.schemas.SecretProviderName.kubernetes key_map_secret_key = ( mlrun.api.crud.Secrets().generate_client_key_map_project_secret_key( mlrun.api.crud.SecretsClientType.schedules @@ -389,7 +389,7 @@ def test_store_project_secrets_with_key_map_success( # store secret with valid key - map shouldn't be used mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={valid_secret_key: valid_secret_value} ), allow_internal_secrets=True, @@ -402,7 +402,7 @@ def test_store_project_secrets_with_key_map_success( # store secret with invalid key - map should be used mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={invalid_secret_key: invalid_secret_value} ), allow_internal_secrets=True, @@ -420,7 +420,7 @@ def test_store_project_secrets_with_key_map_success( # store secret with the same invalid key and different value mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={invalid_secret_key: invalid_secret_value_2} ), allow_internal_secrets=True, @@ -439,7 +439,7 @@ def test_store_project_secrets_with_key_map_success( for _ in range(2): mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={invalid_secret_2_key: invalid_secret_2_value}, ), @@ -461,7 +461,7 @@ def test_store_project_secrets_with_key_map_success( # change values to all secrets mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={ valid_secret_key: valid_secret_value_2, @@ -502,7 +502,7 @@ def test_secrets_crud_internal_project_secrets( k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): project = "project-name" - provider = mlrun.api.schemas.SecretProviderName.kubernetes + provider = mlrun.common.schemas.SecretProviderName.kubernetes regular_secret_key = "key" regular_secret_value = "value" internal_secret_key = ( @@ -513,7 +513,7 @@ def test_secrets_crud_internal_project_secrets( # store regular secret - pass mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={regular_secret_key: regular_secret_value} ), ) @@ -522,7 +522,7 @@ def test_secrets_crud_internal_project_secrets( with pytest.raises(mlrun.errors.MLRunAccessDeniedError): mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={internal_secret_key: internal_secret_value} ), ) @@ -530,7 +530,7 @@ def test_secrets_crud_internal_project_secrets( # store internal secret with allow - pass mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={internal_secret_key: internal_secret_value} ), allow_internal_secrets=True, @@ -634,7 +634,7 @@ def test_secrets_crud_internal_project_secrets( # store internal secret again to verify deletion with empty list with allow - pass mlrun.api.crud.Secrets().store_project_secrets( project, - mlrun.api.schemas.SecretsData( + mlrun.common.schemas.SecretsData( provider=provider, secrets={internal_secret_key: internal_secret_value} ), allow_internal_secrets=True, @@ -666,8 +666,8 @@ def test_store_auth_secret_verifications( # not allowed with provider other than k8s with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): mlrun.api.crud.Secrets().store_auth_secret( - mlrun.api.schemas.AuthSecretData( - provider=mlrun.api.schemas.SecretProviderName.vault, + mlrun.common.schemas.AuthSecretData( + provider=mlrun.common.schemas.SecretProviderName.vault, username="some-username", access_key="some-access-key", ), @@ -682,8 +682,8 @@ def test_store_auth_secret( username = "some-username" access_key = "some-access-key" secret_name = mlrun.api.crud.Secrets().store_auth_secret( - mlrun.api.schemas.AuthSecretData( - provider=mlrun.api.schemas.SecretProviderName.kubernetes, + mlrun.common.schemas.AuthSecretData( + provider=mlrun.common.schemas.SecretProviderName.kubernetes, username=username, access_key=access_key, ), diff --git a/tests/api/db/conftest.py b/tests/api/db/conftest.py index 9acc393a1c71..88d0dfa636a6 100644 --- a/tests/api/db/conftest.py +++ b/tests/api/db/conftest.py @@ -12,12 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import shutil from typing import Generator import pytest -from mlrun.api.db.filedb.db import FileDB from mlrun.api.db.session import close_session, create_session from mlrun.api.db.sqldb.db import SQLDB from mlrun.api.db.sqldb.session import _init_engine @@ -26,83 +24,50 @@ from mlrun.api.utils.singletons.project_member import initialize_project_member from mlrun.config import config -dbs = [ - "sqldb", - "filedb", -] - -@pytest.fixture(params=dbs) -def db(request) -> Generator: - if request.param == "sqldb": - dsn = "sqlite:///:memory:?check_same_thread=false" - config.httpdb.dsn = dsn - _init_engine() - - # memory sqldb remove it self when all session closed, this session will keep it up during all test - db_session = create_session() - try: - init_data() - db = SQLDB(dsn) - db.initialize(db_session) - initialize_db(db) - initialize_project_member() - yield db - finally: - close_session(db_session) - elif request.param == "filedb": - db = FileDB(config.httpdb.dirpath) - db_session = create_session(request.param) - try: - db.initialize(db_session) - - yield db - finally: - shutil.rmtree(config.httpdb.dirpath, ignore_errors=True, onerror=None) - close_session(db_session) - else: - raise Exception("Unknown db type") +@pytest.fixture() +def db() -> Generator: + dsn = "sqlite:///:memory:?check_same_thread=false" + config.httpdb.dsn = dsn + _init_engine() + # memory sqldb remove it self when all session closed, this session will keep it up during all test + db_session = create_session() + try: + init_data() + db = SQLDB(dsn) + db.initialize(db_session) + initialize_db(db) + initialize_project_member() + yield db + finally: + close_session(db_session) @pytest.fixture() -def data_migration_db(request) -> Generator: +def data_migration_db() -> Generator: # Data migrations performed before the API goes up, therefore there's no project member yet # that's the only difference between this fixture and the db fixture. because of the parameterization it was hard to # share code between them, we anyway going to remove filedb soon, then there won't be params, and we could re-use # code # TODO: fix duplication - if request.param == "sqldb": - dsn = "sqlite:///:memory:?check_same_thread=false" - config.httpdb.dsn = dsn - _init_engine() - - # memory sqldb remove it self when all session closed, this session will keep it up during all test - db_session = create_session() - try: - init_data() - db = SQLDB(dsn) - db.initialize(db_session) - initialize_db(db) - yield db - finally: - close_session(db_session) - elif request.param == "filedb": - db = FileDB(config.httpdb.dirpath) - db_session = create_session(request.param) - try: - db.initialize(db_session) - - yield db - finally: - shutil.rmtree(config.httpdb.dirpath, ignore_errors=True, onerror=None) - close_session(db_session) - else: - raise Exception("Unknown db type") + dsn = "sqlite:///:memory:?check_same_thread=false" + config.httpdb.dsn = dsn + _init_engine(dsn=dsn) + # memory sqldb remove it self when all session closed, this session will keep it up during all test + db_session = create_session() + try: + init_data() + db = SQLDB(dsn) + db.initialize(db_session) + initialize_db(db) + yield db + finally: + close_session(db_session) -@pytest.fixture(params=dbs) -def db_session(request) -> Generator: - db_session = create_session(request.param) +@pytest.fixture() +def db_session() -> Generator: + db_session = create_session() try: yield db_session finally: diff --git a/tests/api/db/test_artifacts.py b/tests/api/db/test_artifacts.py index 133c64eaffea..ebc04d25a57f 100644 --- a/tests/api/db/test_artifacts.py +++ b/tests/api/db/test_artifacts.py @@ -19,20 +19,15 @@ from sqlalchemy.orm import Session import mlrun.api.initial_data +import mlrun.common.schemas import mlrun.errors -from mlrun.api import schemas from mlrun.api.db.base import DBInterface -from mlrun.api.schemas.artifact import ArtifactCategories from mlrun.artifacts.dataset import DatasetArtifact from mlrun.artifacts.model import ModelArtifact from mlrun.artifacts.plots import ChartArtifact, PlotArtifact -from tests.api.db.conftest import dbs +from mlrun.common.schemas.artifact import ArtifactCategories -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_artifact_name_filter(db: DBInterface, db_session: Session): artifact_name_1 = "artifact_name_1" artifact_name_2 = "artifact_name_2" @@ -67,10 +62,6 @@ def test_list_artifact_name_filter(db: DBInterface, db_session: Session): assert len(artifacts) == 2 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_artifact_iter_parameter(db: DBInterface, db_session: Session): artifact_name_1 = "artifact_name_1" artifact_name_2 = "artifact_name_2" @@ -105,10 +96,6 @@ def test_list_artifact_iter_parameter(db: DBInterface, db_session: Session): assert len(artifacts) == 1 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_artifact_kind_filter(db: DBInterface, db_session: Session): artifact_name_1 = "artifact_name_1" artifact_kind_1 = ChartArtifact.kind @@ -142,10 +129,6 @@ def test_list_artifact_kind_filter(db: DBInterface, db_session: Session): assert artifacts[0]["metadata"]["name"] == artifact_name_2 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_artifact_category_filter(db: DBInterface, db_session: Session): artifact_name_1 = "artifact_name_1" artifact_kind_1 = ChartArtifact.kind @@ -188,26 +171,26 @@ def test_list_artifact_category_filter(db: DBInterface, db_session: Session): artifacts = db.list_artifacts(db_session) assert len(artifacts) == 4 - artifacts = db.list_artifacts(db_session, category=schemas.ArtifactCategories.model) + artifacts = db.list_artifacts( + db_session, category=mlrun.common.schemas.ArtifactCategories.model + ) assert len(artifacts) == 1 assert artifacts[0]["metadata"]["name"] == artifact_name_3 artifacts = db.list_artifacts( - db_session, category=schemas.ArtifactCategories.dataset + db_session, category=mlrun.common.schemas.ArtifactCategories.dataset ) assert len(artifacts) == 1 assert artifacts[0]["metadata"]["name"] == artifact_name_4 - artifacts = db.list_artifacts(db_session, category=schemas.ArtifactCategories.other) + artifacts = db.list_artifacts( + db_session, category=mlrun.common.schemas.ArtifactCategories.other + ) assert len(artifacts) == 2 assert artifacts[0]["metadata"]["name"] == artifact_name_1 assert artifacts[1]["metadata"]["name"] == artifact_name_2 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_store_artifact_tagging(db: DBInterface, db_session: Session): artifact_1_key = "artifact_key_1" artifact_1_body = _generate_artifact(artifact_1_key) @@ -238,10 +221,6 @@ def test_store_artifact_tagging(db: DBInterface, db_session: Session): assert len(artifacts) == 1 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_store_artifact_restoring_multiple_tags(db: DBInterface, db_session: Session): artifact_key = "artifact_key_1" artifact_1_uid = "artifact_uid_1" @@ -295,10 +274,6 @@ def test_store_artifact_restoring_multiple_tags(db: DBInterface, db_session: Ses assert artifact["metadata"]["tag"] == artifact_2_tag -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_read_artifact_tag_resolution(db: DBInterface, db_session: Session): """ We had a bug in which when we got a tag filter for read/list artifact, we were transforming this tag to list of @@ -341,10 +316,6 @@ def test_read_artifact_tag_resolution(db: DBInterface, db_session: Session): assert len(artifacts) == 1 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_delete_artifacts_tag_filter(db: DBInterface, db_session: Session): artifact_1_key = "artifact_key_1" artifact_2_key = "artifact_key_2" @@ -379,10 +350,6 @@ def test_delete_artifacts_tag_filter(db: DBInterface, db_session: Session): assert len(artifacts) == 0 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_delete_artifact_tag_filter(db: DBInterface, db_session: Session): project = "artifact_project" artifact_1_key = "artifact_key_1" @@ -460,10 +427,6 @@ def test_delete_artifact_tag_filter(db: DBInterface, db_session: Session): assert len(tags) == 0 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_artifacts_exact_name_match(db: DBInterface, db_session: Session): artifact_1_key = "pre_artifact_key_suffix" artifact_2_key = "pre-artifact-key-suffix" @@ -547,9 +510,6 @@ def _generate_artifact_with_iterations( ) -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_artifacts_best_iter_with_tagged_iteration( db: DBInterface, db_session: Session ): @@ -584,13 +544,13 @@ def test_list_artifacts_best_iter_with_tagged_iteration( project=project, ) - identifier_1 = schemas.ArtifactIdentifier( + identifier_1 = mlrun.common.schemas.ArtifactIdentifier( kind=ArtifactCategories.model, key=artifact_key_1, uid=artifact_uid_1, iter=best_iter, ) - identifier_2 = schemas.ArtifactIdentifier( + identifier_2 = mlrun.common.schemas.ArtifactIdentifier( kind=ArtifactCategories.model, key=artifact_key_2, uid=artifact_uid_2, @@ -611,10 +571,6 @@ def test_list_artifacts_best_iter_with_tagged_iteration( ) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_artifacts_best_iter(db: DBInterface, db_session: Session): artifact_1_key = "artifact-1" artifact_1_uid = "uid-1" @@ -690,9 +646,6 @@ def test_list_artifacts_best_iter(db: DBInterface, db_session: Session): ) -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_artifacts_best_iteration(db: DBInterface, db_session: Session): artifact_key = "artifact-1" artifact_1_uid = "uid-1" @@ -748,12 +701,6 @@ def test_list_artifacts_best_iteration(db: DBInterface, db_session: Session): assert set(expected_uids) == set(uids) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "data_migration_db,db_session", - [(dbs[0], dbs[0])], - indirect=["data_migration_db", "db_session"], -) def test_data_migration_fix_legacy_datasets_large_previews( data_migration_db: DBInterface, db_session: Session, @@ -845,12 +792,6 @@ def test_data_migration_fix_legacy_datasets_large_previews( ) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "data_migration_db,db_session", - [(dbs[0], dbs[0])], - indirect=["data_migration_db", "db_session"], -) def test_data_migration_fix_datasets_large_previews( data_migration_db: DBInterface, db_session: Session, diff --git a/tests/api/db/test_background_tasks.py b/tests/api/db/test_background_tasks.py index 5ae1dbc909b7..86ed054d79fe 100644 --- a/tests/api/db/test_background_tasks.py +++ b/tests/api/db/test_background_tasks.py @@ -18,15 +18,11 @@ from sqlalchemy.orm import Session import mlrun.api.initial_data +import mlrun.common.schemas import mlrun.errors -from mlrun.api import schemas from mlrun.api.db.base import DBInterface -from tests.api.db.conftest import dbs -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_store_project_background_task(db: DBInterface, db_session: Session): project = "test-project" db.store_background_task(db_session, "test", timeout=600, project=project) @@ -35,9 +31,6 @@ def test_store_project_background_task(db: DBInterface, db_session: Session): assert background_task.status.state == "running" -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_get_project_background_task_with_timeout_exceeded( db: DBInterface, db_session: Session ): @@ -50,9 +43,6 @@ def test_get_project_background_task_with_timeout_exceeded( assert background_task.status.state == "failed" -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_get_project_background_task_doesnt_exists( db: DBInterface, db_session: Session ): @@ -61,29 +51,33 @@ def test_get_project_background_task_doesnt_exists( db.get_background_task(db_session, "test", project=project) -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_store_project_background_task_after_status_updated( db: DBInterface, db_session: Session ): project = "test-project" db.store_background_task(db_session, "test", project=project) background_task = db.get_background_task(db_session, "test", project=project) - assert background_task.status.state == schemas.BackgroundTaskState.running + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.running + ) db.store_background_task( - db_session, "test", state=schemas.BackgroundTaskState.failed, project=project + db_session, + "test", + state=mlrun.common.schemas.BackgroundTaskState.failed, + project=project, ) background_task = db.get_background_task(db_session, "test", project=project) - assert background_task.status.state == schemas.BackgroundTaskState.failed + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.failed + ) # Expecting to fail with pytest.raises(mlrun.errors.MLRunRuntimeError): db.store_background_task( db_session, "test", - state=schemas.BackgroundTaskState.running, + state=mlrun.common.schemas.BackgroundTaskState.running, project=project, ) # expecting to fail, because terminal state is terminal which means it is not supposed to change @@ -91,18 +85,18 @@ def test_store_project_background_task_after_status_updated( db.store_background_task( db_session, "test", - state=schemas.BackgroundTaskState.succeeded, + state=mlrun.common.schemas.BackgroundTaskState.succeeded, project=project, ) db.store_background_task( - db_session, "test", state=schemas.BackgroundTaskState.failed, project=project + db_session, + "test", + state=mlrun.common.schemas.BackgroundTaskState.failed, + project=project, ) -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_get_project_background_task_with_disabled_timeout( db: DBInterface, db_session: Session ): @@ -118,25 +112,29 @@ def test_get_project_background_task_with_disabled_timeout( assert background_task.metadata.timeout is None # expecting created and updated time to be equal because mode disabled even if timeout exceeded assert background_task.metadata.created == background_task.metadata.updated - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.running + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.running + ) task_name = "test1" db.store_background_task(db_session, name=task_name, project=project) # because timeout default mode is disabled, expecting not to enrich the background task timeout background_task = db.get_background_task(db_session, task_name, project) assert background_task.metadata.timeout is None assert background_task.metadata.created == background_task.metadata.updated - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.running + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.running + ) db.store_background_task( db_session, name=task_name, project=project, - state=mlrun.api.schemas.BackgroundTaskState.succeeded, + state=mlrun.common.schemas.BackgroundTaskState.succeeded, ) background_task_new = db.get_background_task(db_session, task_name, project) assert ( background_task_new.status.state - == mlrun.api.schemas.BackgroundTaskState.succeeded + == mlrun.common.schemas.BackgroundTaskState.succeeded ) assert background_task_new.metadata.updated > background_task.metadata.updated assert background_task_new.metadata.created == background_task.metadata.created diff --git a/tests/api/db/test_feature_sets.py b/tests/api/db/test_feature_sets.py index f36d75776489..3c62bbef209b 100644 --- a/tests/api/db/test_feature_sets.py +++ b/tests/api/db/test_feature_sets.py @@ -16,12 +16,11 @@ import pytest from sqlalchemy.orm import Session +import mlrun.common.schemas import mlrun.feature_store as fstore import mlrun.utils.helpers from mlrun import errors -from mlrun.api import schemas from mlrun.api.db.base import DBInterface -from tests.api.db.conftest import dbs def _create_feature_set(name): @@ -59,17 +58,13 @@ def _create_feature_set(name): } -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_create_feature_set(db: DBInterface, db_session: Session): name = "dummy" feature_set = _create_feature_set(name) project = "proj-test" - feature_set = schemas.FeatureSet(**feature_set) + feature_set = mlrun.common.schemas.FeatureSet(**feature_set) db.store_feature_set( db_session, project, name, feature_set, tag="latest", versioned=True ) @@ -82,10 +77,6 @@ def test_create_feature_set(db: DBInterface, db_session: Session): assert len(features_res.features) == 1 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_handle_feature_set_with_datetime_fields(db: DBInterface, db_session: Session): # Simulate a situation where a feature-set client-side object is created with datetime fields, and then stored to # DB. This may happen in API calls which utilize client-side objects (such as ingest). See ML-3552. @@ -95,20 +86,17 @@ def test_handle_feature_set_with_datetime_fields(db: DBInterface, db_session: Se # This object will have datetime in the spec.source object fields fs_object = fstore.FeatureSet.from_dict(feature_set) # Convert it to DB schema object (will still have datetime fields) - fs_server_object = schemas.FeatureSet(**fs_object.to_dict()) + fs_server_object = mlrun.common.schemas.FeatureSet(**fs_object.to_dict()) mlrun.utils.helpers.fill_object_hash(fs_server_object.dict(), "uid") -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_update_feature_set_labels(db: DBInterface, db_session: Session): name = "dummy" feature_set = _create_feature_set(name) project = "proj-test" - feature_set = schemas.FeatureSet(**feature_set) + feature_set = mlrun.common.schemas.FeatureSet(**feature_set) db.store_feature_set( db_session, project, name, feature_set, tag="latest", versioned=True ) @@ -159,16 +147,13 @@ def test_update_feature_set_labels(db: DBInterface, db_session: Session): assert updated_feature_set.metadata.labels == feature_set.metadata.labels -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_update_feature_set_by_uid(db: DBInterface, db_session: Session): name = "mock_feature_set" feature_set = _create_feature_set(name) project = "proj-test" - feature_set = schemas.FeatureSet(**feature_set) + feature_set = mlrun.common.schemas.FeatureSet(**feature_set) db.store_feature_set( db_session, project, name, feature_set, tag="latest", versioned=True ) diff --git a/tests/api/db/test_functions.py b/tests/api/db/test_functions.py index 93877664c00b..1f1e8bb15466 100644 --- a/tests/api/db/test_functions.py +++ b/tests/api/db/test_functions.py @@ -18,12 +18,8 @@ import mlrun.errors from mlrun.api.db.base import DBInterface from mlrun.api.db.sqldb.models import Function -from tests.api.db.conftest import dbs -@pytest.mark.parametrize( - "db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"] -) def test_store_function_default_to_latest(db: DBInterface, db_session: Session): function_1 = _generate_function() function_hash_key = db.store_function( @@ -43,9 +39,6 @@ def test_store_function_default_to_latest(db: DBInterface, db_session: Session): assert function_queried_without_tag_hash == function_queried_without_tag_hash -@pytest.mark.parametrize( - "db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"] -) def test_store_function_versioned(db: DBInterface, db_session: Session): function_1 = _generate_function() function_hash_key = db.store_function( @@ -84,9 +77,6 @@ def test_store_function_versioned(db: DBInterface, db_session: Session): assert tagged_count == 1 -@pytest.mark.parametrize( - "db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"] -) def test_store_function_not_versioned(db: DBInterface, db_session: Session): function_1 = _generate_function() function_hash_key = db.store_function( @@ -110,9 +100,6 @@ def test_store_function_not_versioned(db: DBInterface, db_session: Session): assert len(functions) == 1 -@pytest.mark.parametrize( - "db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"] -) def test_get_function_by_hash_key(db: DBInterface, db_session: Session): function_1 = _generate_function() function_hash_key = db.store_function( @@ -134,9 +121,36 @@ def test_get_function_by_hash_key(db: DBInterface, db_session: Session): assert function_queried_with_hash_key["metadata"]["tag"] == "" -@pytest.mark.parametrize( - "db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"] -) +def test_get_function_when_using_not_normalize_name( + db: DBInterface, db_session: Session +): + # add a function with a non-normalized name to the database + function_name = "function_name" + project_name = "project" + _generate_and_insert_function_record(db_session, function_name, project_name) + + # getting the function using the non-normalized name, and ensure that it works + response = db.get_function(db_session, function_name, project_name) + assert response["metadata"]["name"] == function_name + + +def _generate_and_insert_function_record( + db_session: Session, function_name: str, project_name: str +): + function = { + "metadata": {"name": function_name, "project": project_name}, + "spec": {"asd": "test"}, + } + fn = Function( + name=function_name, project=project_name, struct=function, uid="1", id="1" + ) + tag = Function.Tag(project=project_name, name="latest", obj_name=fn.name) + tag.obj_id, tag.uid = fn.id, fn.uid + db_session.add(fn) + db_session.add(tag) + db_session.commit() + + def test_get_function_by_tag(db: DBInterface, db_session: Session): function_1 = _generate_function() function_hash_key = db.store_function( @@ -149,9 +163,6 @@ def test_get_function_by_tag(db: DBInterface, db_session: Session): assert function_hash_key == function_not_queried_by_tag_hash -@pytest.mark.parametrize( - "db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"] -) def test_get_function_not_found(db: DBInterface, db_session: Session): function_1 = _generate_function() db.store_function( @@ -167,9 +178,6 @@ def test_get_function_not_found(db: DBInterface, db_session: Session): ) -@pytest.mark.parametrize( - "db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"] -) def test_list_functions_no_tags(db: DBInterface, db_session: Session): function_1 = {"bla": "blabla", "status": {"bla": "blabla"}} function_2 = {"bla2": "blabla", "status": {"bla": "blabla"}} @@ -195,9 +203,6 @@ def test_list_functions_no_tags(db: DBInterface, db_session: Session): assert function["status"] is None -@pytest.mark.parametrize( - "db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"] -) def test_list_functions_by_tag(db: DBInterface, db_session: Session): tag = "function_name_1" @@ -213,10 +218,6 @@ def test_list_functions_by_tag(db: DBInterface, db_session: Session): assert len(names) == 0 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_functions_with_non_existent_tag(db: DBInterface, db_session: Session): names = ["some_name", "some_name2", "some_name3"] for name in names: @@ -226,9 +227,6 @@ def test_list_functions_with_non_existent_tag(db: DBInterface, db_session: Sessi assert len(functions) == 0 -@pytest.mark.parametrize( - "db,db_session", [(db, db) for db in dbs], indirect=["db", "db_session"] -) def test_list_functions_filtering_unversioned_untagged( db: DBInterface, db_session: Session ): @@ -257,10 +255,6 @@ def test_list_functions_filtering_unversioned_untagged( assert functions[0]["metadata"]["hash"] == tagged_function_hash_key -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_delete_function(db: DBInterface, db_session: Session): labels = { "name": "value", @@ -316,10 +310,6 @@ def test_delete_function(db: DBInterface, db_session: Session): assert number_of_labels == 0 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) @pytest.mark.parametrize("use_hash_key", [True, False]) def test_list_functions_multiple_tags( db: DBInterface, db_session: Session, use_hash_key: bool @@ -349,10 +339,6 @@ def test_list_functions_multiple_tags( assert len(tags) == 0 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_function_with_tag_and_uid(db: DBInterface, db_session: Session): tag_name = "some_tag" function_1 = _generate_function(tag=tag_name) diff --git a/tests/api/db/test_hub.py b/tests/api/db/test_hub.py new file mode 100644 index 000000000000..f03d2c426a74 --- /dev/null +++ b/tests/api/db/test_hub.py @@ -0,0 +1,55 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from sqlalchemy.orm import Session + +import mlrun.api.db.sqldb.models +import mlrun.api.initial_data +from mlrun.api.db.base import DBInterface + + +def test_data_migration_rename_marketplace_kind_to_hub( + db: DBInterface, db_session: Session +): + # create hub sources + for i in range(3): + source_name = f"source-{i}" + source_dict = { + "metadata": { + "name": source_name, + }, + "spec": { + "path": "/local/path/to/source", + }, + "kind": "MarketplaceSource", + } + # id and index are multiplied by 2 to avoid sqlalchemy unique constraint error + source = mlrun.api.db.sqldb.models.HubSource( + id=i * 2, + name=source_name, + index=i * 2, + ) + source.full_object = source_dict + db_session.add(source) + db_session.commit() + + # run migration + mlrun.api.initial_data._rename_marketplace_kind_to_hub(db, db_session) + + # check that all hub sources are now of kind 'HubSource' + hubs = db._list_hub_sources_without_transform(db_session) + for hub in hubs: + hub_dict = hub.full_object + assert "kind" in hub_dict + assert hub_dict["kind"] == "HubSource" diff --git a/tests/api/db/test_projects.py b/tests/api/db/test_projects.py index be4b977ecf03..27a443d5fced 100644 --- a/tests/api/db/test_projects.py +++ b/tests/api/db/test_projects.py @@ -20,19 +20,14 @@ import sqlalchemy.orm import mlrun.api.initial_data -import mlrun.api.schemas import mlrun.api.utils.singletons.db +import mlrun.common.schemas import mlrun.config import mlrun.errors from mlrun.api.db.base import DBInterface from mlrun.api.db.sqldb.models import Project -from tests.api.db.conftest import dbs -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_get_project( db: DBInterface, db_session: sqlalchemy.orm.Session, @@ -44,11 +39,11 @@ def test_get_project( } db.create_project( db_session, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name=project_name, labels=project_labels ), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ), ) @@ -65,10 +60,6 @@ def test_get_project( ) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_get_project_with_pre_060_record( db: DBInterface, db_session: sqlalchemy.orm.Session, @@ -91,10 +82,6 @@ def test_get_project_with_pre_060_record( assert updated_record.full_object is not None -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_data_migration_enrich_project_state( db: DBInterface, db_session: sqlalchemy.orm.Session, @@ -105,12 +92,12 @@ def test_data_migration_enrich_project_state( projects = db.list_projects(db_session) for project in projects.projects: # getting default value from the schema - assert project.spec.desired_state == mlrun.api.schemas.ProjectState.online + assert project.spec.desired_state == mlrun.common.schemas.ProjectState.online assert project.status.state is None mlrun.api.initial_data._enrich_project_state(db, db_session) projects = db.list_projects(db_session) for project in projects.projects: - assert project.spec.desired_state == mlrun.api.schemas.ProjectState.online + assert project.spec.desired_state == mlrun.common.schemas.ProjectState.online assert project.status.state == project.spec.desired_state # verify not storing for no reason db.store_project = unittest.mock.Mock() @@ -126,10 +113,6 @@ def _generate_and_insert_pre_060_record( db_session.commit() -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_project( db: DBInterface, db_session: sqlalchemy.orm.Session, @@ -147,11 +130,11 @@ def test_list_project( for project in expected_projects: db.create_project( db_session, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name=project["name"], labels=project.get("labels") ), - spec=mlrun.api.schemas.ProjectSpec( + spec=mlrun.common.schemas.ProjectSpec( description=project.get("description") ), ), @@ -170,10 +153,43 @@ def test_list_project( ) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) +def test_list_project_minimal( + db: DBInterface, + db_session: sqlalchemy.orm.Session, +): + expected_projects = ["project-name-1", "project-name-2", "project-name-3"] + for project in expected_projects: + db.create_project( + db_session, + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( + name=project, + ), + spec=mlrun.common.schemas.ProjectSpec( + description="some-proj", + artifacts=[{"key": "value"}], + workflows=[{"key": "value"}], + functions=[{"key": "value"}], + ), + ), + ) + projects_output = db.list_projects( + db_session, format_=mlrun.common.schemas.ProjectsFormat.minimal + ) + for index, project in enumerate(projects_output.projects): + assert project.metadata.name == expected_projects[index] + assert project.spec.artifacts is None + assert project.spec.workflows is None + assert project.spec.functions is None + + projects_output = db.list_projects(db_session) + for index, project in enumerate(projects_output.projects): + assert project.metadata.name == expected_projects[index] + assert project.spec.artifacts == [{"key": "value"}] + assert project.spec.workflows == [{"key": "value"}] + assert project.spec.functions == [{"key": "value"}] + + def test_list_project_names_filter( db: DBInterface, db_session: sqlalchemy.orm.Session, @@ -183,14 +199,14 @@ def test_list_project_names_filter( for project in project_names: db.create_project( db_session, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project), + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project), ), ) filter_names = [project_names[0], project_names[3], project_names[4]] projects_output = db.list_projects( db_session, - format_=mlrun.api.schemas.ProjectsFormat.name_only, + format_=mlrun.common.schemas.ProjectsFormat.name_only, names=filter_names, ) @@ -205,17 +221,13 @@ def test_list_project_names_filter( projects_output = db.list_projects( db_session, - format_=mlrun.api.schemas.ProjectsFormat.name_only, + format_=mlrun.common.schemas.ProjectsFormat.name_only, names=[], ) assert projects_output.projects == [] -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_create_project( db: DBInterface, db_session: sqlalchemy.orm.Session, @@ -228,10 +240,6 @@ def test_create_project( _assert_project(db, db_session, project) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_store_project_creation( db: DBInterface, db_session: sqlalchemy.orm.Session, @@ -245,10 +253,6 @@ def test_store_project_creation( _assert_project(db, db_session, project) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_store_project_update( db: DBInterface, db_session: sqlalchemy.orm.Session, @@ -262,8 +266,8 @@ def test_store_project_update( db.store_project( db_session, project.metadata.name, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project.metadata.name), + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project.metadata.name), ), ) project_output = db.get_project(db_session, project.metadata.name) @@ -274,10 +278,6 @@ def test_store_project_update( assert project_output.metadata.created != project.metadata.created -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_patch_project( db: DBInterface, db_session: sqlalchemy.orm.Session, @@ -318,10 +318,6 @@ def test_patch_project( ) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_delete_project( db: DBInterface, db_session: sqlalchemy.orm.Session, @@ -330,9 +326,9 @@ def test_delete_project( project_description = "some description" db.create_project( db_session, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ), ) db.delete_project(db_session, project_name) @@ -342,15 +338,15 @@ def test_delete_project( def _generate_project(): - return mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + return mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name="project-name", created=datetime.datetime.utcnow() - datetime.timedelta(seconds=1), labels={ "some-label": "some-label-value", }, ), - spec=mlrun.api.schemas.ProjectSpec( + spec=mlrun.common.schemas.ProjectSpec( description="some description", owner="owner-name" ), ) @@ -359,7 +355,7 @@ def _generate_project(): def _assert_project( db: DBInterface, db_session: sqlalchemy.orm.Session, - expected_project: mlrun.api.schemas.Project, + expected_project: mlrun.common.schemas.Project, ): project_output = db.get_project(db_session, expected_project.metadata.name) assert project_output.metadata.name == expected_project.metadata.name diff --git a/tests/api/db/test_runs.py b/tests/api/db/test_runs.py index fe19706cf29f..3685482fdff3 100644 --- a/tests/api/db/test_runs.py +++ b/tests/api/db/test_runs.py @@ -20,13 +20,8 @@ import mlrun.api.db.sqldb.helpers import mlrun.api.initial_data from mlrun.api.db.base import DBInterface -from tests.api.db.conftest import dbs -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_runs_name_filter(db: DBInterface, db_session: Session): project = "project" run_name_1 = "run_name_1" @@ -57,10 +52,48 @@ def test_list_runs_name_filter(db: DBInterface, db_session: Session): assert len(runs) == 2 -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) +def test_runs_with_notifications(db: DBInterface, db_session: Session): + project_name = "project" + run_uids = ["uid1", "uid2", "uid3"] + num_runs = len(run_uids) + # create several runs with different uids, each with a notification + for run_uid in run_uids: + _create_new_run(db, db_session, project=project_name, uid=run_uid) + notification = mlrun.model.Notification( + kind="slack", + when=["completed", "error"], + name=f"test-notification-{run_uid}", + message="test-message", + condition="blabla", + severity="info", + params={"some-param": "some-value"}, + ) + db.store_run_notifications(db_session, [notification], run_uid, project_name) + + runs = db.list_runs(db_session, project=project_name, with_notifications=True) + assert len(runs) == num_runs + for run in runs: + run_notifications = run["spec"]["notifications"] + assert len(run_notifications) == 1 + assert ( + run_notifications[0]["name"] + == f"test-notification-{run['metadata']['uid']}" + ) + + db.delete_run_notifications(db_session, run_uid=run_uids[0], project=project_name) + runs = db.list_runs(db_session, project=project_name, with_notifications=True) + assert len(runs) == num_runs - 1 + + db.delete_run_notifications(db_session, project=project_name) + runs = db.list_runs(db_session, project=project_name, with_notifications=False) + assert len(runs) == num_runs + runs = db.list_runs(db_session, project=project_name, with_notifications=True) + assert len(runs) == 0 + + db.del_runs(db_session, project=project_name) + db.verify_project_has_no_related_resources(db_session, project_name) + + def test_list_distinct_runs_uids(db: DBInterface, db_session: Session): project_name = "project" uid = "run-uid" @@ -108,10 +141,6 @@ def test_list_distinct_runs_uids(db: DBInterface, db_session: Session): assert type(distinct_runs[0]) == dict -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_runs_state_filter(db: DBInterface, db_session: Session): project = "project" run_uid_running = "run-running" @@ -148,10 +177,6 @@ def test_list_runs_state_filter(db: DBInterface, db_session: Session): assert runs[0]["metadata"]["uid"] == run_uid_completed -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_store_run_overriding_start_time(db: DBInterface, db_session: Session): # First store - fills the start_time project, name, uid, iteration, run = _create_new_run(db, db_session) @@ -178,10 +203,6 @@ def test_store_run_overriding_start_time(db: DBInterface, db_session: Session): assert runs[0].struct["status"]["start_time"] == run["status"]["start_time"] -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_data_migration_align_runs_table(db: DBInterface, db_session: Session): time_before_creation = datetime.now(tz=timezone.utc) # Create runs @@ -214,10 +235,6 @@ def test_data_migration_align_runs_table(db: DBInterface, db_session: Session): _ensure_run_after_align_runs_migration(db, run, time_before_creation) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_data_migration_align_runs_table_with_empty_run_body( db: DBInterface, db_session: Session ): @@ -244,10 +261,6 @@ def test_data_migration_align_runs_table_with_empty_run_body( _ensure_run_after_align_runs_migration(db, run) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_store_run_success(db: DBInterface, db_session: Session): project, name, uid, iteration, run = _create_new_run(db, db_session) @@ -272,10 +285,6 @@ def test_store_run_success(db: DBInterface, db_session: Session): ) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_update_runs_requested_logs(db: DBInterface, db_session: Session): project, name, uid, iteration, run = _create_new_run(db, db_session) @@ -294,10 +303,6 @@ def test_update_runs_requested_logs(db: DBInterface, db_session: Session): assert runs_after[0].updated > run_updated_time -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_update_run_success(db: DBInterface, db_session: Session): project, name, uid, iteration, run = _create_new_run(db, db_session) @@ -315,10 +320,6 @@ def test_update_run_success(db: DBInterface, db_session: Session): assert run["spec"]["another-new-field"] == "value" -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_store_and_update_run_update_name_failure(db: DBInterface, db_session: Session): project, name, uid, iteration, run = _create_new_run(db, db_session) @@ -348,10 +349,6 @@ def test_store_and_update_run_update_name_failure(db: DBInterface, db_session: S ) -# running only on sqldb cause filedb is not really a thing anymore, will be removed soon -@pytest.mark.parametrize( - "db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"] -) def test_list_runs_limited_unsorted_failure(db: DBInterface, db_session: Session): with pytest.raises( mlrun.errors.MLRunInvalidArgumentError, diff --git a/tests/api/runtime_handlers/base.py b/tests/api/runtime_handlers/base.py index 231d3c42e26e..a3584c3d8b1d 100644 --- a/tests/api/runtime_handlers/base.py +++ b/tests/api/runtime_handlers/base.py @@ -25,11 +25,11 @@ import mlrun import mlrun.api.crud as crud -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.runtimes.constants from mlrun.api.constants import LogSources from mlrun.api.utils.singletons.db import get_db -from mlrun.api.utils.singletons.k8s import get_k8s +from mlrun.api.utils.singletons.k8s import get_k8s_helper from mlrun.runtimes import get_runtime_handler from mlrun.runtimes.constants import PodPhases, RunStates from mlrun.utils import create_logger, now_date @@ -80,9 +80,9 @@ def setup_method_fixture(self, db: Session, client: fastapi.testclient.TestClien # We want this mock for every test, ideally we would have simply put it in the setup_method # but it is happening before the fixtures initialization. We need the client fixture (which needs the db one) # in order to be able to mock k8s stuff - get_k8s().v1api = unittest.mock.Mock() - get_k8s().crdapi = unittest.mock.Mock() - get_k8s().is_running_inside_kubernetes_cluster = unittest.mock.Mock( + get_k8s_helper().v1api = unittest.mock.Mock() + get_k8s_helper().crdapi = unittest.mock.Mock() + get_k8s_helper().is_running_inside_kubernetes_cluster = unittest.mock.Mock( return_value=True ) # enable inheriting classes to do the same @@ -124,7 +124,7 @@ def _generate_pod(name, labels, phase=PodPhases.succeeded): ) status = client.V1PodStatus(phase=phase, container_statuses=[container_status]) metadata = client.V1ObjectMeta( - name=name, labels=labels, namespace=get_k8s().resolve_namespace() + name=name, labels=labels, namespace=get_k8s_helper().resolve_namespace() ) pod = client.V1Pod(metadata=metadata, status=status) return pod @@ -132,7 +132,7 @@ def _generate_pod(name, labels, phase=PodPhases.succeeded): @staticmethod def _generate_config_map(name, labels, data=None): metadata = client.V1ObjectMeta( - name=name, labels=labels, namespace=get_k8s().resolve_namespace() + name=name, labels=labels, namespace=get_k8s_helper().resolve_namespace() ) if data is None: data = {"key": "value"} @@ -150,14 +150,16 @@ def _assert_runtime_handler_list_resources( expected_crds=None, expected_pods=None, expected_services=None, - group_by: Optional[mlrun.api.schemas.ListRuntimeResourcesGroupByField] = None, + group_by: Optional[ + mlrun.common.schemas.ListRuntimeResourcesGroupByField + ] = None, ): runtime_handler = get_runtime_handler(runtime_kind) if group_by is None: project = "*" label_selector = runtime_handler._get_default_label_selector() assertion_func = TestRuntimeHandlerBase._assert_list_resources_response - elif group_by == mlrun.api.schemas.ListRuntimeResourcesGroupByField.job: + elif group_by == mlrun.common.schemas.ListRuntimeResourcesGroupByField.job: project = self.project label_selector = ",".join( [ @@ -168,7 +170,7 @@ def _assert_runtime_handler_list_resources( assertion_func = ( TestRuntimeHandlerBase._assert_list_resources_grouped_by_job_response ) - elif group_by == mlrun.api.schemas.ListRuntimeResourcesGroupByField.project: + elif group_by == mlrun.common.schemas.ListRuntimeResourcesGroupByField.project: project = self.project label_selector = ",".join( [ @@ -183,21 +185,21 @@ def _assert_runtime_handler_list_resources( raise NotImplementedError("Unsupported group by value") resources = runtime_handler.list_resources(project, group_by=group_by) crd_group, crd_version, crd_plural = runtime_handler._get_crd_info() - get_k8s().v1api.list_namespaced_pod.assert_called_once_with( - get_k8s().resolve_namespace(), + get_k8s_helper().v1api.list_namespaced_pod.assert_called_once_with( + get_k8s_helper().resolve_namespace(), label_selector=label_selector, ) if expected_crds: - get_k8s().crdapi.list_namespaced_custom_object.assert_called_once_with( + get_k8s_helper().crdapi.list_namespaced_custom_object.assert_called_once_with( crd_group, crd_version, - get_k8s().resolve_namespace(), + get_k8s_helper().resolve_namespace(), crd_plural, label_selector=label_selector, ) if expected_services: - get_k8s().v1api.list_namespaced_service.assert_called_once_with( - get_k8s().resolve_namespace(), + get_k8s_helper().v1api.list_namespaced_service.assert_called_once_with( + get_k8s_helper().resolve_namespace(), label_selector=label_selector, ) assertion_func( @@ -213,7 +215,7 @@ def _assert_runtime_handler_list_resources( def _assert_list_resources_grouped_by_job_response( self, - resources: mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, + resources: mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, expected_crds=None, expected_pods=None, expected_services=None, @@ -229,7 +231,7 @@ def _assert_list_resources_grouped_by_job_response( def _assert_list_resources_grouped_by_project_response( self, - resources: mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, + resources: mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, expected_crds=None, expected_pods=None, expected_services=None, @@ -253,7 +255,7 @@ def _extract_project_and_kind_from_runtime_resources_labels( def _assert_list_resources_grouped_by_response( self, - resources: mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, + resources: mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, group_by_field_extractor, expected_crds=None, expected_pods=None, @@ -285,7 +287,7 @@ def _assert_list_resources_grouped_by_response( def _assert_resource_in_response_resources( expected_resource_type: str, expected_resource: dict, - resources: mlrun.api.schemas.GroupedByJobRuntimeResourcesOutput, + resources: mlrun.common.schemas.GroupedByJobRuntimeResourcesOutput, resources_field_name: str, group_by_field_extractor, ): @@ -323,7 +325,7 @@ def _assert_resource_in_response_resources( def _assert_list_resources_response( self, - resources: mlrun.api.schemas.RuntimeResources, + resources: mlrun.common.schemas.RuntimeResources, expected_crds=None, expected_pods=None, expected_services=None, @@ -359,7 +361,9 @@ def _mock_list_namespaced_pods(list_pods_call_responses: List[List[client.V1Pod] for list_pods_call_response in list_pods_call_responses: pods = client.V1PodList(items=list_pods_call_response) calls.append(pods) - get_k8s().v1api.list_namespaced_pod = unittest.mock.Mock(side_effect=calls) + get_k8s_helper().v1api.list_namespaced_pod = unittest.mock.Mock( + side_effect=calls + ) return calls @staticmethod @@ -376,9 +380,9 @@ def _assert_delete_namespaced_pods( for expected_pod_name in expected_pod_names ] if not expected_pod_names: - assert get_k8s().v1api.delete_namespaced_pod.call_count == 0 + assert get_k8s_helper().v1api.delete_namespaced_pod.call_count == 0 else: - get_k8s().v1api.delete_namespaced_pod.assert_has_calls(calls) + get_k8s_helper().v1api.delete_namespaced_pod.assert_has_calls(calls) @staticmethod def _assert_delete_namespaced_services( @@ -389,9 +393,9 @@ def _assert_delete_namespaced_services( for expected_service_name in expected_service_names ] if not expected_service_names: - assert get_k8s().v1api.delete_namespaced_service.call_count == 0 + assert get_k8s_helper().v1api.delete_namespaced_service.call_count == 0 else: - get_k8s().v1api.delete_namespaced_service.assert_has_calls(calls) + get_k8s_helper().v1api.delete_namespaced_service.assert_has_calls(calls) @staticmethod def _assert_delete_namespaced_custom_objects( @@ -411,26 +415,32 @@ def _assert_delete_namespaced_custom_objects( for expected_custom_object_name in expected_custom_object_names ] if not expected_custom_object_names: - assert get_k8s().crdapi.delete_namespaced_custom_object.call_count == 0 + assert ( + get_k8s_helper().crdapi.delete_namespaced_custom_object.call_count == 0 + ) else: - get_k8s().crdapi.delete_namespaced_custom_object.assert_has_calls(calls) + get_k8s_helper().crdapi.delete_namespaced_custom_object.assert_has_calls( + calls + ) @staticmethod def _mock_delete_namespaced_pods(): - get_k8s().v1api.delete_namespaced_pod = unittest.mock.Mock() + get_k8s_helper().v1api.delete_namespaced_pod = unittest.mock.Mock() @staticmethod def _mock_delete_namespaced_custom_objects(): - get_k8s().crdapi.delete_namespaced_custom_object = unittest.mock.Mock() + get_k8s_helper().crdapi.delete_namespaced_custom_object = unittest.mock.Mock() @staticmethod def _mock_delete_namespaced_services(): - get_k8s().v1api.delete_namespaced_service = unittest.mock.Mock() + get_k8s_helper().v1api.delete_namespaced_service = unittest.mock.Mock() @staticmethod def _mock_read_namespaced_pod_log(): log = "Some log string" - get_k8s().v1api.read_namespaced_pod_log = unittest.mock.Mock(return_value=log) + get_k8s_helper().v1api.read_namespaced_pod_log = unittest.mock.Mock( + return_value=log + ) return log @staticmethod @@ -438,7 +448,7 @@ def _mock_list_namespaced_crds(crd_dicts_call_responses: List[List[Dict]]): calls = [] for crd_dicts_call_response in crd_dicts_call_responses: calls.append({"items": crd_dicts_call_response}) - get_k8s().crdapi.list_namespaced_custom_object = unittest.mock.Mock( + get_k8s_helper().crdapi.list_namespaced_custom_object = unittest.mock.Mock( side_effect=calls ) return calls @@ -446,7 +456,7 @@ def _mock_list_namespaced_crds(crd_dicts_call_responses: List[List[Dict]]): @staticmethod def _mock_list_namespaced_config_map(config_maps): config_maps_list = client.V1ConfigMapList(items=config_maps) - get_k8s().v1api.list_namespaced_config_map = unittest.mock.Mock( + get_k8s_helper().v1api.list_namespaced_config_map = unittest.mock.Mock( return_value=config_maps_list ) return config_maps @@ -454,7 +464,7 @@ def _mock_list_namespaced_config_map(config_maps): @staticmethod def _mock_list_services(services): services_list = client.V1ServiceList(items=services) - get_k8s().v1api.list_namespaced_service = unittest.mock.Mock( + get_k8s_helper().v1api.list_namespaced_service = unittest.mock.Mock( return_value=services_list ) return services @@ -466,13 +476,15 @@ def _assert_list_namespaced_pods_calls( expected_label_selector: str = None, ): assert ( - get_k8s().v1api.list_namespaced_pod.call_count == expected_number_of_calls + get_k8s_helper().v1api.list_namespaced_pod.call_count + == expected_number_of_calls ) expected_label_selector = ( expected_label_selector or runtime_handler._get_default_label_selector() ) - get_k8s().v1api.list_namespaced_pod.assert_any_call( - get_k8s().resolve_namespace(), label_selector=expected_label_selector + get_k8s_helper().v1api.list_namespaced_pod.assert_any_call( + get_k8s_helper().resolve_namespace(), + label_selector=expected_label_selector, ) @staticmethod @@ -481,13 +493,13 @@ def _assert_list_namespaced_crds_calls( ): crd_group, crd_version, crd_plural = runtime_handler._get_crd_info() assert ( - get_k8s().crdapi.list_namespaced_custom_object.call_count + get_k8s_helper().crdapi.list_namespaced_custom_object.call_count == expected_number_of_calls ) - get_k8s().crdapi.list_namespaced_custom_object.assert_any_call( + get_k8s_helper().crdapi.list_namespaced_custom_object.assert_any_call( crd_group, crd_version, - get_k8s().resolve_namespace(), + get_k8s_helper().resolve_namespace(), crd_plural, label_selector=runtime_handler._get_default_label_selector(), ) @@ -501,9 +513,9 @@ async def _assert_run_logs( logger_pod_name: str = None, ): if logger_pod_name is not None: - get_k8s().v1api.read_namespaced_pod_log.assert_called_once_with( + get_k8s_helper().v1api.read_namespaced_pod_log.assert_called_once_with( name=logger_pod_name, - namespace=get_k8s().resolve_namespace(), + namespace=get_k8s_helper().resolve_namespace(), ) _, logs = await crud.Logs().get_logs( db, project, uid, source=LogSources.PERSISTENCY diff --git a/tests/api/runtime_handlers/test_daskjob.py b/tests/api/runtime_handlers/test_daskjob.py index d62dad563814..ec3aa3b1d87d 100644 --- a/tests/api/runtime_handlers/test_daskjob.py +++ b/tests/api/runtime_handlers/test_daskjob.py @@ -16,7 +16,7 @@ from kubernetes import client from sqlalchemy.orm import Session -import mlrun.api.schemas +import mlrun.common.schemas from mlrun.api.utils.singletons.db import get_db from mlrun.runtimes import RuntimeKinds, get_runtime_handler from mlrun.runtimes.constants import PodPhases @@ -104,7 +104,7 @@ def test_list_resources(self, db: Session, client: TestClient): def test_list_resources_grouped_by(self, db: Session, client: TestClient): for group_by in [ - mlrun.api.schemas.ListRuntimeResourcesGroupByField.project, + mlrun.common.schemas.ListRuntimeResourcesGroupByField.project, ]: pods = self._mock_list_resources_pods() services = self._mock_list_services([self.cluster_service]) @@ -127,7 +127,7 @@ def test_build_output_from_runtime_resources(self, db: Session, client: TestClie runtime_handler = get_runtime_handler(RuntimeKinds.dask) resources = runtime_handler.list_resources( self.project, - group_by=mlrun.api.schemas.ListRuntimeResourcesGroupByField.project, + group_by=mlrun.common.schemas.ListRuntimeResourcesGroupByField.project, ) runtime_handler.build_output_from_runtime_resources( [resources[self.project][RuntimeKinds.dask]] diff --git a/tests/api/runtime_handlers/test_kubejob.py b/tests/api/runtime_handlers/test_kubejob.py index 2124a7258878..a1cd686348a1 100644 --- a/tests/api/runtime_handlers/test_kubejob.py +++ b/tests/api/runtime_handlers/test_kubejob.py @@ -20,7 +20,7 @@ from sqlalchemy.orm import Session import mlrun.api.crud -import mlrun.api.schemas +import mlrun.common.schemas import tests.conftest from mlrun.api.utils.singletons.db import get_db from mlrun.config import config @@ -81,8 +81,8 @@ def test_list_resources(self, db: Session, client: TestClient): def test_list_resources_grouped_by(self, db: Session, client: TestClient): for group_by in [ - mlrun.api.schemas.ListRuntimeResourcesGroupByField.job, - mlrun.api.schemas.ListRuntimeResourcesGroupByField.project, + mlrun.common.schemas.ListRuntimeResourcesGroupByField.job, + mlrun.common.schemas.ListRuntimeResourcesGroupByField.project, ]: pods = self._mock_list_resources_pods() self._assert_runtime_handler_list_resources( @@ -98,7 +98,7 @@ def test_list_resources_grouped_by_project_with_non_project_resources( resources = self._assert_runtime_handler_list_resources( RuntimeKinds.job, expected_pods=pods, - group_by=mlrun.api.schemas.ListRuntimeResourcesGroupByField.project, + group_by=mlrun.common.schemas.ListRuntimeResourcesGroupByField.project, ) # the legacy builder pod does not have a project label, verify it is listed under the empty key # so it will be removed on cleanup diff --git a/tests/api/runtime_handlers/test_mpijob.py b/tests/api/runtime_handlers/test_mpijob.py index 2269ab34666c..1758d95b36c4 100644 --- a/tests/api/runtime_handlers/test_mpijob.py +++ b/tests/api/runtime_handlers/test_mpijob.py @@ -18,9 +18,9 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -import mlrun.api.schemas +import mlrun.common.schemas from mlrun.api.utils.singletons.db import get_db -from mlrun.api.utils.singletons.k8s import get_k8s +from mlrun.api.utils.singletons.k8s import get_k8s_helper from mlrun.runtimes import RuntimeKinds, get_runtime_handler from mlrun.runtimes.constants import PodPhases, RunStates from tests.api.runtime_handlers.base import TestRuntimeHandlerBase @@ -123,8 +123,8 @@ def test_list_resources_with_crds_without_status( def test_list_resources_grouped_by_job(self, db: Session, client: TestClient): for group_by in [ - mlrun.api.schemas.ListRuntimeResourcesGroupByField.job, - mlrun.api.schemas.ListRuntimeResourcesGroupByField.project, + mlrun.common.schemas.ListRuntimeResourcesGroupByField.job, + mlrun.common.schemas.ListRuntimeResourcesGroupByField.project, ]: mocked_responses = self._mock_list_namespaced_crds( [[self.succeeded_crd_dict]] @@ -362,7 +362,7 @@ def _generate_mpijob_crd(project, uid, status=None): crd_dict = { "metadata": { "name": "train-eaf63df8", - "namespace": get_k8s().resolve_namespace(), + "namespace": get_k8s_helper().resolve_namespace(), "labels": { "mlrun/class": "mpijob", "mlrun/function": "trainer", diff --git a/tests/api/runtime_handlers/test_sparkjob.py b/tests/api/runtime_handlers/test_sparkjob.py index fae1b483c06a..aca0a6a8e627 100644 --- a/tests/api/runtime_handlers/test_sparkjob.py +++ b/tests/api/runtime_handlers/test_sparkjob.py @@ -18,9 +18,9 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -import mlrun.api.schemas +import mlrun.common.schemas from mlrun.api.utils.singletons.db import get_db -from mlrun.api.utils.singletons.k8s import get_k8s +from mlrun.api.utils.singletons.k8s import get_k8s_helper from mlrun.runtimes import RuntimeKinds, get_runtime_handler from mlrun.runtimes.constants import PodPhases, RunStates from tests.api.runtime_handlers.base import TestRuntimeHandlerBase @@ -115,8 +115,8 @@ def test_list_resources(self, db: Session, client: TestClient): def test_list_resources_grouped_by_job(self, db: Session, client: TestClient): for group_by in [ - mlrun.api.schemas.ListRuntimeResourcesGroupByField.job, - mlrun.api.schemas.ListRuntimeResourcesGroupByField.project, + mlrun.common.schemas.ListRuntimeResourcesGroupByField.job, + mlrun.common.schemas.ListRuntimeResourcesGroupByField.project, ]: mocked_responses = self._mock_list_namespaced_crds( [[self.completed_crd_dict]] @@ -360,7 +360,7 @@ def _generate_sparkjob_crd(project, uid, status=None): crd_dict = { "metadata": { "name": "my-spark-jdbc-2ea432f1", - "namespace": get_k8s().resolve_namespace(), + "namespace": get_k8s_helper().resolve_namespace(), "labels": { "mlrun/class": "spark", "mlrun/function": "my-spark-jdbc", diff --git a/tests/api/runtimes/base.py b/tests/api/runtimes/base.py index 1ae46684a4e7..715e30149e2c 100644 --- a/tests/api/runtimes/base.py +++ b/tests/api/runtimes/base.py @@ -30,16 +30,18 @@ from kubernetes import client as k8s_client from kubernetes.client import V1EnvVar -import mlrun.api.schemas +import mlrun.api.api.endpoints.functions +import mlrun.api.crud +import mlrun.common.schemas import mlrun.k8s_utils import mlrun.runtimes.pod -from mlrun.api.utils.singletons.k8s import get_k8s +import tests.api.api.utils +from mlrun.api.utils.singletons.k8s import get_k8s_helper from mlrun.config import config as mlconf from mlrun.model import new_task from mlrun.runtimes.constants import PodPhases from mlrun.utils import create_logger from mlrun.utils.azure_vault import AzureVaultStore -from mlrun.utils.vault import VaultStore logger = create_logger(level="debug", name="test-runtime") @@ -47,7 +49,7 @@ class TestRuntimeBase: def setup_method(self, method): self.namespace = mlconf.namespace = "test-namespace" - get_k8s().namespace = self.namespace + get_k8s_helper().namespace = self.namespace # set auto-mount to work as if this is an Iguazio system (otherwise it may try to mount PVC) mlconf.igz_version = "1.1.1" @@ -65,8 +67,9 @@ def setup_method(self, method): self.requirements_file = str(self.assets_path / "requirements.txt") self.vault_secrets = ["secret1", "secret2", "AWS_KEY"] - self.vault_secret_value = "secret123!@" - self.vault_secret_name = "vault-secret" + # TODO: Vault: uncomment when vault returns to be relevant + # self.vault_secret_value = "secret123!@" + # self.vault_secret_name = "vault-secret" self.azure_vault_secrets = ["azure_secret1", "azure_secret2"] self.azure_secret_value = "azure-secret-123!@" @@ -91,12 +94,13 @@ def setup_method_fixture( # We want this mock for every test, ideally we would have simply put it in the setup_method # but it is happening before the fixtures initialization. We need the client fixture (which needs the db one) # in order to be able to mock k8s stuff - get_k8s().get_project_secret_keys = unittest.mock.Mock(return_value=[]) - get_k8s().v1api = unittest.mock.Mock() - get_k8s().crdapi = unittest.mock.Mock() - get_k8s().is_running_inside_kubernetes_cluster = unittest.mock.Mock( + get_k8s_helper().get_project_secret_keys = unittest.mock.Mock(return_value=[]) + get_k8s_helper().v1api = unittest.mock.Mock() + get_k8s_helper().crdapi = unittest.mock.Mock() + get_k8s_helper().is_running_inside_kubernetes_cluster = unittest.mock.Mock( return_value=True ) + self._create_project(client) # enable inheriting classes to do the same self.custom_setup_after_fixtures() @@ -142,6 +146,11 @@ def custom_setup_after_fixtures(self): def custom_teardown(self): pass + def _create_project( + self, client: fastapi.testclient.TestClient, project_name: str = None + ): + tests.api.api.utils.create_project(client, project_name or self.project) + def _generate_task(self): return new_task( name=self.name, project=self.project, artifact_path=self.artifact_path @@ -328,7 +337,7 @@ def _generate_pod(namespace, pod): response_pod.metadata.namespace = namespace return response_pod - get_k8s().v1api.create_namespaced_pod = unittest.mock.Mock( + get_k8s_helper().v1api.create_namespaced_pod = unittest.mock.Mock( side_effect=_generate_pod ) @@ -336,10 +345,10 @@ def _generate_pod(namespace, pod): def _mock_get_logger_pods(self): # Our purpose is not to test the client watching on logs, mock empty list (used in get_logger_pods) - get_k8s().v1api.list_namespaced_pod = unittest.mock.Mock( + get_k8s_helper().v1api.list_namespaced_pod = unittest.mock.Mock( return_value=client.V1PodList(items=[]) ) - get_k8s().v1api.read_namespaced_pod_log = unittest.mock.Mock( + get_k8s_helper().v1api.read_namespaced_pod_log = unittest.mock.Mock( return_value="Mocked pod logs" ) @@ -354,15 +363,16 @@ def _generate_custom_object( ): return deepcopy(body) - get_k8s().crdapi.create_namespaced_custom_object = unittest.mock.Mock( + get_k8s_helper().crdapi.create_namespaced_custom_object = unittest.mock.Mock( side_effect=_generate_custom_object ) self._mock_get_logger_pods() # Vault now supported in KubeJob and Serving, so moved to base. def _mock_vault_functionality(self): - secret_dict = {key: self.vault_secret_value for key in self.vault_secrets} - VaultStore.get_secrets = unittest.mock.Mock(return_value=secret_dict) + # TODO: Vault: uncomment when vault returns to be relevant + # secret_dict = {key: self.vault_secret_value for key in self.vault_secrets} + # VaultStore.get_secrets = unittest.mock.Mock(return_value=secret_dict) azure_secret_dict = { key: self.azure_secret_value for key in self.azure_vault_secrets @@ -378,7 +388,7 @@ def _mock_vault_functionality(self): service_account = client.V1ServiceAccount( metadata=object_meta, secrets=[secret] ) - get_k8s().v1api.read_namespaced_service_account = unittest.mock.Mock( + get_k8s_helper().v1api.read_namespaced_service_account = unittest.mock.Mock( return_value=service_account ) @@ -389,14 +399,21 @@ def execute_function(self, runtime, **kwargs): kwargs.update({"watch": False}) self._execute_run(runtime, **kwargs) + @staticmethod + def deploy(db_session, runtime, with_mlrun=True): + auth_info = mlrun.common.schemas.AuthInfo() + mlrun.api.api.endpoints.functions._build_function( + db_session, auth_info, runtime, with_mlrun=with_mlrun + ) + def _reset_mocks(self): - get_k8s().v1api.create_namespaced_pod.reset_mock() - get_k8s().v1api.list_namespaced_pod.reset_mock() - get_k8s().v1api.read_namespaced_pod_log.reset_mock() + get_k8s_helper().v1api.create_namespaced_pod.reset_mock() + get_k8s_helper().v1api.list_namespaced_pod.reset_mock() + get_k8s_helper().v1api.read_namespaced_pod_log.reset_mock() def _reset_custom_object_mocks(self): - mlrun.api.utils.singletons.k8s.get_k8s().crdapi.create_namespaced_custom_object.reset_mock() - get_k8s().v1api.list_namespaced_pod.reset_mock() + mlrun.api.utils.singletons.k8s.get_k8s_helper().crdapi.create_namespaced_custom_object.reset_mock() + get_k8s_helper().v1api.list_namespaced_pod.reset_mock() def _execute_run(self, runtime, **kwargs): # Reset the mock, so that when checking is create_pod was called, no leftovers are there (in case running @@ -522,7 +539,7 @@ def _assert_pod_env_from_secrets(pod_env, expected_variables): assert len(expected_variables) == 0 def _get_pod_creation_args(self): - args, _ = get_k8s().v1api.create_namespaced_pod.call_args + args, _ = get_k8s_helper().v1api.create_namespaced_pod.call_args return args[1] def _get_custom_object_creation_body(self): @@ -530,7 +547,7 @@ def _get_custom_object_creation_body(self): _, kwargs, ) = ( - mlrun.api.utils.singletons.k8s.get_k8s().crdapi.create_namespaced_custom_object.call_args + mlrun.api.utils.singletons.k8s.get_k8s_helper().crdapi.create_namespaced_custom_object.call_args ) return kwargs["body"] @@ -539,12 +556,12 @@ def _get_create_custom_object_namespace_arg(self): _, kwargs, ) = ( - mlrun.api.utils.singletons.k8s.get_k8s().crdapi.create_namespaced_custom_object.call_args + mlrun.api.utils.singletons.k8s.get_k8s_helper().crdapi.create_namespaced_custom_object.call_args ) return kwargs["namespace"] def _get_create_pod_namespace_arg(self): - args, _ = get_k8s().v1api.create_namespaced_pod.call_args + args, _ = get_k8s_helper().v1api.create_namespaced_pod.call_args return args[0] def _assert_v3io_mount_or_creds_configured( @@ -664,7 +681,7 @@ def _assert_pod_creation_config( expected_args=None, ): if assert_create_pod_called: - create_pod_mock = get_k8s().v1api.create_namespaced_pod + create_pod_mock = get_k8s_helper().v1api.create_namespaced_pod create_pod_mock.assert_called_once() assert self._get_create_pod_namespace_arg() == self.namespace @@ -819,7 +836,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations( json.dumps(preemptible_node_selector).encode("utf-8") ) mlrun.mlconf.function_defaults.preemption_mode = ( - mlrun.api.schemas.PreemptionModes.prevent.value + mlrun.common.schemas.PreemptionModes.prevent.value ) # set default preemptible tolerations @@ -838,21 +855,25 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations( preemptible_affinity = self._generate_preemptible_affinity() preemptible_tolerations = self._generate_preemptible_tolerations() logger.info("prevent -> constrain, expecting preemptible affinity") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection( affinity=preemptible_affinity, tolerations=preemptible_tolerations ) logger.info("constrain -> allow, expecting only preemption tolerations to stay") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection(tolerations=preemptible_tolerations) logger.info( "allow -> constrain, expecting preemptible affinity with tolerations" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection( affinity=preemptible_affinity, tolerations=preemptible_tolerations @@ -861,19 +882,19 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations( logger.info( "constrain -> prevent, expecting affinity and tolerations to be removed" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.prevent.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.prevent.value) self.execute_function(runtime) self.assert_node_selection() logger.info("prevent -> allow, expecting preemptible tolerations") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection(tolerations=preemptible_tolerations) logger.info( "allow -> prevent, expecting affinity and tolerations to be removed" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.prevent.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.prevent.value) self.execute_function(runtime) self.assert_node_selection() @@ -885,7 +906,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi json.dumps(preemptible_node_selector).encode("utf-8") ) mlrun.mlconf.function_defaults.preemption_mode = ( - mlrun.api.schemas.PreemptionModes.prevent.value + mlrun.common.schemas.PreemptionModes.prevent.value ) # set default preemptible tolerations @@ -910,7 +931,9 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi "and preemptible anti-affinity to be removed and preemptible affinity to be added" ) runtime.with_node_selection(node_selector=self._generate_node_selector()) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection( node_selector=preemptible_node_selector, @@ -921,7 +944,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi "constrain -> allow, with preemptible node selector and affinity and tolerations," " expecting affinity and node selector to be removed and only preemptible tolerations to stay" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection(tolerations=preemptible_tolerations) @@ -939,7 +962,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi logger.info( "allow -> prevent, with not preemptible node selector, expecting to stay" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.prevent.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.prevent.value) self.execute_function(runtime) self.assert_node_selection( node_selector=not_preemptible_node_selector, @@ -949,7 +972,9 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi "prevent -> constrain, with not preemptible node selector, expecting to stay and" " preemptible affinity and tolerations to be added" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection( node_selector=not_preemptible_node_selector, @@ -969,14 +994,18 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi "prevent -> constrain, with not preemptible affinity," " expecting to override affinity with preemptible affinity and add tolerations" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection( affinity=preemptible_affinity, tolerations=preemptible_tolerations ) logger.info("constrain > constrain, expecting to stay the same") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection( affinity=preemptible_affinity, tolerations=preemptible_tolerations @@ -988,7 +1017,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi ) runtime = self._generate_runtime() runtime.with_node_selection(affinity=self._generate_not_preemptible_affinity()) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection( affinity=self._generate_not_preemptible_affinity(), @@ -996,7 +1025,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi ) logger.info("allow -> allow, expecting to stay the same") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection( affinity=self._generate_not_preemptible_affinity(), @@ -1006,14 +1035,14 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi logger.info( "allow -> prevent, with not preemptible affinity expecting tolerations to be removed" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.prevent.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.prevent.value) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_not_preemptible_affinity()) logger.info( "prevent -> prevent, with not preemptible affinity expecting to stay the same" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.prevent.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.prevent.value) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_not_preemptible_affinity()) @@ -1025,7 +1054,9 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi ) runtime = self._generate_runtime() runtime.with_node_selection(affinity=self._generate_affinity()) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) expected_affinity = self._generate_affinity() expected_affinity.node_affinity.required_during_scheduling_ignored_during_execution = k8s_client.V1NodeSelector( @@ -1060,7 +1091,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi + self._generate_preemptible_tolerations() ) runtime.with_preemption_mode( - mode=mlrun.api.schemas.PreemptionModes.constrain.value + mode=mlrun.common.schemas.PreemptionModes.constrain.value ) self.execute_function(runtime) self.assert_node_selection( @@ -1072,7 +1103,9 @@ def assert_run_preemption_mode_with_preemptible_node_selector_and_tolerations_wi "constrain -> allow, with merged preemptible tolerations and preemptible affinity, " "expecting only merged preemptible tolerations" ) - runtime.with_preemption_mode(mode=mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode( + mode=mlrun.common.schemas.PreemptionModes.allow.value + ) self.execute_function(runtime) self.assert_node_selection( tolerations=merged_preemptible_tolerations, @@ -1087,7 +1120,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl json.dumps(preemptible_node_selector).encode("utf-8") ) mlrun.mlconf.function_defaults.preemption_mode = ( - mlrun.api.schemas.PreemptionModes.prevent.value + mlrun.common.schemas.PreemptionModes.prevent.value ) logger.info( "prevent, without setting any node selection expecting preemptible anti-affinity to be set" @@ -1097,32 +1130,36 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl self.assert_node_selection(affinity=self._generate_preemptible_anti_affinity()) logger.info("prevent -> constrain, expecting preemptible affinity") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_preemptible_affinity()) logger.info("constrain -> allow, expecting no node selection to be set") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection() logger.info("allow -> constrain, expecting preemptible affinity") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_preemptible_affinity()) logger.info("constrain -> prevent, expecting preemptible anti-affinity") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.prevent.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.prevent.value) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_preemptible_anti_affinity()) logger.info("prevent -> allow, expecting no node selection to be set") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection() logger.info("allow -> prevent, expecting preemptible anti-affinity") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.prevent.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.prevent.value) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_preemptible_anti_affinity()) @@ -1135,7 +1172,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl json.dumps(preemptible_node_selector).encode("utf-8") ) mlrun.mlconf.function_defaults.preemption_mode = ( - mlrun.api.schemas.PreemptionModes.prevent.value + mlrun.common.schemas.PreemptionModes.prevent.value ) logger.info( @@ -1151,7 +1188,9 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl "and preemptible anti-affinity to be removed and preemptible affinity to be added" ) runtime.with_node_selection(node_selector=preemptible_node_selector) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection( node_selector=preemptible_node_selector, @@ -1160,7 +1199,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl logger.info( "constrain -> allow with preemptible node selector and affinity, expecting both to be removed" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection() @@ -1176,7 +1215,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl "allow -> prevent, with not preemptible node selector, expecting to stay and preemptible" " anti-affinity" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.prevent.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.prevent.value) self.execute_function(runtime) self.assert_node_selection( node_selector=not_preemptible_node_selector, @@ -1186,7 +1225,9 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl "prevent -> constrain, with not preemptible node selector, expecting to stay and" " preemptible affinity to be add and anti affinity to be remove" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection( node_selector=not_preemptible_node_selector, @@ -1206,12 +1247,16 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl "prevent -> constrain, with preemptible anti-affinity," " expecting to override anti-affinity with preemptible affinity" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_preemptible_affinity()) logger.info("constrain > constrain, expecting to stay the same") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_preemptible_affinity()) @@ -1220,26 +1265,26 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl logger.info("prevent -> allow, with not preemptible affinity expecting to stay") runtime = self._generate_runtime() runtime.with_node_selection(affinity=self._generate_not_preemptible_affinity()) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_not_preemptible_affinity()) logger.info("allow -> allow, expecting to stay the same") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_not_preemptible_affinity()) logger.info( "allow -> prevent, with not preemptible affinity expecting to be overridden with anti-affinity" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.prevent.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.prevent.value) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_preemptible_anti_affinity()) logger.info( "prevent -> prevent, with anti-affinity, expecting to stay the same" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.prevent.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.prevent.value) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_preemptible_anti_affinity()) @@ -1250,7 +1295,9 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl ) runtime = self._generate_runtime() runtime.with_node_selection(affinity=self._generate_affinity()) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) expected_affinity = self._generate_affinity() expected_affinity.node_affinity.required_during_scheduling_ignored_during_execution = k8s_client.V1NodeSelector( @@ -1280,7 +1327,7 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl ) runtime.with_preemption_mode( - mode=mlrun.api.schemas.PreemptionModes.constrain.value + mode=mlrun.common.schemas.PreemptionModes.constrain.value ) self.execute_function(runtime) self.assert_node_selection( @@ -1292,7 +1339,9 @@ def assert_run_preemption_mode_with_preemptible_node_selector_without_preemptibl "constrain -> allow, with not preemptible tolerations and preemptible affinity, " "expecting only not preemptible tolerations" ) - runtime.with_preemption_mode(mode=mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode( + mode=mlrun.common.schemas.PreemptionModes.allow.value + ) self.execute_function(runtime) self.assert_node_selection( tolerations=self._generate_not_preemptible_tolerations(), @@ -1305,7 +1354,7 @@ def assert_run_with_preemption_mode_none_transitions(self): json.dumps(preemptible_node_selector).encode("utf-8") ) mlrun.mlconf.function_defaults.preemption_mode = ( - mlrun.api.schemas.PreemptionModes.prevent.value + mlrun.common.schemas.PreemptionModes.prevent.value ) logger.info("prevent, expecting anti affinity") @@ -1315,7 +1364,7 @@ def assert_run_with_preemption_mode_none_transitions(self): self.assert_node_selection(affinity=self._generate_preemptible_anti_affinity()) logger.info("prevent -> none, expecting to stay the same") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.none.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.none.value) self.execute_function(runtime) self.assert_node_selection(affinity=self._generate_preemptible_anti_affinity()) @@ -1332,7 +1381,9 @@ def assert_run_with_preemption_mode_none_transitions(self): logger.info( "none -> constrain, expecting preemptible affinity and user's tolerations" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + runtime.with_preemption_mode( + mlrun.common.schemas.PreemptionModes.constrain.value + ) self.execute_function(runtime) self.assert_node_selection( affinity=self._generate_preemptible_affinity(), @@ -1342,7 +1393,7 @@ def assert_run_with_preemption_mode_none_transitions(self): logger.info( "constrain -> none, expecting preemptible affinity to stay and user's tolerations" ) - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.none.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.none.value) self.execute_function(runtime) self.assert_node_selection( affinity=self._generate_preemptible_affinity(), @@ -1350,12 +1401,12 @@ def assert_run_with_preemption_mode_none_transitions(self): ) logger.info("none -> allow, expecting user's tolerations to stay") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) self.execute_function(runtime) self.assert_node_selection(tolerations=self._generate_tolerations()) logger.info("allow -> none, expecting user's tolerations to stay") - runtime.with_preemption_mode(mlrun.api.schemas.PreemptionModes.none.value) + runtime.with_preemption_mode(mlrun.common.schemas.PreemptionModes.none.value) self.execute_function(runtime) self.assert_node_selection(tolerations=self._generate_tolerations()) @@ -1404,7 +1455,7 @@ def assert_run_with_preemption_mode_without_preemptible_configuration(self): if test_case.get("tolerations", False) else None ) - for preemption_mode in mlrun.api.schemas.PreemptionModes: + for preemption_mode in mlrun.common.schemas.PreemptionModes: runtime = self._generate_runtime() runtime.with_node_selection( node_name=node_name, diff --git a/tests/api/runtimes/test_dask.py b/tests/api/runtimes/test_dask.py index 7489a7ac4cb2..32d99fcf6f4b 100644 --- a/tests/api/runtimes/test_dask.py +++ b/tests/api/runtimes/test_dask.py @@ -24,7 +24,7 @@ import mlrun import mlrun.api.api.endpoints.functions -import mlrun.api.schemas +import mlrun.common.schemas from mlrun import mlconf from mlrun.platforms import auto_mount from mlrun.runtimes.utils import generate_resources @@ -437,10 +437,10 @@ def test_deploy_dask_function_with_enriched_security_context( ): runtime = self._generate_runtime() user_unix_id = 1000 - auth_info = mlrun.api.schemas.AuthInfo(user_unix_id=user_unix_id) + auth_info = mlrun.common.schemas.AuthInfo(user_unix_id=user_unix_id) mlrun.mlconf.igz_version = "3.6" mlrun.mlconf.function.spec.security_context.enrichment_mode = ( - mlrun.api.schemas.function.SecurityContextEnrichmentModes.disabled.value + mlrun.common.schemas.function.SecurityContextEnrichmentModes.disabled.value ) _ = mlrun.api.api.endpoints.functions._start_function(runtime, auth_info) pod = self._get_pod_creation_args() @@ -448,7 +448,7 @@ def test_deploy_dask_function_with_enriched_security_context( self.assert_security_context() mlrun.mlconf.function.spec.security_context.enrichment_mode = ( - mlrun.api.schemas.function.SecurityContextEnrichmentModes.override.value + mlrun.common.schemas.function.SecurityContextEnrichmentModes.override.value ) runtime = self._generate_runtime() _ = mlrun.api.api.endpoints.functions._start_function(runtime, auth_info) diff --git a/tests/api/runtimes/test_kubejob.py b/tests/api/runtimes/test_kubejob.py index 7a4356e05305..fa32ecdccead 100644 --- a/tests/api/runtimes/test_kubejob.py +++ b/tests/api/runtimes/test_kubejob.py @@ -23,11 +23,12 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -import mlrun.api.schemas -import mlrun.builder +import mlrun.api.api.endpoints.functions +import mlrun.api.utils.builder +import mlrun.common.schemas import mlrun.errors import mlrun.k8s_utils -from mlrun.api.schemas import SecurityContextEnrichmentModes +from mlrun.common.schemas import SecurityContextEnrichmentModes from mlrun.config import config as mlconf from mlrun.platforms import auto_mount from mlrun.runtimes.utils import generate_resources @@ -47,6 +48,7 @@ def custom_setup(self): def _generate_runtime(self) -> mlrun.runtimes.KubejobRuntime: runtime = mlrun.runtimes.KubejobRuntime() runtime.spec.image = self.image_name + runtime.metadata.project = self.project return runtime def test_run_without_runspec(self, db: Session, client: TestClient): @@ -414,35 +416,36 @@ def test_run_with_global_secrets( expected_env_from_secrets=expected_env_from_secrets, ) - def test_run_with_vault_secrets(self, db: Session, client: TestClient): - self._mock_vault_functionality() - runtime = self._generate_runtime() - - task = self._generate_task() - - task.metadata.project = self.project - secret_source = { - "kind": "vault", - "source": {"project": self.project, "secrets": self.vault_secrets}, - } - task.with_secrets(secret_source["kind"], self.vault_secrets) - vault_url = "/url/for/vault" - mlconf.secret_stores.vault.remote_url = vault_url - mlconf.secret_stores.vault.token_path = vault_url - - self.execute_function(runtime, runspec=task) - - self._assert_pod_creation_config( - expected_secrets=secret_source, - expected_env={ - "MLRUN_SECRET_STORES__VAULT__ROLE": f"project:{self.project}", - "MLRUN_SECRET_STORES__VAULT__URL": vault_url, - }, - ) - - self._assert_secret_mount( - "vault-secret", self.vault_secret_name, 420, vault_url - ) + # TODO: Vault: uncomment when vault returns to be relevant + # def test_run_with_vault_secrets(self, db: Session, client: TestClient): + # self._mock_vault_functionality() + # runtime = self._generate_runtime() + # + # task = self._generate_task() + # + # task.metadata.project = self.project + # secret_source = { + # "kind": "vault", + # "source": {"project": self.project, "secrets": self.vault_secrets}, + # } + # task.with_secrets(secret_source["kind"], self.vault_secrets) + # vault_url = "/url/for/vault" + # mlconf.secret_stores.vault.remote_url = vault_url + # mlconf.secret_stores.vault.token_path = vault_url + # + # self.execute_function(runtime, runspec=task) + # + # self._assert_pod_creation_config( + # expected_secrets=secret_source, + # expected_env={ + # "MLRUN_SECRET_STORES__VAULT__ROLE": f"project:{self.project}", + # "MLRUN_SECRET_STORES__VAULT__URL": vault_url, + # }, + # ) + # + # self._assert_secret_mount( + # "vault-secret", self.vault_secret_name, 420, vault_url + # ) def test_run_with_code(self, db: Session, client: TestClient): runtime = self._generate_runtime() @@ -554,13 +557,11 @@ def test_with_image_pull_configuration(self, db: Session, client: TestClient): def test_with_requirements(self, db: Session, client: TestClient): runtime = self._generate_runtime() runtime.with_requirements(self.requirements_file) - expected_commands = [ - "python -m pip install faker python-dotenv 'chardet>=3.0.2, <4.0'" - ] + expected_requirements = ["faker", "python-dotenv", "chardet>=3.0.2, <4.0"] assert ( deepdiff.DeepDiff( - expected_commands, - runtime.spec.build.commands, + expected_requirements, + runtime.spec.build.requirements, ignore_order=True, ) == {} @@ -665,27 +666,59 @@ def test_build_config(self, db: Session, client: TestClient): ) runtime.build_config(requirements=["pandas", "numpy"]) - expected_commands = [ - "python -m pip install scikit-learn", - "python -m pip install pandas numpy", + expected_requirements = [ + "pandas", + "numpy", ] - print(runtime.spec.build.commands) + assert ( + deepdiff.DeepDiff( + expected_requirements, + runtime.spec.build.requirements, + ignore_order=False, + ) + == {} + ) + expected_commands = ["python -m pip install scikit-learn"] assert ( deepdiff.DeepDiff( expected_commands, runtime.spec.build.commands, - ignore_order=False, + ignore_order=True, ) == {} ) runtime.build_config(requirements=["scikit-learn"], overwrite=True) - expected_commands = ["python -m pip install scikit-learn"] + expected_requirements = ["scikit-learn"] + assert ( + deepdiff.DeepDiff( + expected_requirements, + runtime.spec.build.requirements, + ignore_order=True, + ) + == {} + ) + + def test_build_config_commands_and_requirements_order( + self, db: Session, client: TestClient + ): + runtime = self._generate_runtime() + runtime.build_config(commands=["apt-get update"], requirements=["scikit-learn"]) + expected_commands = ["apt-get update"] + expected_requirements = ["scikit-learn"] assert ( deepdiff.DeepDiff( expected_commands, runtime.spec.build.commands, - ignore_order=True, + ignore_order=False, + ) + == {} + ) + assert ( + deepdiff.DeepDiff( + expected_requirements, + runtime.spec.build.requirements, + ignore_order=False, ) == {} ) @@ -719,27 +752,208 @@ def test_deploy_upgrade_pip( expected_to_upgrade, ): mlrun.mlconf.httpdb.builder.docker_registry = "localhost:5000" - mlrun.builder.make_kaniko_pod = unittest.mock.MagicMock() - - runtime = self._generate_runtime() - runtime.spec.build.base_image = "some/image" - runtime.spec.build.commands = copy.deepcopy(commands) - runtime.deploy(with_mlrun=with_mlrun, watch=False) - dockerfile = mlrun.builder.make_kaniko_pod.call_args[1]["dockertext"] - if expected_to_upgrade: - expected_str = "" - if commands: - expected_str += "\nRUN " - expected_str += "\nRUN ".join(commands) - expected_str += f"\nRUN python -m pip install --upgrade pip{mlrun.mlconf.httpdb.builder.pip_version}" - if with_mlrun: - expected_str += '\nRUN python -m pip install "mlrun[complete]' - assert expected_str in dockerfile - else: - assert ( - f"pip install --upgrade pip{mlrun.mlconf.httpdb.builder.pip_version}" - not in dockerfile + with unittest.mock.patch( + "mlrun.api.utils.builder.make_kaniko_pod", unittest.mock.MagicMock() + ): + runtime = self._generate_runtime() + runtime.spec.build.base_image = "some/image" + runtime.spec.build.commands = copy.deepcopy(commands) + self.deploy(db, runtime, with_mlrun=with_mlrun) + dockerfile = mlrun.api.utils.builder.make_kaniko_pod.call_args[1][ + "dockertext" + ] + if expected_to_upgrade: + expected_str = "" + if commands: + expected_str += "\nRUN " + expected_str += "\nRUN ".join(commands) + expected_str += f"\nRUN python -m pip install --upgrade pip{mlrun.mlconf.httpdb.builder.pip_version}" + + # assert that mlrun was added to the requirements file + if with_mlrun: + expected_str += ( + "\nRUN echo 'Installing /empty/requirements.txt...'; cat /empty/requirements.txt" + "\nRUN python -m pip install -r /empty/requirements.txt" + ) + kaniko_pod_requirements = ( + mlrun.api.utils.builder.make_kaniko_pod.call_args[1][ + "requirements" + ] + ) + assert kaniko_pod_requirements == [ + "mlrun[complete] @ git+https://github.com/mlrun/mlrun@development" + ] + assert expected_str in dockerfile + else: + assert ( + f"pip install --upgrade pip{mlrun.mlconf.httpdb.builder.pip_version}" + not in dockerfile + ) + + @pytest.mark.parametrize( + "with_mlrun, requirements, with_requirements_file, expected_requirements", + [ + ( + True, + [], + False, + ["mlrun[complete] @ git+https://github.com/mlrun/mlrun@development"], + ), + ( + True, + ["pandas"], + False, + [ + "mlrun[complete] @ git+https://github.com/mlrun/mlrun@development", + "pandas", + ], + ), + ( + True, + ["pandas", "tensorflow"], + False, + [ + "mlrun[complete] @ git+https://github.com/mlrun/mlrun@development", + "pandas", + "tensorflow", + ], + ), + (False, [], True, ["faker", "python-dotenv", "chardet>=3.0.2, <4.0"]), + (False, ["pandas", "tensorflow"], False, ["pandas", "tensorflow"]), + ( + False, + ["pandas", "tensorflow"], + True, + [ + "faker", + "python-dotenv", + "chardet>=3.0.2, <4.0", + "pandas", + "tensorflow", + ], + ), + ( + True, + ["pandas", "tensorflow"], + True, + [ + "mlrun[complete] @ git+https://github.com/mlrun/mlrun@development", + "faker", + "python-dotenv", + "chardet>=3.0.2, <4.0", + "pandas", + "tensorflow", + ], + ), + ( + True, + [], + True, + [ + "mlrun[complete] @ git+https://github.com/mlrun/mlrun@development", + "faker", + "python-dotenv", + "chardet>=3.0.2, <4.0", + ], + ), + ], + ) + def test_deploy_with_mlrun( + self, + db: Session, + client: TestClient, + with_mlrun, + requirements, + with_requirements_file, + expected_requirements, + ): + mlrun.mlconf.httpdb.builder.docker_registry = "localhost:5000" + with unittest.mock.patch( + "mlrun.api.utils.builder.make_kaniko_pod", unittest.mock.MagicMock() + ): + runtime = self._generate_runtime() + runtime.spec.build.base_image = "some/image" + + requirements_file = ( + "" if not with_requirements_file else self.requirements_file + ) + runtime.with_requirements( + requirements=requirements, requirements_file=requirements_file + ) + + self.deploy(db, runtime, with_mlrun=with_mlrun) + dockerfile = mlrun.api.utils.builder.make_kaniko_pod.call_args[1][ + "dockertext" + ] + + install_requirements_commands = ( + "\nRUN echo 'Installing /empty/requirements.txt...'; cat /empty/requirements.txt" + "\nRUN python -m pip install -r /empty/requirements.txt" ) + kaniko_pod_requirements = mlrun.api.utils.builder.make_kaniko_pod.call_args[ + 1 + ]["requirements"] + if with_mlrun: + expected_str = f"\nRUN python -m pip install --upgrade pip{mlrun.mlconf.httpdb.builder.pip_version}" + expected_str += install_requirements_commands + assert kaniko_pod_requirements == expected_requirements + assert expected_str in dockerfile + + else: + assert ( + f"pip install --upgrade pip{mlrun.mlconf.httpdb.builder.pip_version}" + not in dockerfile + ) + + # assert that install requirements commands are in the dockerfile + if with_requirements_file or requirements: + expected_str = install_requirements_commands + assert expected_str in dockerfile + + # assert mlrun is not in the requirements + for requirement in kaniko_pod_requirements: + assert "mlrun" not in requirement + + assert kaniko_pod_requirements == expected_requirements + + @pytest.mark.parametrize( + "workdir, source, pull_at_runtime, target_dir, expected_workdir", + [ + ("", "git://bla", True, None, None), + ("", "git://bla", False, None, None), + ("", "git://bla", False, "/a/b/c", "/a/b/c/"), + ("subdir", "git://bla", False, "/a/b/c", "/a/b/c/subdir"), + ("./subdir", "git://bla", False, "/a/b/c", "/a/b/c/subdir"), + ("./subdir", "git://bla", True, "/a/b/c", None), + ("/abs/subdir", "git://bla", False, "/a/b/c", "/abs/subdir"), + ("/abs/subdir", "git://bla", False, None, "/abs/subdir"), + ], + ) + def test_resolve_workdir( + self, workdir, source, pull_at_runtime, target_dir, expected_workdir + ): + runtime = self._generate_runtime() + runtime.with_source_archive( + source, workdir, pull_at_runtime=pull_at_runtime, target_dir=target_dir + ) + + # mock the build + runtime.spec.image = "some/image" + self.execute_function(runtime) + pod = self._get_pod_creation_args() + assert pod.spec.containers[0].working_dir == expected_workdir + + def test_with_source_archive_validation(self): + runtime = self._generate_runtime() + source = "./some/relative/path" + with pytest.raises(mlrun.errors.MLRunInvalidArgumentError) as e: + runtime.with_source_archive(source, pull_at_runtime=False) + assert ( + f"Source '{source}' must be a valid URL or absolute path when 'pull_at_runtime' is False " + "set 'source' to a remote URL to clone/copy the source to the base image, " + "or set 'pull_at_runtime' to True to pull the source at runtime." + in str(e.value) + ) @staticmethod def _assert_build_commands(expected_commands, runtime): diff --git a/tests/api/runtimes/test_mpijob.py b/tests/api/runtimes/test_mpijob.py index e220e70033b2..fbc164a88a96 100644 --- a/tests/api/runtimes/test_mpijob.py +++ b/tests/api/runtimes/test_mpijob.py @@ -15,11 +15,14 @@ import typing import unittest.mock +from fastapi.testclient import TestClient from kubernetes import client as k8s_client +from sqlalchemy.orm import Session +import mlrun.api.utils.builder import mlrun.runtimes.pod from mlrun import code_to_function, mlconf -from mlrun.api.utils.singletons.k8s import get_k8s +from mlrun.api.utils.singletons.k8s import get_k8s_helper from mlrun.runtimes.constants import MPIJobCRDVersions from tests.api.runtimes.base import TestRuntimeBase @@ -31,21 +34,25 @@ def custom_setup(self): self.name = "test-mpi-v1" mlconf.mpijob_crd_version = MPIJobCRDVersions.v1 - def test_run_v1_sanity(self): - self._mock_list_pods() - self._mock_create_namespaced_custom_object() - self._mock_get_namespaced_custom_object() - mpijob_function = self._generate_runtime(self.runtime_kind) - mpijob_function.deploy() - run = mpijob_function.run( - artifact_path="v3io:///mypath", - watch=False, - ) + def test_run_v1_sanity(self, db: Session, client: TestClient): + mlconf.httpdb.builder.docker_registry = "localhost:5000" + with unittest.mock.patch( + "mlrun.api.utils.builder.make_kaniko_pod", unittest.mock.MagicMock() + ): + self._mock_list_pods() + self._mock_create_namespaced_custom_object() + self._mock_get_namespaced_custom_object() + mpijob_function = self._generate_runtime(self.runtime_kind) + self.deploy(db, mpijob_function) + run = mpijob_function.run( + artifact_path="v3io:///mypath", + watch=False, + ) - assert run.status.state == "running" + assert run.status.state == "running" def _mock_get_namespaced_custom_object(self, workers=1): - get_k8s().crdapi.get_namespaced_custom_object = unittest.mock.Mock( + get_k8s_helper().crdapi.get_namespaced_custom_object = unittest.mock.Mock( return_value={ "status": { "replicaStatuses": { @@ -64,7 +71,7 @@ def _mock_list_pods(self, workers=1, pods=None, phase="Running"): if pods is None: pods = [self._get_worker_pod(phase=phase)] * workers pods += [self._get_launcher_pod(phase=phase)] - get_k8s().list_pods = unittest.mock.Mock(return_value=pods) + get_k8s_helper().list_pods = unittest.mock.Mock(return_value=pods) def _get_worker_pod(self, phase="Running"): return k8s_client.V1Pod( diff --git a/tests/api/runtimes/test_nuclio.py b/tests/api/runtimes/test_nuclio.py index c02fd3e9cbe2..31937b7860cc 100644 --- a/tests/api/runtimes/test_nuclio.py +++ b/tests/api/runtimes/test_nuclio.py @@ -28,23 +28,16 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session -import mlrun.api.schemas +import mlrun.api.crud.runtimes.nuclio.function +import mlrun.api.crud.runtimes.nuclio.helpers +import mlrun.common.schemas import mlrun.errors +import mlrun.runtimes.function import mlrun.runtimes.pod from mlrun import code_to_function, mlconf from mlrun.api.api.endpoints.functions import _build_function from mlrun.platforms.iguazio import split_path from mlrun.runtimes.constants import NuclioIngressAddTemplatedIngressModes -from mlrun.runtimes.function import ( - _compile_nuclio_archive_config, - compile_function_config, - deploy_nuclio_function, - enrich_function_with_ingress, - is_nuclio_version_in_range, - min_nuclio_versions, - resolve_function_ingresses, - validate_nuclio_version_compatibility, -) from mlrun.utils import logger from tests.api.conftest import K8sSecretsMock from tests.api.runtimes.base import TestRuntimeBase @@ -124,7 +117,11 @@ def _get_expected_struct_for_v3io_trigger(self, parameters): } def _execute_run(self, runtime, **kwargs): - deploy_nuclio_function(runtime, **kwargs) + # deploy_nuclio_function doesn't accept watch, so we need to remove it + kwargs.pop("watch", None) + mlrun.api.crud.runtimes.nuclio.function.deploy_nuclio_function( + runtime, **kwargs + ) def _generate_runtime( self, kind=None, labels=None @@ -156,6 +153,7 @@ def _assert_deploy_called_basic_config( expected_build_base_image=None, expected_nuclio_runtime=None, expected_env=None, + expected_build_commands=None, ): if expected_labels is None: expected_labels = {} @@ -221,6 +219,13 @@ def _assert_deploy_called_basic_config( if expected_nuclio_runtime: assert deploy_config["spec"]["runtime"] == expected_nuclio_runtime + + if expected_build_commands: + assert ( + deploy_config["spec"]["build"]["commands"] + == expected_build_commands + ) + return deploy_configs def _assert_triggers(self, http_trigger=None, v3io_trigger=None): @@ -363,7 +368,9 @@ def test_compile_function_config_with_special_character_labels( function = self._generate_runtime(self.runtime_kind) key, val = "test.label.com/env", "test" function.set_label(key, val) - _, _, config = compile_function_config(function) + _, _, config = mlrun.api.crud.runtimes.nuclio.function._compile_function_config( + function + ) assert config["metadata"]["labels"].get(key) == val def test_enrich_with_ingress_no_overriding(self, db: Session, client: TestClient): @@ -376,12 +383,18 @@ def test_enrich_with_ingress_no_overriding(self, db: Session, client: TestClient # both ingress and node port ingress_host = "something.com" function.with_http(host=ingress_host, paths=["/"], port=30030) - function_name, project_name, config = compile_function_config(function) + ( + function_name, + project_name, + config, + ) = mlrun.api.crud.runtimes.nuclio.function._compile_function_config(function) service_type = "NodePort" - enrich_function_with_ingress( + mlrun.api.crud.runtimes.nuclio.helpers.enrich_function_with_ingress( config, NuclioIngressAddTemplatedIngressModes.always, service_type ) - ingresses = resolve_function_ingresses(config["spec"]) + ingresses = mlrun.api.crud.runtimes.nuclio.helpers.resolve_function_ingresses( + config["spec"] + ) assert len(ingresses) > 0, "Expected one ingress to be created" for ingress in ingresses: assert "hostTemplate" not in ingress, "No host template should be added" @@ -392,12 +405,18 @@ def test_enrich_with_ingress_always(self, db: Session, client: TestClient): Expect ingress template to be created as the configuration templated ingress mode is "always" """ function = self._generate_runtime(self.runtime_kind) - function_name, project_name, config = compile_function_config(function) + ( + function_name, + project_name, + config, + ) = mlrun.api.crud.runtimes.nuclio.function._compile_function_config(function) service_type = "NodePort" - enrich_function_with_ingress( + mlrun.api.crud.runtimes.nuclio.helpers.enrich_function_with_ingress( config, NuclioIngressAddTemplatedIngressModes.always, service_type ) - ingresses = resolve_function_ingresses(config["spec"]) + ingresses = mlrun.api.crud.runtimes.nuclio.helpers.resolve_function_ingresses( + config["spec"] + ) assert ingresses[0]["hostTemplate"] != "" def test_enrich_with_ingress_on_cluster_ip(self, db: Session, client: TestClient): @@ -406,14 +425,20 @@ def test_enrich_with_ingress_on_cluster_ip(self, db: Session, client: TestClient function service type is ClusterIP """ function = self._generate_runtime(self.runtime_kind) - function_name, project_name, config = compile_function_config(function) + ( + function_name, + project_name, + config, + ) = mlrun.api.crud.runtimes.nuclio.function._compile_function_config(function) service_type = "ClusterIP" - enrich_function_with_ingress( + mlrun.api.crud.runtimes.nuclio.helpers.enrich_function_with_ingress( config, NuclioIngressAddTemplatedIngressModes.on_cluster_ip, service_type, ) - ingresses = resolve_function_ingresses(config["spec"]) + ingresses = mlrun.api.crud.runtimes.nuclio.helpers.resolve_function_ingresses( + config["spec"] + ) assert ingresses[0]["hostTemplate"] != "" def test_enrich_with_ingress_never(self, db: Session, client: TestClient): @@ -421,12 +446,18 @@ def test_enrich_with_ingress_never(self, db: Session, client: TestClient): Expect no ingress to be created automatically as the configuration templated ingress mode is "never" """ function = self._generate_runtime(self.runtime_kind) - function_name, project_name, config = compile_function_config(function) + ( + function_name, + project_name, + config, + ) = mlrun.api.crud.runtimes.nuclio.function._compile_function_config(function) service_type = "DoesNotMatter" - enrich_function_with_ingress( + mlrun.api.crud.runtimes.nuclio.helpers.enrich_function_with_ingress( config, NuclioIngressAddTemplatedIngressModes.never, service_type ) - ingresses = resolve_function_ingresses(config["spec"]) + ingresses = mlrun.api.crud.runtimes.nuclio.helpers.resolve_function_ingresses( + config["spec"] + ) assert ingresses == [] def test_nuclio_config_spec_env(self, db: Session, client: TestClient): @@ -449,7 +480,11 @@ def test_nuclio_config_spec_env(self, db: Session, client: TestClient): {"name": name2, "value": value2}, ] - function_name, project_name, config = compile_function_config(function) + ( + function_name, + project_name, + config, + ) = mlrun.api.crud.runtimes.nuclio.function._compile_function_config(function) for expected_env_var in expected_env_vars: assert expected_env_var in config["spec"]["env"] assert isinstance(function.spec.env[0], kubernetes.client.V1EnvVar) @@ -457,7 +492,11 @@ def test_nuclio_config_spec_env(self, db: Session, client: TestClient): # simulating sending to API - serialization through dict function = function.from_dict(function.to_dict()) - function_name, project_name, config = compile_function_config(function) + ( + function_name, + project_name, + config, + ) = mlrun.api.crud.runtimes.nuclio.function._compile_function_config(function) for expected_env_var in expected_env_vars: assert expected_env_var in config["spec"]["env"] @@ -484,7 +523,7 @@ def test_deploy_with_project_service_accounts( self, db: Session, k8s_secrets_mock: K8sSecretsMock ): k8s_secrets_mock.set_service_account_keys(self.project, "sa1", ["sa1", "sa2"]) - auth_info = mlrun.api.schemas.AuthInfo() + auth_info = mlrun.common.schemas.AuthInfo() function = self._generate_runtime(self.runtime_kind) # Need to call _build_function, since service-account enrichment is happening only on server side, before the # call to deploy_nuclio_function @@ -511,17 +550,17 @@ def test_deploy_with_security_context_enrichment( self, db: Session, k8s_secrets_mock: K8sSecretsMock ): user_unix_id = 1000 - auth_info = mlrun.api.schemas.AuthInfo(user_unix_id=user_unix_id) + auth_info = mlrun.common.schemas.AuthInfo(user_unix_id=user_unix_id) mlrun.mlconf.igz_version = "3.6" mlrun.mlconf.function.spec.security_context.enrichment_mode = ( - mlrun.api.schemas.function.SecurityContextEnrichmentModes.disabled.value + mlrun.common.schemas.function.SecurityContextEnrichmentModes.disabled.value ) function = self._generate_runtime(self.runtime_kind) _build_function(db, auth_info, function) self.assert_security_context({}) mlrun.mlconf.function.spec.security_context.enrichment_mode = ( - mlrun.api.schemas.function.SecurityContextEnrichmentModes.override.value + mlrun.common.schemas.function.SecurityContextEnrichmentModes.override.value ) function = self._generate_runtime(self.runtime_kind) _build_function(db, auth_info, function) @@ -537,7 +576,7 @@ def test_deploy_with_global_service_account( ): service_account_name = "default-sa" mlconf.function.spec.service_account.default = service_account_name - auth_info = mlrun.api.schemas.AuthInfo() + auth_info = mlrun.common.schemas.AuthInfo() function = self._generate_runtime(self.runtime_kind) # Need to call _build_function, since service-account enrichment is happening only on server side, before the # call to deploy_nuclio_function @@ -611,6 +650,71 @@ def test_deploy_without_image_and_build_base_image( self._assert_deploy_called_basic_config(expected_class=self.class_name) + def test_deploy_image_with_enrich_registry_prefix(self): + function = self._generate_runtime(self.runtime_kind) + function.spec.image = ".my/image:latest" + + with unittest.mock.patch( + "mlrun.utils.get_parsed_docker_registry", + return_value=["some.registry", "some-repository"], + ): + self.execute_function(function) + self._assert_deploy_called_basic_config( + expected_class=self.class_name, + expected_build_base_image="some.registry/some-repository/my/image:latest", + ) + + @pytest.mark.parametrize( + "requirements,expected_commands", + [ + (["pandas", "numpy"], ["python -m pip install pandas numpy"]), + ( + ["-r requirements.txt", "numpy"], + ["python -m pip install -r requirements.txt numpy"], + ), + (["pandas>=1.0.0, <2"], ["python -m pip install 'pandas>=1.0.0, <2'"]), + (["pandas>=1.0.0,<2"], ["python -m pip install 'pandas>=1.0.0,<2'"]), + ( + ["-r somewhere/requirements.txt"], + ["python -m pip install -r somewhere/requirements.txt"], + ), + ( + ["something @ git+https://somewhere.com/a/b.git@v0.0.0#egg=something"], + [ + "python -m pip install 'something @ git+https://somewhere.com/a/b.git@v0.0.0#egg=something'" + ], + ), + ], + ) + def test_deploy_function_with_requirements( + self, + requirements: list, + expected_commands: list, + db: Session, + client: TestClient, + ): + function = self._generate_runtime(self.runtime_kind) + function.with_requirements(requirements) + self.execute_function(function) + self._assert_deploy_called_basic_config( + expected_class=self.class_name, expected_build_commands=expected_commands + ) + + def test_deploy_function_with_commands_and_requirements( + self, db: Session, client: TestClient + ): + function = self._generate_runtime(self.runtime_kind) + function.with_commands(["python -m pip install scikit-learn"]) + function.with_requirements(["pandas", "numpy"]) + self.execute_function(function) + expected_commands = [ + "python -m pip install scikit-learn", + "python -m pip install pandas numpy", + ] + self._assert_deploy_called_basic_config( + expected_class=self.class_name, expected_build_commands=expected_commands + ) + def test_deploy_function_with_labels(self, db: Session, client: TestClient): labels = { "key": "value", @@ -937,43 +1041,75 @@ def test_deploy_python_decode_string_env_var_enrichment( def test_is_nuclio_version_in_range(self): mlrun.runtimes.utils.cached_nuclio_version = "1.7.2" - assert not is_nuclio_version_in_range("1.6.11", "1.7.2") - assert not is_nuclio_version_in_range("1.7.0", "1.3.1") - assert not is_nuclio_version_in_range("1.7.3", "1.8.5") - assert not is_nuclio_version_in_range("1.7.2", "1.7.2") - assert is_nuclio_version_in_range("1.7.2", "1.7.3") - assert is_nuclio_version_in_range("1.7.0", "1.7.3") - assert is_nuclio_version_in_range("1.5.5", "1.7.3") - assert is_nuclio_version_in_range("1.5.5", "2.3.4") + assert not mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.6.11", "1.7.2" + ) + assert not mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.7.0", "1.3.1" + ) + assert not mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.7.3", "1.8.5" + ) + assert not mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.7.2", "1.7.2" + ) + assert mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.7.2", "1.7.3" + ) + assert mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.7.0", "1.7.3" + ) + assert mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.5.5", "1.7.3" + ) + assert mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.5.5", "2.3.4" + ) # best effort - assumes compatibility mlrun.runtimes.utils.cached_nuclio_version = "" - assert is_nuclio_version_in_range("1.5.5", "2.3.4") - assert is_nuclio_version_in_range("1.7.2", "1.7.2") + assert mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.5.5", "2.3.4" + ) + assert mlrun.api.crud.runtimes.nuclio.helpers.is_nuclio_version_in_range( + "1.7.2", "1.7.2" + ) def test_validate_nuclio_version_compatibility(self): # nuclio version we have mlconf.nuclio_version = "1.6.10" - # validate_nuclio_version_compatibility receives the min nuclio version required - assert not validate_nuclio_version_compatibility("1.6.11") - assert not validate_nuclio_version_compatibility("1.5.9", "1.6.11") - assert not validate_nuclio_version_compatibility("1.6.11", "1.5.9") - assert not validate_nuclio_version_compatibility("2.0.0") - assert validate_nuclio_version_compatibility("1.6.9") - assert validate_nuclio_version_compatibility("1.5.9") + # mlrun.runtimes.function.validate_nuclio_version_compatibility receives the min nuclio version required + assert not mlrun.runtimes.function.validate_nuclio_version_compatibility( + "1.6.11" + ) + assert not mlrun.runtimes.function.validate_nuclio_version_compatibility( + "1.5.9", "1.6.11" + ) + assert not mlrun.runtimes.function.validate_nuclio_version_compatibility( + "1.6.11", "1.5.9" + ) + assert not mlrun.runtimes.function.validate_nuclio_version_compatibility( + "2.0.0" + ) + assert mlrun.runtimes.function.validate_nuclio_version_compatibility("1.6.9") + assert mlrun.runtimes.function.validate_nuclio_version_compatibility("1.5.9") mlconf.nuclio_version = "2.0.0" - assert validate_nuclio_version_compatibility("1.6.11") - assert validate_nuclio_version_compatibility("1.5.9", "1.6.11") + assert mlrun.runtimes.function.validate_nuclio_version_compatibility("1.6.11") + assert mlrun.runtimes.function.validate_nuclio_version_compatibility( + "1.5.9", "1.6.11" + ) # best effort - assumes compatibility mlconf.nuclio_version = "" - assert validate_nuclio_version_compatibility("1.6.11") - assert validate_nuclio_version_compatibility("1.5.9", "1.6.11") + assert mlrun.runtimes.function.validate_nuclio_version_compatibility("1.6.11") + assert mlrun.runtimes.function.validate_nuclio_version_compatibility( + "1.5.9", "1.6.11" + ) with pytest.raises(ValueError): - validate_nuclio_version_compatibility("") + mlrun.runtimes.function.validate_nuclio_version_compatibility("") def test_min_nuclio_versions_decorator_failure(self): mlconf.nuclio_version = "1.6.10" @@ -984,7 +1120,7 @@ def test_min_nuclio_versions_decorator_failure(self): ["1.5.9", "1.6.11"], ]: - @min_nuclio_versions(*case) + @mlrun.runtimes.function.min_nuclio_versions(*case) def fail(): pytest.fail("Should not enter this function") @@ -1001,7 +1137,7 @@ def test_min_nuclio_versions_decorator_success(self): ["1.0.0", "0.9.81", "1.4.1"], ]: - @min_nuclio_versions(*case) + @mlrun.runtimes.function.min_nuclio_versions(*case) def success(): pass @@ -1237,26 +1373,30 @@ def test_deploy_function_with_image_pull_secret( if build_secret_name is not None: fn.spec.build.secret = build_secret_name - _, _, deployed_config = compile_function_config(fn) + ( + _, + _, + deployed_config, + ) = mlrun.api.crud.runtimes.nuclio.function._compile_function_config(fn) assert deployed_config["spec"].get("imagePullSecrets") == expected_secret_name def test_nuclio_with_preemption_mode(self): fn = self._generate_runtime(self.runtime_kind) assert fn.spec.preemption_mode == "prevent" - fn.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + fn.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) assert fn.spec.preemption_mode == "allow" - fn.with_preemption_mode(mlrun.api.schemas.PreemptionModes.constrain.value) + fn.with_preemption_mode(mlrun.common.schemas.PreemptionModes.constrain.value) assert fn.spec.preemption_mode == "constrain" - fn.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + fn.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) assert fn.spec.preemption_mode == "allow" mlconf.nuclio_version = "1.7.5" with pytest.raises(mlrun.errors.MLRunIncompatibleVersionError): - fn.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + fn.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) mlconf.nuclio_version = "1.8.6" - fn.with_preemption_mode(mlrun.api.schemas.PreemptionModes.allow.value) + fn.with_preemption_mode(mlrun.common.schemas.PreemptionModes.allow.value) assert fn.spec.preemption_mode == "allow" def test_preemption_mode_without_preemptible_configuration( @@ -1417,11 +1557,19 @@ def test_deploy_with_service_type( if expected_ingress_host_template is None: # never - ingresses = resolve_function_ingresses(deploy_spec) + ingresses = ( + mlrun.api.crud.runtimes.nuclio.helpers.resolve_function_ingresses( + deploy_spec + ) + ) assert ingresses == [] else: - ingresses = resolve_function_ingresses(deploy_spec) + ingresses = ( + mlrun.api.crud.runtimes.nuclio.helpers.resolve_function_ingresses( + deploy_spec + ) + ) assert ingresses[0]["hostTemplate"] == expected_ingress_host_template @@ -1436,6 +1584,8 @@ def runtime_kind(self): def get_archive_spec(function, secrets): spec = nuclio.ConfigSpec() config = {} - _compile_nuclio_archive_config(spec, function, secrets) + mlrun.api.crud.runtimes.nuclio.helpers.compile_nuclio_archive_config( + spec, function, secrets + ) spec.merge(config) return config diff --git a/tests/api/runtimes/test_serving.py b/tests/api/runtimes/test_serving.py index ae6671eb3b29..c71023177b79 100644 --- a/tests/api/runtimes/test_serving.py +++ b/tests/api/runtimes/test_serving.py @@ -25,15 +25,11 @@ from sqlalchemy.orm import Session import mlrun.api.api.utils -import tests.api.api.utils +import mlrun.api.crud.runtimes.nuclio.function from mlrun import mlconf, new_function -from mlrun.api.utils.singletons.k8s import get_k8s +from mlrun.api.utils.singletons.k8s import get_k8s_helper from mlrun.db import SQLDB -from mlrun.runtimes.function import ( - NuclioStatus, - compile_function_config, - deploy_nuclio_function, -) +from mlrun.runtimes.function import NuclioStatus from .assets.serving_child_functions import * # noqa @@ -55,7 +51,8 @@ def class_name(self): def custom_setup_after_fixtures(self): self._mock_nuclio_deploy_config() - self._mock_vault_functionality() + # TODO: Vault: uncomment when vault returns to be relevant + # self._mock_vault_functionality() # Since most of the Serving runtime handling is done client-side, we'll mock the calls to remote-build # and instead just call the deploy_nuclio_function() API which actually performs the # deployment in this case. This will keep the tests' code mostly client-side oriented, but validations @@ -75,7 +72,7 @@ def custom_setup(self): @staticmethod def _mock_db_remote_deploy_functions(): def _remote_db_mock_function(func, with_mlrun, builder_env=None): - deploy_nuclio_function(func) + mlrun.api.crud.runtimes.nuclio.function.deploy_nuclio_function(func) return { "data": { "status": NuclioStatus( @@ -121,21 +118,22 @@ def _assert_deploy_spec_has_secrets_config(self, expected_secret_sources): args, _ = single_call_args deploy_spec = args[0]["spec"] - token_path = mlconf.secret_stores.vault.token_path.replace("~", "/root") azure_secret_path = mlconf.secret_stores.azure_vault.secret_path.replace( "~", "/root" ) + # TODO: Vault: uncomment when vault returns to be relevant + # token_path = mlconf.secret_stores.vault.token_path.replace("~", "/root") expected_volumes = [ - { - "volume": { - "name": "vault-secret", - "secret": { - "defaultMode": 420, - "secretName": self.vault_secret_name, - }, - }, - "volumeMount": {"name": "vault-secret", "mountPath": token_path}, - }, + # { + # "volume": { + # "name": "vault-secret", + # "secret": { + # "defaultMode": 420, + # "secretName": self.vault_secret_name, + # }, + # }, + # "volumeMount": {"name": "vault-secret", "mountPath": token_path}, + # }, { "volume": { "name": "azure-vault-secret", @@ -158,8 +156,9 @@ def _assert_deploy_spec_has_secrets_config(self, expected_secret_sources): ) expected_env = { - "MLRUN_SECRET_STORES__VAULT__ROLE": f"project:{self.project}", - "MLRUN_SECRET_STORES__VAULT__URL": mlconf.secret_stores.vault.url, + # TODO: Vault: uncomment when vault returns to be relevant + # "MLRUN_SECRET_STORES__VAULT__ROLE": f"project:{self.project}", + # "MLRUN_SECRET_STORES__VAULT__URL": mlconf.secret_stores.vault.url, # For now, just checking the variable exists, later we check specific contents "SERVING_SPEC_ENV": None, } @@ -182,10 +181,11 @@ def _generate_expected_secret_sources(self): full_inline_secrets["ENV_SECRET1"] = os.environ["ENV_SECRET1"] expected_secret_sources = [ {"kind": "inline", "source": full_inline_secrets}, - { - "kind": "vault", - "source": {"project": self.project, "secrets": self.vault_secrets}, - }, + # TODO: Vault: uncomment when vault returns to be relevant + # { + # "kind": "vault", + # "source": {"project": self.project, "secrets": self.vault_secrets}, + # }, { "kind": "azure_vault", "source": { @@ -212,9 +212,10 @@ def test_mock_server_secrets(self, db: Session, client: TestClient): server = function.to_mock_server() + # TODO: Vault: uncomment when vault returns to be relevant # Verify all secrets are in the context - for secret_key in self.vault_secrets: - assert server.context.get_secret(secret_key) == self.vault_secret_value + # for secret_key in self.vault_secrets: + # assert server.context.get_secret(secret_key) == self.vault_secret_value for secret_key in self.inline_secrets: assert ( server.context.get_secret(secret_key) == self.inline_secrets[secret_key] @@ -226,7 +227,9 @@ def test_mock_server_secrets(self, db: Session, client: TestClient): expected_response = [ {"inline_secret1": self.inline_secrets["inline_secret1"]}, {"ENV_SECRET1": os.environ["ENV_SECRET1"]}, - {"AWS_KEY": self.vault_secret_value}, + # TODO: Vault: uncomment when vault returns to be relevant, and replace the AWS_KEY with the current key + # {"AWS_KEY": self.vault_secret_value}, + {"AWS_KEY": None}, ] assert deepdiff.DeepDiff(resp, expected_response) == {} @@ -244,12 +247,13 @@ def test_mock_bad_step(self, db: Session, client: TestClient): server.test() def test_serving_with_secrets_remote_build(self, db: Session, client: TestClient): - orig_function = get_k8s()._get_project_secrets_raw_data - get_k8s()._get_project_secrets_raw_data = unittest.mock.Mock(return_value={}) + orig_function = get_k8s_helper()._get_project_secrets_raw_data + get_k8s_helper()._get_project_secrets_raw_data = unittest.mock.Mock( + return_value={} + ) mlrun.api.api.utils.mask_function_sensitive_data = unittest.mock.Mock() function = self._create_serving_function() - tests.api.api.utils.create_project(client, self.project) # Simulate a remote build by issuing client's API. Code below is taken from httpdb. req = { @@ -263,7 +267,7 @@ def test_serving_with_secrets_remote_build(self, db: Session, client: TestClient self._assert_deploy_called_basic_config(expected_class=self.class_name) - get_k8s()._get_project_secrets_raw_data = orig_function + get_k8s_helper()._get_project_secrets_raw_data = orig_function def test_child_functions_with_secrets(self, db: Session, client: TestClient): function = self._create_serving_function() @@ -315,7 +319,9 @@ def test_empty_function(self): # test simple function (no source) function = new_function("serving", kind="serving", image="mlrun/mlrun") function.set_topology("flow") - _, _, config = compile_function_config(function) + _, _, config = mlrun.api.crud.runtimes.nuclio.function._compile_function_config( + function + ) # verify the code is filled with the mlrun serving wrapper assert config["spec"]["build"]["functionSourceCode"] @@ -326,10 +332,14 @@ def test_empty_function(self): function.set_topology("flow") # mock secrets for the source (so it will not fail) - orig_function = get_k8s()._get_project_secrets_raw_data - get_k8s()._get_project_secrets_raw_data = unittest.mock.Mock(return_value={}) - _, _, config = compile_function_config(function, builder_env={}) - get_k8s()._get_project_secrets_raw_data = orig_function + orig_function = get_k8s_helper()._get_project_secrets_raw_data + get_k8s_helper()._get_project_secrets_raw_data = unittest.mock.Mock( + return_value={} + ) + _, _, config = mlrun.api.crud.runtimes.nuclio.function._compile_function_config( + function, builder_env={} + ) + get_k8s_helper()._get_project_secrets_raw_data = orig_function # verify the handler points to mlrun serving wrapper handler assert config["spec"]["handler"].startswith("mlrun.serving") diff --git a/tests/api/runtimes/test_spark.py b/tests/api/runtimes/test_spark.py index 0fac7386c1d8..46b24ca7d211 100644 --- a/tests/api/runtimes/test_spark.py +++ b/tests/api/runtimes/test_spark.py @@ -23,8 +23,8 @@ import pytest import sqlalchemy.orm -import mlrun.api.schemas import mlrun.api.utils.singletons.k8s +import mlrun.common.schemas import mlrun.errors import mlrun.runtimes.pod import tests.api.runtimes.base @@ -89,7 +89,7 @@ def _assert_custom_object_creation_config( expected_code: typing.Optional[str] = None, ): if assert_create_custom_object_called: - mlrun.api.utils.singletons.k8s.get_k8s().crdapi.create_namespaced_custom_object.assert_called_once() + mlrun.api.utils.singletons.k8s.get_k8s_helper().crdapi.create_namespaced_custom_object.assert_called_once() assert self._get_create_custom_object_namespace_arg() == self.namespace @@ -614,23 +614,26 @@ def test_get_offline_features( fstore.get_offline_features( fv, with_indexes=True, - entity_timestamp_column="timestamp", + timestamp_for_filtering="timestamp", engine="remote-spark", run_config=RunConfig(local=False, function=runtime, watch=False), target=ParquetTarget(), ) + self.project = "default" + self._create_project(client) + resp = fstore.get_offline_features( fv, with_indexes=True, - entity_timestamp_column="timestamp", + timestamp_for_filtering="timestamp", engine="spark", # setting watch=False, because we don't want to wait for the job to complete when running in API run_config=RunConfig(local=False, function=runtime, watch=False), target=ParquetTarget(), ) runspec = resp.run.spec.to_dict() - assert runspec == { + expected_runspec = { "parameters": { "vector_uri": "store://feature-vectors/default/my-vector", "target": { @@ -640,24 +643,35 @@ def test_get_offline_features( "max_events": 10000, "flush_after_seconds": 900, }, - "timestamp_column": "timestamp", + "entity_timestamp_column": None, "drop_columns": None, "with_indexes": True, "query": None, - "join_type": "inner", "order_by": None, + "start_time": None, + "end_time": None, + "timestamp_for_filtering": "timestamp", "engine_args": None, }, "outputs": [], "output_path": "v3io:///mypath", - "function": "None/my-vector-merger@e67bf7add40a6bafa25e19a1b80f3d4cc3789eff", "secret_sources": [], + "function": "None/my-vector-merger@349f744e83e1a71d8b1faf4bbf3723dc0625daed", "data_stores": [], "handler": "merge_handler", } + assert ( + deepdiff.DeepDiff( + runspec, + expected_runspec, + # excluding function attribute as it contains hash of the object, excluding this path because any change + # in the structure of the run will require to update the function hash + exclude_paths=["root['function']"], + ) + == {} + ) self.name = "my-vector-merger" - self.project = "default" expected_code = _default_merger_handler.replace( "{{{engine}}}", "SparkFeatureMerger" @@ -697,6 +711,7 @@ def test_run_with_load_source_on_run( # generate runtime and set source code to load on run runtime: mlrun.runtimes.Spark3Runtime = self._generate_runtime() runtime.metadata.name = "test-spark-runtime" + runtime.metadata.project = self.project runtime.spec.build.source = "git://github.com/mock/repo" runtime.spec.build.load_source_on_run = True # expect pre-condition error, not supported diff --git a/tests/api/test_api_states.py b/tests/api/test_api_states.py index e7d11a8909ff..e7e8e9d01877 100644 --- a/tests/api/test_api_states.py +++ b/tests/api/test_api_states.py @@ -16,50 +16,69 @@ import unittest.mock import fastapi.testclient +import pytest import sqlalchemy.orm import mlrun.api.initial_data -import mlrun.api.schemas import mlrun.api.utils.auth.verifier import mlrun.api.utils.db.alembic import mlrun.api.utils.db.backup import mlrun.api.utils.db.sqlite_migration +import mlrun.common.schemas def test_offline_state( db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient ) -> None: - mlrun.mlconf.httpdb.state = mlrun.api.schemas.APIStates.offline + mlrun.mlconf.httpdb.state = mlrun.common.schemas.APIStates.offline response = client.get("healthz") - assert response.status_code == http.HTTPStatus.OK.value + assert response.status_code == http.HTTPStatus.SERVICE_UNAVAILABLE.value response = client.get("projects") assert response.status_code == http.HTTPStatus.PRECONDITION_FAILED.value assert "API is in offline state" in response.text -def test_migrations_states( - db: sqlalchemy.orm.Session, client: fastapi.testclient.TestClient +@pytest.mark.parametrize( + "state, expected_healthz_status_code", + [ + ( + mlrun.common.schemas.APIStates.waiting_for_migrations, + http.HTTPStatus.OK.value, + ), + ( + mlrun.common.schemas.APIStates.migrations_in_progress, + http.HTTPStatus.OK.value, + ), + (mlrun.common.schemas.APIStates.migrations_failed, http.HTTPStatus.OK.value), + ( + mlrun.common.schemas.APIStates.waiting_for_chief, + http.HTTPStatus.SERVICE_UNAVAILABLE.value, + ), + ], +) +def test_api_states( + db: sqlalchemy.orm.Session, + client: fastapi.testclient.TestClient, + state, + expected_healthz_status_code, ) -> None: - expected_message_map = { - mlrun.api.schemas.APIStates.waiting_for_migrations: "API is waiting for migrations to be triggered", - mlrun.api.schemas.APIStates.migrations_in_progress: "Migrations are in progress", - mlrun.api.schemas.APIStates.migrations_failed: "Migrations failed", - } - for state, expected_message in expected_message_map.items(): - mlrun.mlconf.httpdb.state = state - response = client.get("healthz") - assert response.status_code == http.HTTPStatus.OK.value + mlrun.mlconf.httpdb.state = state + response = client.get("healthz") + assert response.status_code == expected_healthz_status_code - response = client.get("projects/some-project/background-tasks/some-task") - assert response.status_code == http.HTTPStatus.NOT_FOUND.value + response = client.get("projects/some-project/background-tasks/some-task") + assert response.status_code == http.HTTPStatus.NOT_FOUND.value - response = client.get("client-spec") - assert response.status_code == http.HTTPStatus.OK.value + response = client.get("client-spec") + assert response.status_code == http.HTTPStatus.OK.value - response = client.get("projects") - assert response.status_code == http.HTTPStatus.PRECONDITION_FAILED.value - assert expected_message in response.text + response = client.get("projects") + expected_message = mlrun.common.schemas.APIStates.description(state) + assert response.status_code == http.HTTPStatus.PRECONDITION_FAILED.value + assert ( + expected_message in response.text + ), f"Expected message: {expected_message}, actual: {response.text}" def test_init_data_migration_required_recognition(monkeypatch) -> None: @@ -176,12 +195,12 @@ def test_init_data_migration_required_recognition(monkeypatch) -> None: ) is_latest_data_version_mock.return_value = not case.get("data_migration", False) - mlrun.mlconf.httpdb.state = mlrun.api.schemas.APIStates.online + mlrun.mlconf.httpdb.state = mlrun.common.schemas.APIStates.online mlrun.api.initial_data.init_data() failure_message = f"Failed in case: {case}" assert ( mlrun.mlconf.httpdb.state - == mlrun.api.schemas.APIStates.waiting_for_migrations + == mlrun.common.schemas.APIStates.waiting_for_migrations ), failure_message # assert the api just changed state and no operation was done assert db_backup_util_mock.call_count == 0, failure_message diff --git a/tests/api/test_collect_runs_logs.py b/tests/api/test_collect_runs_logs.py index 94248fd4309e..c4ebce5eb426 100644 --- a/tests/api/test_collect_runs_logs.py +++ b/tests/api/test_collect_runs_logs.py @@ -324,17 +324,28 @@ async def test_verify_stop_logs_on_startup( log_collector = mlrun.api.utils.clients.log_collector.LogCollectorClient() project_name = "some-project" - run_uids = ["some_uid", "some_uid2", "some_uid3"] - for run_uid in run_uids: + + # iterate over some runs, for each run assign different state + run_uids_to_state = [ + ("some_uid", mlrun.runtimes.constants.RunStates.completed), + ("some_uid2", mlrun.runtimes.constants.RunStates.unknown), + ("some_uid3", mlrun.runtimes.constants.RunStates.completed), + ("some_uid4", mlrun.runtimes.constants.RunStates.completed), + # keep it last, as we later on omit it from the run_uids list + ("some_uid5", mlrun.runtimes.constants.RunStates.running), + ] + for run_uid, state in run_uids_to_state: _create_new_run( db, project_name, uid=run_uid, name=run_uid, kind="job", - state=mlrun.runtimes.constants.RunStates.completed, + state=state, ) + run_uids = [run_uid for run_uid, _ in run_uids_to_state] + # update requested logs field to True mlrun.api.utils.singletons.db.get_db().update_runs_requested_logs( db, run_uids, True @@ -345,7 +356,7 @@ async def test_verify_stop_logs_on_startup( requested_logs_modes=[True], only_uids=False, ) - assert len(runs) == 3 + assert len(runs) == 5 log_collector._call = unittest.mock.AsyncMock(return_value=None) @@ -355,7 +366,10 @@ async def test_verify_stop_logs_on_startup( assert log_collector._call.call_args[0][0] == "StopLogs" stop_log_request = log_collector._call.call_args[0][1] assert stop_log_request.project == project_name - assert len(stop_log_request.runUIDs) == 3 + + # one of the runs is in running state + run_uids = run_uids[: len(run_uids) - 1] + assert len(stop_log_request.runUIDs) == len(run_uids) assert ( deepdiff.DeepDiff( list(stop_log_request.runUIDs), @@ -375,7 +389,7 @@ async def test_verify_stop_logs_on_startup( requested_logs_modes=[True], only_uids=False, ) - assert len(runs) == 2 + assert len(runs) == 4 await mlrun.api.main._verify_log_collection_stopped_on_startup() @@ -383,7 +397,7 @@ async def test_verify_stop_logs_on_startup( assert log_collector._call.call_args[0][0] == "StopLogs" stop_log_request = log_collector._call.call_args[0][1] assert stop_log_request.project == project_name - assert len(stop_log_request.runUIDs) == 2 + assert len(stop_log_request.runUIDs) == 3 assert ( deepdiff.DeepDiff( list(stop_log_request.runUIDs), diff --git a/tests/api/test_initial_data.py b/tests/api/test_initial_data.py index 72884da5352c..74f332df8846 100644 --- a/tests/api/test_initial_data.py +++ b/tests/api/test_initial_data.py @@ -24,8 +24,8 @@ import mlrun.api.db.sqldb.db import mlrun.api.db.sqldb.session import mlrun.api.initial_data -import mlrun.api.schemas import mlrun.api.utils.singletons.db +import mlrun.common.schemas def test_add_data_version_empty_db(): @@ -54,8 +54,8 @@ def test_add_data_version_non_empty_db(): # fill db db.create_project( db_session, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name="project-name"), + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name="project-name"), ), ) mlrun.api.initial_data._add_initial_data(db_session) @@ -69,25 +69,43 @@ def test_perform_data_migrations_from_zero_version(): # set version to 0 db.create_data_version(db_session, "0") + # keep a reference to the original functions, so we can restore them later original_perform_version_1_data_migrations = ( mlrun.api.initial_data._perform_version_1_data_migrations ) mlrun.api.initial_data._perform_version_1_data_migrations = unittest.mock.Mock() + original_perform_version_2_data_migrations = ( + mlrun.api.initial_data._perform_version_2_data_migrations + ) + mlrun.api.initial_data._perform_version_2_data_migrations = unittest.mock.Mock() + original_perform_version_3_data_migrations = ( + mlrun.api.initial_data._perform_version_3_data_migrations + ) + mlrun.api.initial_data._perform_version_3_data_migrations = unittest.mock.Mock() + # perform migrations mlrun.api.initial_data._perform_data_migrations(db_session) - mlrun.api.initial_data._perform_version_1_data_migrations.assert_called_once() - - # calling again should trigger migrations again + # calling again should not trigger migrations again, since we're already at the latest version mlrun.api.initial_data._perform_data_migrations(db_session) mlrun.api.initial_data._perform_version_1_data_migrations.assert_called_once() + mlrun.api.initial_data._perform_version_2_data_migrations.assert_called_once() + mlrun.api.initial_data._perform_version_3_data_migrations.assert_called_once() + + assert db.get_current_data_version(db_session, raise_on_not_found=True) == str( + mlrun.api.initial_data.latest_data_version + ) + # restore original functions mlrun.api.initial_data._perform_version_1_data_migrations = ( original_perform_version_1_data_migrations ) - assert db.get_current_data_version(db_session, raise_on_not_found=True) == str( - mlrun.api.initial_data.latest_data_version + mlrun.api.initial_data._perform_version_2_data_migrations = ( + original_perform_version_2_data_migrations + ) + mlrun.api.initial_data._perform_version_3_data_migrations = ( + original_perform_version_3_data_migrations ) @@ -122,8 +140,8 @@ def test_resolve_current_data_version_before_and_after_projects(table_exists, db # fill db db.create_project( db_session, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name="project-name"), + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name="project-name"), ), ) assert mlrun.api.initial_data._resolve_current_data_version(db, db_session) == 1 @@ -135,11 +153,10 @@ def _initialize_db_without_migrations() -> typing.Tuple[ ]: dsn = "sqlite:///:memory:?check_same_thread=false" mlrun.mlconf.httpdb.dsn = dsn - mlrun.api.db.sqldb.session._init_engine(dsn) - + mlrun.api.db.sqldb.session._init_engine(dsn=dsn) mlrun.api.utils.singletons.db.initialize_db() - db_session = mlrun.api.db.sqldb.session.create_session() + db_session = mlrun.api.db.sqldb.session.create_session(dsn=dsn) db = mlrun.api.db.sqldb.db.SQLDB(dsn) db.initialize(db_session) - mlrun.api.db.init_db.init_db(db_session) + mlrun.api.db.init_db.init_db() return db, db_session diff --git a/tests/api/utils/auth/providers/test_opa.py b/tests/api/utils/auth/providers/test_opa.py index 84087ad62688..185c86eeeae3 100644 --- a/tests/api/utils/auth/providers/test_opa.py +++ b/tests/api/utils/auth/providers/test_opa.py @@ -19,8 +19,8 @@ import deepdiff import pytest -import mlrun.api.schemas import mlrun.api.utils.auth.providers.opa +import mlrun.common.schemas import mlrun.config import mlrun.errors @@ -74,8 +74,8 @@ async def test_query_permissions_success( opa_provider: mlrun.api.utils.auth.providers.opa.Provider, ): resource = "/projects/project-name/functions/function-name" - action = mlrun.api.schemas.AuthorizationAction.create - auth_info = mlrun.api.schemas.AuthInfo( + action = mlrun.common.schemas.AuthorizationAction.create + auth_info = mlrun.common.schemas.AuthInfo( user_id="user-id", user_group_ids=["user-group-id-1", "user-group-id-2"] ) @@ -128,8 +128,8 @@ async def test_filter_by_permission( allowed_opa_resources = [ resource["opa_resource"] for resource in expected_allowed_resources ] - action = mlrun.api.schemas.AuthorizationAction.create - auth_info = mlrun.api.schemas.AuthInfo( + action = mlrun.common.schemas.AuthorizationAction.create + auth_info = mlrun.common.schemas.AuthInfo( user_id="user-id", user_group_ids=["user-group-id-1", "user-group-id-2"] ) @@ -174,8 +174,8 @@ async def test_query_permissions_failure( requests_mock: aioresponses.aioresponses, ): resource = "/projects/project-name/functions/function-name" - action = mlrun.api.schemas.AuthorizationAction.create - auth_info = mlrun.api.schemas.AuthInfo( + action = mlrun.common.schemas.AuthorizationAction.create + auth_info = mlrun.common.schemas.AuthInfo( user_id="user-id", user_group_ids=["user-group-id-1", "user-group-id-2"] ) @@ -211,7 +211,7 @@ async def test_query_permissions_use_cache( permission_query_path: str, opa_provider: mlrun.api.utils.auth.providers.opa.Provider, ): - auth_info = mlrun.api.schemas.AuthInfo(user_id="user-id") + auth_info = mlrun.common.schemas.AuthInfo(user_id="user-id") project_name = "project-name" opa_provider.add_allowed_project_for_owner(project_name, auth_info) @@ -219,7 +219,7 @@ async def test_query_permissions_use_cache( assert ( await opa_provider.query_permissions( f"/projects/{project_name}/resource", - mlrun.api.schemas.AuthorizationAction.create, + mlrun.common.schemas.AuthorizationAction.create, auth_info, ) is True @@ -232,7 +232,7 @@ def test_allowed_project_owners_cache( permission_query_path: str, opa_provider: mlrun.api.utils.auth.providers.opa.Provider, ): - auth_info = mlrun.api.schemas.AuthInfo(user_id="user-id") + auth_info = mlrun.common.schemas.AuthInfo(user_id="user-id") project_name = "project-name" opa_provider.add_allowed_project_for_owner(project_name, auth_info) # ensure nothing is wrong with adding the same project twice @@ -252,7 +252,7 @@ def test_allowed_project_owners_cache( assert ( opa_provider._check_allowed_project_owners_cache( f"/projects/{project_name}/resource", - mlrun.api.schemas.AuthInfo(user_id="other-user-id"), + mlrun.common.schemas.AuthInfo(user_id="other-user-id"), ) is False ) @@ -263,7 +263,7 @@ def test_allowed_project_owners_cache_ttl_refresh( permission_query_path: str, opa_provider: mlrun.api.utils.auth.providers.opa.Provider, ): - auth_info = mlrun.api.schemas.AuthInfo(user_id="user-id") + auth_info = mlrun.common.schemas.AuthInfo(user_id="user-id") opa_provider._allowed_project_owners_cache_ttl_seconds = 1 project_name = "project-name" opa_provider.add_allowed_project_for_owner(project_name, auth_info) @@ -291,8 +291,8 @@ def test_allowed_project_owners_cache_clean_expired( permission_query_path: str, opa_provider: mlrun.api.utils.auth.providers.opa.Provider, ): - auth_info = mlrun.api.schemas.AuthInfo(user_id="user-id") - auth_info_2 = mlrun.api.schemas.AuthInfo(user_id="user-id-2") + auth_info = mlrun.common.schemas.AuthInfo(user_id="user-id") + auth_info_2 = mlrun.common.schemas.AuthInfo(user_id="user-id-2") opa_provider._allowed_project_owners_cache_ttl_seconds = 2 project_name = "project-name" project_name_2 = "project-name-2" diff --git a/tests/api/utils/clients/test_chief.py b/tests/api/utils/clients/test_chief.py index 8e9d64171632..fba3c932fa5e 100644 --- a/tests/api/utils/clients/test_chief.py +++ b/tests/api/utils/clients/test_chief.py @@ -25,8 +25,8 @@ from aiohttp import ClientConnectorError from aiohttp.test_utils import TestClient, TestServer -import mlrun.api.schemas import mlrun.api.utils.clients.chief +import mlrun.common.schemas import mlrun.config import mlrun.errors from tests.common_fixtures import aioresponses_mock @@ -72,10 +72,12 @@ async def test_get_background_task_from_chief_success( assert response.status_code == http.HTTPStatus.OK background_task = _transform_response_to_background_task(response) assert background_task.metadata.name == task_name - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.running + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.running + ) assert background_task.metadata.created == background_schema.metadata.created - background_schema.status.state = mlrun.api.schemas.BackgroundTaskState.succeeded + background_schema.status.state = mlrun.common.schemas.BackgroundTaskState.succeeded background_schema.metadata.updated = datetime.datetime.utcnow() response_body = fastapi.encoders.jsonable_encoder(background_schema) aioresponses_mock.get( @@ -86,7 +88,8 @@ async def test_get_background_task_from_chief_success( background_task = _transform_response_to_background_task(response) assert background_task.metadata.name == task_name assert ( - background_task.status.state == mlrun.api.schemas.BackgroundTaskState.succeeded + background_task.status.state + == mlrun.common.schemas.BackgroundTaskState.succeeded ) assert background_task.metadata.created == background_schema.metadata.created assert background_task.metadata.updated == background_schema.metadata.updated @@ -159,10 +162,12 @@ async def test_trigger_migration_succeeded( assert response.status_code == http.HTTPStatus.ACCEPTED background_task = _transform_response_to_background_task(response) assert background_task.metadata.name == task_name - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.running + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.running + ) assert background_task.metadata.created == background_schema.metadata.created - background_schema.status.state = mlrun.api.schemas.BackgroundTaskState.succeeded + background_schema.status.state = mlrun.common.schemas.BackgroundTaskState.succeeded background_schema.metadata.updated = datetime.datetime.utcnow() response_body = fastapi.encoders.jsonable_encoder(background_schema) aioresponses_mock.post( @@ -175,7 +180,8 @@ async def test_trigger_migration_succeeded( background_task = _transform_response_to_background_task(response) assert background_task.metadata.name == task_name assert ( - background_task.status.state == mlrun.api.schemas.BackgroundTaskState.succeeded + background_task.status.state + == mlrun.common.schemas.BackgroundTaskState.succeeded ) assert background_task.metadata.created == background_schema.metadata.created assert background_task.metadata.updated == background_schema.metadata.updated @@ -229,12 +235,14 @@ async def test_trigger_migrations_chief_restarted_while_executing_migrations( assert response.status_code == http.HTTPStatus.ACCEPTED background_task = _transform_response_to_background_task(response) assert background_task.metadata.name == task_name - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.running + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.running + ) assert background_task.metadata.created == background_schema.metadata.created # in internal background tasks, failed state is only when the background task doesn't exists in memory, # which means the api was restarted - background_schema.status.state = mlrun.api.schemas.BackgroundTaskState.failed + background_schema.status.state = mlrun.common.schemas.BackgroundTaskState.failed response_body = fastapi.encoders.jsonable_encoder(background_schema) aioresponses_mock.get( f"{api_url}/api/v1/background-tasks/{task_name}", payload=response_body @@ -243,29 +251,31 @@ async def test_trigger_migrations_chief_restarted_while_executing_migrations( assert response.status_code == http.HTTPStatus.OK background_task = _transform_response_to_background_task(response) assert background_task.metadata.name == task_name - assert background_task.status.state == mlrun.api.schemas.BackgroundTaskState.failed + assert ( + background_task.status.state == mlrun.common.schemas.BackgroundTaskState.failed + ) assert background_task.metadata.created == background_schema.metadata.created def _transform_response_to_background_task(response: fastapi.Response): decoded_body = response.body.decode("utf-8") body_dict = json.loads(decoded_body) - return mlrun.api.schemas.BackgroundTask(**body_dict) + return mlrun.common.schemas.BackgroundTask(**body_dict) def _generate_background_task( background_task_name, - state: mlrun.api.schemas.BackgroundTaskState = mlrun.api.schemas.BackgroundTaskState.running, -) -> mlrun.api.schemas.BackgroundTask: + state: mlrun.common.schemas.BackgroundTaskState = mlrun.common.schemas.BackgroundTaskState.running, +) -> mlrun.common.schemas.BackgroundTask: now = datetime.datetime.utcnow() - return mlrun.api.schemas.BackgroundTask( - metadata=mlrun.api.schemas.BackgroundTaskMetadata( + return mlrun.common.schemas.BackgroundTask( + metadata=mlrun.common.schemas.BackgroundTaskMetadata( name=background_task_name, created=now, updated=now, ), - status=mlrun.api.schemas.BackgroundTaskStatus(state=state.value), - spec=mlrun.api.schemas.BackgroundTaskSpec(), + status=mlrun.common.schemas.BackgroundTaskStatus(state=state.value), + spec=mlrun.common.schemas.BackgroundTaskSpec(), ) diff --git a/tests/api/utils/clients/test_iguazio.py b/tests/api/utils/clients/test_iguazio.py index 3fbe2d31ef17..b03821680981 100644 --- a/tests/api/utils/clients/test_iguazio.py +++ b/tests/api/utils/clients/test_iguazio.py @@ -27,41 +27,14 @@ from aioresponses import CallbackResult from requests.cookies import cookiejar_from_dict -import mlrun.api.schemas import mlrun.api.utils.clients.iguazio +import mlrun.common.schemas import mlrun.config import mlrun.errors from mlrun.api.utils.asyncio import maybe_coroutine from tests.common_fixtures import aioresponses_mock -@pytest.fixture() -async def api_url() -> str: - api_url = "http://iguazio-api-url:8080" - mlrun.config.config._iguazio_api_url = api_url - return api_url - - -@pytest.fixture() -async def iguazio_client( - api_url: str, - request, -) -> mlrun.api.utils.clients.iguazio.Client: - if request.param == "async": - client = mlrun.api.utils.clients.iguazio.AsyncClient() - else: - client = mlrun.api.utils.clients.iguazio.Client() - - # force running init again so the configured api url will be used - client.__init__() - client._wait_for_job_completion_retry_interval = 0 - client._wait_for_project_terminal_state_retry_interval = 0 - - # inject the request param into client, so we can use it in tests - setattr(client, "mode", request.param) - return client - - def patch_restful_request( is_client_sync: bool, requests_mock: requests_mock_package.Mocker, @@ -441,7 +414,6 @@ def verify_list(request, context): "filter[updated_at]": [ f"[$gt]{updated_after.isoformat().split('+')[0]}Z".lower() ], - "include": ["owner"], "page[size]": [ str( mlrun.mlconf.httpdb.projects.iguazio_list_projects_default_page_size @@ -457,6 +429,15 @@ def verify_list(request, context): f"{api_url}/api/projects", json=verify_list, ) + + requests_mock.get( + f"{api_url}/api/projects/__name__/{project.metadata.name}", + json={ + "data": _build_project_response( + iguazio_client, project, with_mlrun_project=True + ) + }, + ) await maybe_coroutine( iguazio_client.list_projects( session, @@ -489,22 +470,35 @@ async def test_list_project( "annotations": {"annotation-key2": "annotation-value2"}, }, ] + project_objects = [ + _generate_project( + mock_project["name"], + mock_project.get("description", ""), + mock_project.get("labels", {}), + mock_project.get("annotations", {}), + owner=mock_project.get("owner", None), + ) + for mock_project in mock_projects + ] response_body = { "data": [ _build_project_response( iguazio_client, - _generate_project( - mock_project["name"], - mock_project.get("description", ""), - mock_project.get("labels", {}), - mock_project.get("annotations", {}), - owner=mock_project.get("owner", None), - ), + mock_project, ) - for mock_project in mock_projects + for mock_project in project_objects ] } requests_mock.get(f"{api_url}/api/projects", json=response_body) + for mock_project in project_objects: + requests_mock.get( + f"{api_url}/api/projects/__name__/{mock_project.metadata.name}", + json={ + "data": _build_project_response( + iguazio_client, mock_project, with_mlrun_project=True + ) + }, + ) projects, latest_updated_at = await maybe_coroutine( iguazio_client.list_projects(None) ) @@ -641,8 +635,8 @@ async def test_create_project_minimal_project( iguazio_client: mlrun.api.utils.clients.iguazio.Client, requests_mock: requests_mock_package.Mocker, ): - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name="some-name", ), ) @@ -827,7 +821,7 @@ async def test_format_as_leader_project( ) assert ( deepdiff.DeepDiff( - _build_project_response(iguazio_client, project), + _build_project_response(iguazio_client, project, with_mlrun_project=True), iguazio_project.data, ignore_order=True, exclude_paths=[ @@ -856,7 +850,7 @@ def _generate_session_verification_response_headers( def _assert_auth_info_from_session_verification_mock_response_headers( - auth_info: mlrun.api.schemas.AuthInfo, response_headers: dict + auth_info: mlrun.common.schemas.AuthInfo, response_headers: dict ): _assert_auth_info( auth_info, @@ -869,7 +863,7 @@ def _assert_auth_info_from_session_verification_mock_response_headers( def _assert_auth_info( - auth_info: mlrun.api.schemas.AuthInfo, + auth_info: mlrun.common.schemas.AuthInfo, username: str, session: str, data_session: str, @@ -888,7 +882,7 @@ async def _create_project_and_assert( api_url: str, iguazio_client: mlrun.api.utils.clients.iguazio.Client, requests_mock: requests_mock_package.Mocker, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): session = "1234" job_id = "1d4c9d25-9c5c-4a34-b052-c1d3665fec5e" @@ -920,7 +914,7 @@ def _verify_deletion(project_name, session, job_id, request, context): assert request.json()["data"]["attributes"]["name"] == project_name assert ( request.headers["igz-project-deletion-strategy"] - == mlrun.api.schemas.DeletionStrategy.default().to_iguazio_deletion_strategy() + == mlrun.common.schemas.DeletionStrategy.default().to_iguazio_deletion_strategy() ) _verify_project_request_headers(request.headers, session) context.status_code = http.HTTPStatus.ACCEPTED.value @@ -933,7 +927,7 @@ def _verify_creation(iguazio_client, project, session, job_id, request, context) _verify_project_request_headers(request.headers, session) return { "data": _build_project_response( - iguazio_client, project, job_id, mlrun.api.schemas.ProjectState.creating + iguazio_client, project, job_id, mlrun.common.schemas.ProjectState.creating ) } @@ -961,7 +955,7 @@ def _verify_request_cookie(headers: dict, session: str): def _verify_project_request_headers(headers: dict, session: str): _verify_request_cookie(headers, session) - assert headers[mlrun.api.schemas.HeaderNames.projects_role] == "mlrun" + assert headers[mlrun.common.schemas.HeaderNames.projects_role] == "mlrun" def _mock_job_progress( @@ -1006,7 +1000,7 @@ def _generate_project( annotations=None, created=None, owner="project-owner", -) -> mlrun.api.schemas.Project: +) -> mlrun.common.schemas.Project: if labels is None: labels = { "some-label": "some-label-value", @@ -1015,21 +1009,21 @@ def _generate_project( annotations = { "some-annotation": "some-annotation-value", } - return mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + return mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name=name, created=created or datetime.datetime.utcnow(), labels=labels, annotations=annotations, some_extra_field="some value", ), - spec=mlrun.api.schemas.ProjectSpec( + spec=mlrun.common.schemas.ProjectSpec( description=description, - desired_state=mlrun.api.schemas.ProjectState.online, + desired_state=mlrun.common.schemas.ProjectState.online, owner=owner, some_extra_field="some value", ), - status=mlrun.api.schemas.ProjectStatus( + status=mlrun.common.schemas.ProjectStatus( some_extra_field="some value", ), ) @@ -1037,10 +1031,11 @@ def _generate_project( def _build_project_response( iguazio_client: mlrun.api.utils.clients.iguazio.Client, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, job_id: typing.Optional[str] = None, - operational_status: typing.Optional[mlrun.api.schemas.ProjectState] = None, + operational_status: typing.Optional[mlrun.common.schemas.ProjectState] = None, owner_access_key: typing.Optional[str] = None, + with_mlrun_project: bool = False, ): body = { "type": "project", @@ -1051,12 +1046,15 @@ def _build_project_response( else datetime.datetime.utcnow().isoformat(), "updated_at": datetime.datetime.utcnow().isoformat(), "admin_status": project.spec.desired_state - or mlrun.api.schemas.ProjectState.online, - "mlrun_project": iguazio_client._transform_mlrun_project_to_iguazio_mlrun_project_attribute( - project - ), + or mlrun.common.schemas.ProjectState.online, }, } + if with_mlrun_project: + body["attributes"][ + "mlrun_project" + ] = iguazio_client._transform_mlrun_project_to_iguazio_mlrun_project_attribute( + project + ) if project.spec.description: body["attributes"]["description"] = project.spec.description if project.spec.owner: @@ -1090,7 +1088,7 @@ def _build_project_response( def _assert_project_creation( iguazio_client: mlrun.api.utils.clients.iguazio.Client, request_body: dict, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): assert request_body["data"]["attributes"]["name"] == project.metadata.name assert request_body["data"]["attributes"]["description"] == project.spec.description diff --git a/tests/api/utils/clients/test_log_collector.py b/tests/api/utils/clients/test_log_collector.py index 772a0842d93a..db4ea4919dcd 100644 --- a/tests/api/utils/clients/test_log_collector.py +++ b/tests/api/utils/clients/test_log_collector.py @@ -21,8 +21,8 @@ import sqlalchemy.orm.session import mlrun -import mlrun.api.schemas import mlrun.api.utils.clients.log_collector +import mlrun.common.schemas class BaseLogCollectorResponse: @@ -67,7 +67,7 @@ def __init__(self, success, error, has_logs): mlrun.mlconf.log_collector.address = "http://localhost:8080" -mlrun.mlconf.log_collector.mode = mlrun.api.schemas.LogsCollectorMode.sidecar +mlrun.mlconf.log_collector.mode = mlrun.common.schemas.LogsCollectorMode.sidecar class TestLogCollector: @@ -150,6 +150,43 @@ async def test_get_logs( async for log in log_stream: assert log == b"" + @pytest.mark.asyncio + async def test_get_log_with_retryable_error( + self, db: sqlalchemy.orm.session.Session, client: fastapi.testclient.TestClient + ): + run_uid = "123" + project_name = "some-project" + log_collector = mlrun.api.utils.clients.log_collector.LogCollectorClient() + + # mock responses for HasLogs to return a retryable error + log_collector._call = unittest.mock.AsyncMock( + return_value=HasLogsResponse( + False, + "readdirent /var/mlrun/logs/blabla: resource temporarily unavailable", + True, + ) + ) + + log_stream = log_collector.get_logs( + run_uid=run_uid, project=project_name, raise_on_error=False + ) + async for log in log_stream: + assert log == b"" + + # mock responses for HasLogs to return a retryable error + log_collector._call = unittest.mock.AsyncMock( + return_value=HasLogsResponse( + False, + "I'm an error that should not be retried", + True, + ) + ) + with pytest.raises(mlrun.errors.MLRunInternalServerError): + async for log in log_collector.get_logs( + run_uid=run_uid, project=project_name + ): + assert log == b"" # should not get here + @pytest.mark.asyncio async def test_stop_logs( self, db: sqlalchemy.orm.session.Session, client: fastapi.testclient.TestClient diff --git a/tests/api/utils/clients/test_nuclio.py b/tests/api/utils/clients/test_nuclio.py index 0ab235d5242b..3dd763a8b289 100644 --- a/tests/api/utils/clients/test_nuclio.py +++ b/tests/api/utils/clients/test_nuclio.py @@ -18,8 +18,8 @@ import pytest import requests_mock as requests_mock_package -import mlrun.api.schemas import mlrun.api.utils.clients.nuclio +import mlrun.common.schemas import mlrun.config import mlrun.errors @@ -179,13 +179,13 @@ def verify_creation(request, context): requests_mock.post(f"{api_url}/api/projects", json=verify_creation) nuclio_client.create_project( None, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name=project_name, labels=project_labels, annotations=project_annotations, ), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ), ) @@ -230,13 +230,13 @@ def verify_store_creation(request, context): nuclio_client.store_project( None, project_name, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name=project_name, labels=project_labels, annotations=project_annotations, ), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ), ) @@ -281,13 +281,13 @@ def verify_store_update(request, context): nuclio_client.store_project( None, project_name, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata( + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( name=project_name, labels=project_labels, annotations=project_annotations, ), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ), ) @@ -399,7 +399,7 @@ def verify_deletion(request, context): ) assert ( request.headers["x-nuclio-delete-project-strategy"] - == mlrun.api.schemas.DeletionStrategy.default().to_nuclio_deletion_strategy() + == mlrun.common.schemas.DeletionStrategy.default().to_nuclio_deletion_strategy() ) context.status_code = http.HTTPStatus.NO_CONTENT.value diff --git a/tests/api/utils/events/__init__.py b/tests/api/utils/events/__init__.py new file mode 100644 index 000000000000..33c5b3d3bd7c --- /dev/null +++ b/tests/api/utils/events/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/api/utils/events/test_events_client.py b/tests/api/utils/events/test_events_client.py new file mode 100644 index 000000000000..56e1b680fb8e --- /dev/null +++ b/tests/api/utils/events/test_events_client.py @@ -0,0 +1,117 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest.mock + +import fastapi.testclient +import pytest +import sqlalchemy.orm + +import mlrun.api.crud +import mlrun.api.utils.clients.iguazio +import mlrun.api.utils.events.events_factory +import mlrun.common.schemas +import tests.api.conftest + + +class TestEventClient: + @pytest.mark.parametrize( + "iguazio_version", + [ + "3.5.4", + "3.5.3", + None, + ], + ) + def test_create_project_auth_secret( + self, + monkeypatch, + db: sqlalchemy.orm.Session, + client: fastapi.testclient.TestClient, + k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, + iguazio_version: str, + ): + self._initialize_and_mock_client(monkeypatch, iguazio_version) + + username = "some-username" + access_key = "some-access-key" + mlrun.api.crud.Secrets().store_auth_secret( + mlrun.common.schemas.AuthSecretData( + provider=mlrun.common.schemas.SecretProviderName.kubernetes, + username=username, + access_key=access_key, + ) + ) + self._assert_client_was_called(iguazio_version) + + @pytest.mark.parametrize( + "iguazio_version", + [ + "3.5.4", + "3.5.3", + None, + ], + ) + def test_create_project_secret( + self, + monkeypatch, + db: sqlalchemy.orm.Session, + client: fastapi.testclient.TestClient, + k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, + iguazio_version: str, + ): + self._initialize_and_mock_client(monkeypatch, iguazio_version) + + project = "project-name" + valid_secret_key = "valid-key" + valid_secret_value = "some-value-5" + provider = mlrun.common.schemas.SecretProviderName.kubernetes + key_map_secret_key = ( + mlrun.api.crud.Secrets().generate_client_key_map_project_secret_key( + mlrun.api.crud.SecretsClientType.schedules + ) + ) + mlrun.api.crud.Secrets().store_project_secrets( + project, + mlrun.common.schemas.SecretsData( + provider=provider, secrets={valid_secret_key: valid_secret_value} + ), + allow_internal_secrets=True, + key_map_secret_key=key_map_secret_key, + ) + + self._assert_client_was_called(iguazio_version) + + def _initialize_and_mock_client(self, monkeypatch, iguazio_version: str): + mlrun.mlconf.events.mode = mlrun.common.schemas.EventsModes.enabled.value + self._initialize_client(iguazio_version) + self.client.emit = unittest.mock.MagicMock() + monkeypatch.setattr( + mlrun.api.utils.events.events_factory.EventsFactory, + "get_events_client", + lambda *args, **kwargs: self.client, + ) + + def _initialize_client(self, version: str = None): + mlrun.mlconf.igz_version = version + self.client = ( + mlrun.api.utils.events.events_factory.EventsFactory.get_events_client() + ) + + def _assert_client_was_called(self, iguazio_version: str): + self.client.emit.assert_called_once() + if iguazio_version: + assert self.client.emit.call_args[0][0].description + else: + assert self.client.emit.call_args[0][0] is None diff --git a/tests/api/utils/events/test_events_factory.py b/tests/api/utils/events/test_events_factory.py new file mode 100644 index 000000000000..41a0fc4dfc7a --- /dev/null +++ b/tests/api/utils/events/test_events_factory.py @@ -0,0 +1,73 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pytest + +import mlrun.api.utils.events.base +import mlrun.api.utils.events.events_factory +import mlrun.api.utils.events.iguazio +import mlrun.api.utils.events.nop +import mlrun.common.schemas + + +@pytest.mark.parametrize( + "events_mode,kind,igz_version,expected_error,expected_instance", + [ + ( + mlrun.common.schemas.EventsModes.disabled, + None, + None, + None, + mlrun.api.utils.events.nop.NopClient, + ), + ( + mlrun.common.schemas.EventsModes.enabled, + None, + None, + None, + mlrun.api.utils.events.nop.NopClient, + ), + ( + mlrun.common.schemas.EventsModes.enabled, + mlrun.common.schemas.EventClientKinds.iguazio, + None, + mlrun.errors.MLRunInvalidArgumentError, + None, + ), + ( + mlrun.common.schemas.EventsModes.enabled, + mlrun.common.schemas.EventClientKinds.iguazio, + "3.5.3", + None, + mlrun.api.utils.events.iguazio.Client, + ), + ], +) +def test_get_events_client( + events_mode: mlrun.common.schemas.EventsModes, + kind: mlrun.common.schemas.EventClientKinds, + igz_version: str, + expected_error: mlrun.errors.MLRunBaseError, + expected_instance: mlrun.api.utils.events.base.BaseEventClient, +): + mlrun.mlconf.events.mode = events_mode.value + mlrun.mlconf.igz_version = igz_version + if expected_error: + with pytest.raises(expected_error): + mlrun.api.utils.events.events_factory.EventsFactory.get_events_client(kind) + else: + instance = ( + mlrun.api.utils.events.events_factory.EventsFactory.get_events_client(kind) + ) + assert isinstance(instance, expected_instance) diff --git a/tests/api/utils/projects/test_follower_member.py b/tests/api/utils/projects/test_follower_member.py index 26e13bd3cf15..a92a8cfa253c 100644 --- a/tests/api/utils/projects/test_follower_member.py +++ b/tests/api/utils/projects/test_follower_member.py @@ -22,11 +22,11 @@ import sqlalchemy.orm import mlrun.api.crud -import mlrun.api.schemas import mlrun.api.utils.projects.follower import mlrun.api.utils.projects.remotes.leader import mlrun.api.utils.singletons.db import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.config import mlrun.errors import tests.api.conftest @@ -63,25 +63,25 @@ def test_sync_projects( ): project_nothing_changed = _generate_project(name="project-nothing-changed") project_in_creation = _generate_project( - name="project-in-creation", state=mlrun.api.schemas.ProjectState.creating + name="project-in-creation", state=mlrun.common.schemas.ProjectState.creating ) project_in_deletion = _generate_project( - name="project-in-deletion", state=mlrun.api.schemas.ProjectState.deleting + name="project-in-deletion", state=mlrun.common.schemas.ProjectState.deleting ) project_will_be_in_deleting = _generate_project( name="project-will-be-in-deleting", - state=mlrun.api.schemas.ProjectState.creating, + state=mlrun.common.schemas.ProjectState.creating, ) project_moved_to_deletion = _generate_project( name=project_will_be_in_deleting.metadata.name, - state=mlrun.api.schemas.ProjectState.deleting, + state=mlrun.common.schemas.ProjectState.deleting, ) project_will_be_offline = _generate_project( - name="project-will-be-offline", state=mlrun.api.schemas.ProjectState.online + name="project-will-be-offline", state=mlrun.common.schemas.ProjectState.online ) project_offline = _generate_project( name=project_will_be_offline.metadata.name, - state=mlrun.api.schemas.ProjectState.offline, + state=mlrun.common.schemas.ProjectState.offline, ) project_only_in_db = _generate_project(name="only-in-db") for _project in [ @@ -197,7 +197,7 @@ def test_patch_project( db, project.metadata.name, {"spec": {"description": patched_description}} ) expected_patched_project = _generate_project(description=patched_description) - expected_patched_project.status.state = mlrun.api.schemas.ProjectState.online + expected_patched_project.status.state = mlrun.common.schemas.ProjectState.online _assert_projects_equal(expected_patched_project, patched_project) _assert_project_in_follower(db, projects_follower, expected_patched_project) @@ -274,8 +274,8 @@ def test_list_project( project = _generate_project(name="name-1", owner=owner) archived_project = _generate_project( name="name-2", - desired_state=mlrun.api.schemas.ProjectDesiredState.archived, - state=mlrun.api.schemas.ProjectState.archived, + desired_state=mlrun.common.schemas.ProjectDesiredState.archived, + state=mlrun.common.schemas.ProjectState.archived, owner=owner, ) label_key = "key" @@ -283,8 +283,8 @@ def test_list_project( labeled_project = _generate_project(name="name-3", labels={label_key: label_value}) archived_and_labeled_project = _generate_project( name="name-4", - desired_state=mlrun.api.schemas.ProjectDesiredState.archived, - state=mlrun.api.schemas.ProjectState.archived, + desired_state=mlrun.common.schemas.ProjectDesiredState.archived, + state=mlrun.common.schemas.ProjectState.archived, labels={label_key: label_value}, ) all_projects = { @@ -309,7 +309,7 @@ def test_list_project( db, projects_follower, [archived_project, archived_and_labeled_project], - state=mlrun.api.schemas.ProjectState.archived, + state=mlrun.common.schemas.ProjectState.archived, ) # list by owner @@ -373,7 +373,7 @@ def test_list_project( db, projects_follower, [archived_and_labeled_project], - state=mlrun.api.schemas.ProjectState.archived, + state=mlrun.common.schemas.ProjectState.archived, labels=[f"{label_key}={label_value}", label_key], ) @@ -385,7 +385,7 @@ async def test_list_project_summaries( nop_leader: mlrun.api.utils.projects.remotes.leader.Member, ): project = _generate_project(name="name-1") - project_summary = mlrun.api.schemas.ProjectSummary( + project_summary = mlrun.common.schemas.ProjectSummary( name=project.metadata.name, files_count=4, feature_sets_count=5, @@ -423,7 +423,7 @@ async def test_list_project_summaries_fails_to_list_pipeline_runs( project_name = "project-name" _generate_project(name=project_name) mlrun.api.utils.singletons.db.get_db().list_projects = unittest.mock.Mock( - return_value=mlrun.api.schemas.ProjectsOutput(projects=[project_name]) + return_value=mlrun.common.schemas.ProjectsOutput(projects=[project_name]) ) mlrun.api.crud.projects.Projects()._list_pipelines = unittest.mock.Mock( side_effect=mlrun.errors.MLRunNotFoundError("not found") @@ -446,12 +446,12 @@ def test_list_project_leader_format( ): project = _generate_project(name="name-1") mlrun.api.utils.singletons.db.get_db().list_projects = unittest.mock.Mock( - return_value=mlrun.api.schemas.ProjectsOutput(projects=[project]) + return_value=mlrun.common.schemas.ProjectsOutput(projects=[project]) ) projects = projects_follower.list_projects( db, - format_=mlrun.api.schemas.ProjectsFormat.leader, - projects_role=mlrun.api.schemas.ProjectsRole.nop, + format_=mlrun.common.schemas.ProjectsFormat.leader, + projects_role=mlrun.common.schemas.ProjectsRole.nop, ) assert ( deepdiff.DeepDiff( @@ -466,7 +466,7 @@ def test_list_project_leader_format( def _assert_list_projects( db_session: sqlalchemy.orm.Session, projects_follower: mlrun.api.utils.projects.follower.Member, - expected_projects: typing.List[mlrun.api.schemas.Project], + expected_projects: typing.List[mlrun.common.schemas.Project], **kwargs, ): projects = projects_follower.list_projects(db_session, **kwargs) @@ -479,7 +479,7 @@ def _assert_list_projects( # assert again - with name only format projects = projects_follower.list_projects( - db_session, format_=mlrun.api.schemas.ProjectsFormat.name_only, **kwargs + db_session, format_=mlrun.common.schemas.ProjectsFormat.name_only, **kwargs ) assert len(projects.projects) == len(expected_projects) assert ( @@ -495,19 +495,19 @@ def _assert_list_projects( def _generate_project( name="project-name", description="some description", - desired_state=mlrun.api.schemas.ProjectDesiredState.online, - state=mlrun.api.schemas.ProjectState.online, + desired_state=mlrun.common.schemas.ProjectDesiredState.online, + state=mlrun.common.schemas.ProjectState.online, labels: typing.Optional[dict] = None, owner="some-owner", ): - return mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=name, labels=labels), - spec=mlrun.api.schemas.ProjectSpec( + return mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=name, labels=labels), + spec=mlrun.common.schemas.ProjectSpec( description=description, desired_state=desired_state, owner=owner, ), - status=mlrun.api.schemas.ProjectStatus( + status=mlrun.common.schemas.ProjectStatus( state=state, ), ) @@ -523,9 +523,9 @@ def _assert_projects_equal(project_1, project_2): ) == {} ) - assert mlrun.api.schemas.ProjectState( + assert mlrun.common.schemas.ProjectState( project_1.status.state - ) == mlrun.api.schemas.ProjectState(project_2.status.state) + ) == mlrun.common.schemas.ProjectState(project_2.status.state) def _assert_project_not_in_follower( @@ -540,7 +540,7 @@ def _assert_project_not_in_follower( def _assert_project_in_follower( db_session: sqlalchemy.orm.Session, projects_follower: mlrun.api.utils.projects.follower.Member, - project: mlrun.api.schemas.Project, + project: mlrun.common.schemas.Project, ): follower_project = projects_follower.get_project(db_session, project.metadata.name) _assert_projects_equal(project, follower_project) diff --git a/tests/api/utils/projects/test_leader_member.py b/tests/api/utils/projects/test_leader_member.py index f26b3655fd02..1973ad2a7629 100644 --- a/tests/api/utils/projects/test_leader_member.py +++ b/tests/api/utils/projects/test_leader_member.py @@ -18,10 +18,10 @@ import pytest import sqlalchemy.orm -import mlrun.api.schemas import mlrun.api.utils.projects.leader import mlrun.api.utils.projects.remotes.follower import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.config import mlrun.errors from mlrun.utils import logger @@ -72,9 +72,9 @@ def test_projects_sync_follower_project_adoption( ): project_name = "project-name" project_description = "some description" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ) nop_follower.create_project( None, @@ -105,9 +105,9 @@ def test_projects_sync_mid_deletion( """ project_name = "project-name" project_description = "some description" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ) projects_leader.create_project(db, project) _assert_project_in_followers( @@ -141,17 +141,17 @@ def test_projects_sync_leader_project_syncing( ): project_name = "project-name" project_description = "some description" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ) enriched_project = project.copy(deep=True) # simulate project enrichment enriched_project.status.state = enriched_project.spec.desired_state leader_follower.create_project(None, enriched_project) invalid_project_name = "invalid_name" - invalid_project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=invalid_project_name), + invalid_project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=invalid_project_name), ) leader_follower.create_project( None, @@ -180,17 +180,19 @@ def test_projects_sync_multiple_follower_project_adoption( ): second_follower_project_name = "project-name-2" second_follower_project_description = "some description 2" - second_follower_project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=second_follower_project_name), - spec=mlrun.api.schemas.ProjectSpec( + second_follower_project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( + name=second_follower_project_name + ), + spec=mlrun.common.schemas.ProjectSpec( description=second_follower_project_description ), ) both_followers_project_name = "project-name" both_followers_project_description = "some description" - both_followers_project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=both_followers_project_name), - spec=mlrun.api.schemas.ProjectSpec( + both_followers_project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=both_followers_project_name), + spec=mlrun.common.schemas.ProjectSpec( description=both_followers_project_description ), ) @@ -238,11 +240,11 @@ def test_create_project( ): project_name = "project-name" project_description = "some description" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec( + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec( description=project_description, - desired_state=mlrun.api.schemas.ProjectState.archived, + desired_state=mlrun.common.schemas.ProjectState.archived, ), ) projects_leader.create_project( @@ -291,8 +293,8 @@ def test_create_and_store_project_failure_invalid_name( ] for case in cases: project_name = case["name"] - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), ) if case["valid"]: projects_leader.create_project( @@ -334,8 +336,8 @@ def test_ensure_project( project_name, ) - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( None, @@ -362,9 +364,9 @@ def test_store_project_creation( ): project_name = "project-name" project_description = "some description" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ) _assert_no_projects_in_followers([leader_follower, nop_follower]) @@ -384,9 +386,9 @@ def test_store_project_update( ): project_name = "project-name" project_description = "some description" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ) projects_leader.create_project( None, @@ -395,10 +397,10 @@ def test_store_project_update( _assert_project_in_followers([leader_follower, nop_follower], project) # removing description from the projects and changing desired state - updated_project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec( - desired_state=mlrun.api.schemas.ProjectState.archived + updated_project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec( + desired_state=mlrun.common.schemas.ProjectState.archived ), ) @@ -417,8 +419,8 @@ def test_patch_project( leader_follower: mlrun.api.utils.projects.remotes.follower.Member, ): project_name = "project-name" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( None, @@ -430,7 +432,7 @@ def test_patch_project( # Adding description to the project and changing state project_description = "some description" - project_desired_state = mlrun.api.schemas.ProjectState.archived + project_desired_state = mlrun.common.schemas.ProjectState.archived projects_leader.patch_project( None, project_name, @@ -453,8 +455,8 @@ def test_store_and_patch_project_failure_conflict_body_path_name( leader_follower: mlrun.api.utils.projects.remotes.follower.Member, ): project_name = "project-name" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( None, @@ -466,8 +468,8 @@ def test_store_and_patch_project_failure_conflict_body_path_name( projects_leader.store_project( None, project_name, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name="different-name"), + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name="different-name"), ), ) with pytest.raises(mlrun.errors.MLRunConflictError): @@ -486,8 +488,8 @@ def test_delete_project( leader_follower: mlrun.api.utils.projects.remotes.follower.Member, ): project_name = "project-name" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( None, @@ -509,8 +511,8 @@ def mock_failed_delete(*args, **kwargs): raise RuntimeError() project_name = "project-name" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( None, @@ -534,8 +536,8 @@ def test_list_projects( leader_follower: mlrun.api.utils.projects.remotes.follower.Member, ): project_name = "project-name" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), ) projects_leader.create_project( None, @@ -546,8 +548,8 @@ def test_list_projects( # add some project to follower nop_follower.create_project( None, - mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name="some-other-project"), + mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name="some-other-project"), ), ) @@ -565,9 +567,9 @@ def test_get_project( ): project_name = "project-name" project_description = "some description" - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=project_name), - spec=mlrun.api.schemas.ProjectSpec(description=project_description), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=project_name), + spec=mlrun.common.schemas.ProjectSpec(description=project_description), ) projects_leader.create_project( None, @@ -599,7 +601,7 @@ def _assert_no_projects_in_followers(followers): def _assert_project_in_followers( - followers, project: mlrun.api.schemas.Project, enriched=True + followers, project: mlrun.common.schemas.Project, enriched=True ): for follower in followers: assert ( diff --git a/tests/api/utils/singletons/__init__.py b/tests/api/utils/singletons/__init__.py new file mode 100644 index 000000000000..b3085be1eb56 --- /dev/null +++ b/tests/api/utils/singletons/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/test_k8s_utils.py b/tests/api/utils/singletons/test_k8s_utils.py similarity index 92% rename from tests/test_k8s_utils.py rename to tests/api/utils/singletons/test_k8s_utils.py index 0907de0f9513..364462ed3ede 100644 --- a/tests/test_k8s_utils.py +++ b/tests/api/utils/singletons/test_k8s_utils.py @@ -16,7 +16,7 @@ import pytest -import mlrun.k8s_utils +import mlrun.api.utils.singletons.k8s import mlrun.runtimes @@ -44,7 +44,7 @@ def test_get_logger_pods_label_selector( if extra_selector: selector += f",{extra_selector}" - k8s_helper = mlrun.k8s_utils.K8sHelper(namespace, silent=True) + k8s_helper = mlrun.api.utils.singletons.k8s.K8sHelper(namespace, silent=True) k8s_helper.list_pods = unittest.mock.MagicMock() k8s_helper.get_logger_pods(project, uid, run_type) diff --git a/tests/test_builder.py b/tests/api/utils/test_builder.py similarity index 70% rename from tests/test_builder.py rename to tests/api/utils/test_builder.py index 324e9390217c..90c697a05311 100644 --- a/tests/test_builder.py +++ b/tests/api/utils/test_builder.py @@ -14,16 +14,20 @@ # import base64 import json +import os import re import unittest.mock +from contextlib import nullcontext as does_not_raise import deepdiff import pytest import mlrun -import mlrun.api.schemas +import mlrun.api.api.utils +import mlrun.api.utils.builder import mlrun.api.utils.singletons.k8s -import mlrun.builder +import mlrun.common.constants +import mlrun.common.schemas import mlrun.k8s_utils import mlrun.utils.version from mlrun.config import config @@ -34,8 +38,8 @@ def test_build_runtime_use_base_image_when_no_build(): base_image = "mlrun/ml-models" fn.build_config(base_image=base_image) assert fn.spec.image == "" - ready = mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + ready = mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), fn, ) assert ready is True @@ -48,8 +52,8 @@ def test_build_runtime_use_image_when_no_build(): "some-function", "some-project", "some-tag", image=image, kind="job" ) assert fn.spec.image == image - ready = mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + ready = mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), fn, with_mlrun=False, ) @@ -57,33 +61,20 @@ def test_build_runtime_use_image_when_no_build(): assert fn.spec.image == image -def test_build_config_with_multiple_commands(): - image = "mlrun/ml-models" - fn = mlrun.new_function( - "some-function", "some-project", "some-tag", image=image, kind="job" - ) - fn.build_config(commands=["pip install pandas", "pip install numpy"]) - assert len(fn.spec.build.commands) == 2 - - fn.build_config(commands=["pip install pandas"]) - assert len(fn.spec.build.commands) == 2 - - -def test_build_config_preserve_order(): - function = mlrun.new_function("some-function", kind="job") - # run a lot of times as order change - commands = [] - for index in range(10): - commands.append(str(index)) - # when using un-stable (doesn't preserve order) methods to make a list unique (like list(set(x))) it's random - # whether the order will be preserved, therefore run in a loop - for _ in range(100): - function.spec.build.commands = [] - function.build_config(commands=commands) - assert function.spec.build.commands == commands - - -def test_build_runtime_insecure_registries(monkeypatch): +@pytest.mark.parametrize( + "pull_mode,push_mode,secret,flags_expected", + [ + ("auto", "auto", "", True), + ("auto", "auto", "some-secret-name", False), + ("enabled", "enabled", "some-secret-name", True), + ("enabled", "enabled", "", True), + ("disabled", "disabled", "some-secret-name", False), + ("disabled", "disabled", "", False), + ], +) +def test_build_runtime_insecure_registries( + monkeypatch, pull_mode, push_mode, secret, flags_expected +): _patch_k8s_helper(monkeypatch) mlrun.mlconf.httpdb.builder.docker_registry = "registry.hub.docker.com/username" function = mlrun.new_function( @@ -96,62 +87,24 @@ def test_build_runtime_insecure_registries(monkeypatch): ) insecure_flags = {"--insecure", "--insecure-pull"} - for case in [ - { - "pull_mode": "auto", - "push_mode": "auto", - "secret": "", - "flags_expected": True, - }, - { - "pull_mode": "auto", - "push_mode": "auto", - "secret": "some-secret-name", - "flags_expected": False, - }, - { - "pull_mode": "enabled", - "push_mode": "enabled", - "secret": "some-secret-name", - "flags_expected": True, - }, - { - "pull_mode": "enabled", - "push_mode": "enabled", - "secret": "", - "flags_expected": True, - }, - { - "pull_mode": "disabled", - "push_mode": "disabled", - "secret": "some-secret-name", - "flags_expected": False, - }, - { - "pull_mode": "disabled", - "push_mode": "disabled", - "secret": "", - "flags_expected": False, - }, - ]: - mlrun.mlconf.httpdb.builder.insecure_pull_registry_mode = case["pull_mode"] - mlrun.mlconf.httpdb.builder.insecure_push_registry_mode = case["push_mode"] - mlrun.mlconf.httpdb.builder.docker_registry_secret = case["secret"] - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), - function, - ) - assert ( - insecure_flags.issubset( - set( - mlrun.builder.get_k8s_helper() - .create_pod.call_args[0][0] - .pod.spec.containers[0] - .args - ) + mlrun.mlconf.httpdb.builder.insecure_pull_registry_mode = pull_mode + mlrun.mlconf.httpdb.builder.insecure_push_registry_mode = push_mode + mlrun.mlconf.httpdb.builder.docker_registry_secret = secret + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), + function, + ) + assert ( + insecure_flags.issubset( + set( + mlrun.api.utils.singletons.k8s.get_k8s_helper() + .create_pod.call_args[0][0] + .pod.spec.containers[0] + .args ) - == case["flags_expected"] ) + == flags_expected + ) def test_build_runtime_target_image(monkeypatch): @@ -175,8 +128,8 @@ def test_build_runtime_target_image(monkeypatch): ) ) - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) @@ -188,8 +141,8 @@ def test_build_runtime_target_image(monkeypatch): function.spec.build.image = ( f"{registry}/{image_name_prefix}-some-addition:{function.metadata.tag}" ) - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) target_image = _get_target_image_from_create_pod_mock() @@ -198,11 +151,11 @@ def test_build_runtime_target_image(monkeypatch): # assert the same with the registry enrich prefix # assert we can override the target image as long as we stick to the prefix function.spec.build.image = ( - f"{mlrun.builder.IMAGE_NAME_ENRICH_REGISTRY_PREFIX}username" + f"{mlrun.common.constants.IMAGE_NAME_ENRICH_REGISTRY_PREFIX}username" f"/{image_name_prefix}-some-addition:{function.metadata.tag}" ) - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) target_image = _get_target_image_from_create_pod_mock() @@ -213,13 +166,13 @@ def test_build_runtime_target_image(monkeypatch): # assert it raises if we don't stick to the prefix for invalid_image in [ - f"{mlrun.builder.IMAGE_NAME_ENRICH_REGISTRY_PREFIX}username/without-prefix:{function.metadata.tag}", + f"{mlrun.common.constants.IMAGE_NAME_ENRICH_REGISTRY_PREFIX}username/without-prefix:{function.metadata.tag}", f"{registry}/without-prefix:{function.metadata.tag}", ]: function.spec.build.image = invalid_image with pytest.raises(mlrun.errors.MLRunInvalidArgumentError): - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) @@ -228,8 +181,8 @@ def test_build_runtime_target_image(monkeypatch): f"registry.hub.docker.com/some-other-username/image-not-by-prefix" f":{function.metadata.tag}" ) - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) target_image = _get_target_image_from_create_pod_mock() @@ -254,8 +207,8 @@ def test_build_runtime_use_default_node_selector(monkeypatch): kind="job", requirements=["some-package"], ) - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) assert ( @@ -287,8 +240,8 @@ def test_function_build_with_attributes_from_spec(monkeypatch): function.spec.node_name = node_name function.spec.node_selector = node_selector function.spec.priority_class_name = priority_class_name - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) assert ( @@ -324,8 +277,8 @@ def test_function_build_with_default_requests(monkeypatch): kind="job", requirements=["some-package"], ) - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) expected_resources = {"requests": {}} @@ -345,8 +298,8 @@ def test_function_build_with_default_requests(monkeypatch): } expected_resources = {"requests": {"cpu": "25m", "memory": "1m"}} - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) assert ( @@ -372,8 +325,8 @@ def test_function_build_with_default_requests(monkeypatch): } expected_resources = {"requests": {"cpu": "25m", "memory": "1m"}} - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) assert ( @@ -386,16 +339,14 @@ def test_function_build_with_default_requests(monkeypatch): ) -def test_resolve_mlrun_install_command(): - pip_command = "python -m pip install" +def test_resolve_mlrun_install_command_version(): cases = [ { "test_description": "when mlrun_version_specifier configured, expected to install mlrun_version_specifier", "mlrun_version_specifier": "mlrun[complete] @ git+https://github.com/mlrun/mlrun@v0.10.0", "client_version": "0.9.3", "server_mlrun_version_specifier": None, - "expected_mlrun_install_command": f"{pip_command} " - f'"mlrun[complete] @ git+https://github.com/mlrun/mlrun@v0.10.0"', + "expected_mlrun_install_command_version": "mlrun[complete] @ git+https://github.com/mlrun/mlrun@v0.10.0", }, { "test_description": "when mlrun_version_specifier is not configured and the server_mlrun_version_specifier" @@ -404,7 +355,7 @@ def test_resolve_mlrun_install_command(): "mlrun_version_specifier": None, "client_version": "0.9.3", "server_mlrun_version_specifier": "mlrun[complete]==0.10.0-server-version", - "expected_mlrun_install_command": f'{pip_command} "mlrun[complete]==0.10.0-server-version"', + "expected_mlrun_install_command_version": "mlrun[complete]==0.10.0-server-version", }, { "test_description": "when client_version is specified and stable and mlrun_version_specifier and" @@ -413,7 +364,7 @@ def test_resolve_mlrun_install_command(): "mlrun_version_specifier": None, "client_version": "0.9.3", "server_mlrun_version_specifier": None, - "expected_mlrun_install_command": f'{pip_command} "mlrun[complete]==0.9.3"', + "expected_mlrun_install_command_version": "mlrun[complete]==0.9.3", }, { "test_description": "when client_version is specified and unstable and mlrun_version_specifier and" @@ -422,8 +373,8 @@ def test_resolve_mlrun_install_command(): "mlrun_version_specifier": None, "client_version": "unstable", "server_mlrun_version_specifier": None, - "expected_mlrun_install_command": f'{pip_command} "mlrun[complete] @ git+' - f'https://github.com/mlrun/mlrun@development"', + "expected_mlrun_install_command_version": "mlrun[complete] @ " + "git+https://github.com/mlrun/mlrun@development", }, { "test_description": "when only the config.version is configured and unstable," @@ -432,8 +383,8 @@ def test_resolve_mlrun_install_command(): "client_version": None, "server_mlrun_version_specifier": None, "version": "unstable", - "expected_mlrun_install_command": f'{pip_command} "mlrun[complete] @ git+' - f'https://github.com/mlrun/mlrun@development"', + "expected_mlrun_install_command_version": "mlrun[complete] @ " + "git+https://github.com/mlrun/mlrun@development", }, { "test_description": "when only the config.version is configured and stable," @@ -442,7 +393,7 @@ def test_resolve_mlrun_install_command(): "client_version": None, "server_mlrun_version_specifier": None, "version": "0.9.2", - "expected_mlrun_install_command": f'{pip_command} "mlrun[complete]==0.9.2"', + "expected_mlrun_install_command_version": "mlrun[complete]==0.9.2", }, ] for case in cases: @@ -457,9 +408,9 @@ def test_resolve_mlrun_install_command(): mlrun_version_specifier = case.get("mlrun_version_specifier") client_version = case.get("client_version") - expected_result = case.get("expected_mlrun_install_command") + expected_result = case.get("expected_mlrun_install_command_version") - result = mlrun.builder.resolve_mlrun_install_command( + result = mlrun.api.utils.builder.resolve_mlrun_install_command_version( mlrun_version_specifier, client_version ) assert ( @@ -472,22 +423,39 @@ def test_build_runtime_ecr_with_ec2_iam_policy(monkeypatch): mlrun.mlconf.httpdb.builder.docker_registry = ( "aws_account_id.dkr.ecr.region.amazonaws.com" ) - function = mlrun.new_function( - "some-function", - "some-project", - "some-tag", - image="mlrun/mlrun", + project = mlrun.new_project("some-project") + project.set_secrets( + secrets={ + "AWS_ACCESS_KEY_ID": "test-a", + "AWS_SECRET_ACCESS_KEY": "test-b", + } + ) + function = project.set_function( + "hub://describe", + name="some-function", kind="job", - requirements=["some-package"], ) - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) pod_spec = _create_pod_mock_pod_spec() assert {"name": "AWS_SDK_LOAD_CONFIG", "value": "true", "value_from": None} in [ env.to_dict() for env in pod_spec.containers[0].env ] + + # ensure both envvars are set without values so they wont interfere with the iam policy + for env_name in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"]: + assert {"name": env_name, "value": "", "value_from": None} in [ + env.to_dict() for env in pod_spec.containers[0].env + ] + + # 1 for the AWS_SDK_LOAD_CONFIG=true + # 2 for the AWS_ACCESS_KEY_ID="" and AWS_SECRET_ACCESS_KEY="" + # 1 for the project secret + # == 4 + assert len(pod_spec.containers[0].env) == 4, "expected 4 env items" + assert len(pod_spec.init_containers) == 2 for init_container in pod_spec.init_containers: if init_container.name == "create-repos": @@ -530,8 +498,8 @@ def test_build_runtime_resolve_ecr_registry(monkeypatch): if case.get("tag"): image += f":{case.get('tag')}" function.spec.build.image = image - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) pod_spec = _create_pod_mock_pod_spec() @@ -562,8 +530,8 @@ def test_build_runtime_ecr_with_aws_secret(monkeypatch): kind="job", requirements=["some-package"], ) - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) pod_spec = _create_pod_mock_pod_spec() @@ -620,8 +588,8 @@ def test_build_runtime_ecr_with_repository(monkeypatch): kind="job", requirements=["some-package"], ) - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) pod_spec = _create_pod_mock_pod_spec() @@ -683,7 +651,7 @@ def test_resolve_image_dest(image_target, registry, default_repository, expected config.httpdb.builder.docker_registry = default_repository config.httpdb.builder.docker_registry_secret = docker_registry_secret - image_target, _ = mlrun.builder._resolve_image_target_and_registry_secret( + image_target, _ = mlrun.api.utils.builder.resolve_image_target_and_registry_secret( image_target, registry ) assert image_target == expected_dest @@ -757,7 +725,7 @@ def test_resolve_registry_secret( config.httpdb.builder.docker_registry = docker_registry config.httpdb.builder.docker_registry_secret = default_secret_name - _, secret_name = mlrun.builder._resolve_image_target_and_registry_secret( + _, secret_name = mlrun.api.utils.builder.resolve_image_target_and_registry_secret( image_target, registry, secret_name ) assert secret_name == expected_secret_name @@ -778,8 +746,8 @@ def test_kaniko_pod_spec_default_service_account_enrichment(monkeypatch): image="mlrun/mlrun", kind="job", ) - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) pod_spec = _create_pod_mock_pod_spec() @@ -802,8 +770,8 @@ def test_kaniko_pod_spec_user_service_account_enrichment(monkeypatch): ) service_account = "my-actual-sa" function.spec.service_account = service_account - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) pod_spec = _create_pod_mock_pod_spec() @@ -811,17 +779,18 @@ def test_kaniko_pod_spec_user_service_account_enrichment(monkeypatch): @pytest.mark.parametrize( - "workdir,expected_workdir", + "clone_target_dir,expected_workdir", [ (None, r"WORKDIR .*\/tmp.*\/mlrun"), ("", r"WORKDIR .*\/tmp.*\/mlrun"), ("./path/to/code", r"WORKDIR .*\/tmp.*\/mlrun\/path\/to\/code"), + ("rel_path", r"WORKDIR .*\/tmp.*\/mlrun\/rel_path"), ("/some/workdir", r"WORKDIR \/some\/workdir"), ], ) -def test_builder_workdir(monkeypatch, workdir, expected_workdir): +def test_builder_workdir(monkeypatch, clone_target_dir, expected_workdir): _patch_k8s_helper(monkeypatch) - mlrun.builder.make_kaniko_pod = unittest.mock.MagicMock() + mlrun.api.utils.builder.make_kaniko_pod = unittest.mock.MagicMock() docker_registry = "default.docker.registry/default-repository" config.httpdb.builder.docker_registry = docker_registry @@ -832,17 +801,150 @@ def test_builder_workdir(monkeypatch, workdir, expected_workdir): image="mlrun/mlrun", kind="job", ) - if workdir is not None: - function.spec.workdir = workdir - function.spec.build.source = "some-source.tgz" - mlrun.builder.build_runtime( - mlrun.api.schemas.AuthInfo(), + if clone_target_dir is not None: + function.spec.clone_target_dir = clone_target_dir + function.spec.build.source = "/path/some-source.tgz" + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), function, ) - dockerfile = mlrun.builder.make_kaniko_pod.call_args[1]["dockertext"] + dockerfile = mlrun.api.utils.builder.make_kaniko_pod.call_args[1]["dockertext"] dockerfile_lines = dockerfile.splitlines() expected_workdir_re = re.compile(expected_workdir) - assert expected_workdir_re.match(dockerfile_lines[2]) + assert expected_workdir_re.match(dockerfile_lines[1]) + + +@pytest.mark.parametrize( + "source,expectation", + [ + ("v3io://path/some-source.tar.gz", does_not_raise()), + ("/path/some-source.tar.gz", does_not_raise()), + ("/path/some-source.zip", does_not_raise()), + ( + "./relative/some-source", + pytest.raises(mlrun.errors.MLRunInvalidArgumentError), + ), + ("./", pytest.raises(mlrun.errors.MLRunInvalidArgumentError)), + ], +) +def test_builder_source(monkeypatch, source, expectation): + _patch_k8s_helper(monkeypatch) + mlrun.api.utils.builder.make_kaniko_pod = unittest.mock.MagicMock() + docker_registry = "default.docker.registry/default-repository" + config.httpdb.builder.docker_registry = docker_registry + + function = mlrun.new_function( + "some-function", + "some-project", + "some-tag", + image="mlrun/mlrun", + kind="job", + ) + + with expectation: + function.spec.build.source = source + mlrun.api.utils.builder.build_runtime( + mlrun.common.schemas.AuthInfo(), + function, + ) + + dockerfile = mlrun.api.utils.builder.make_kaniko_pod.call_args[1]["dockertext"] + dockerfile_lines = dockerfile.splitlines() + + expected_source = source + if "://" in source: + _, expected_source = os.path.split(source) + + if source.endswith(".zip"): + expected_output_re = re.compile( + rf"COPY {expected_source} .*/tmp.*/mlrun/source" + ) + expected_line_index = 4 + + else: + expected_output_re = re.compile(rf"ADD {expected_source} .*/tmp.*/mlrun") + expected_line_index = 2 + + assert expected_output_re.match(dockerfile_lines[expected_line_index].strip()) + + +@pytest.mark.parametrize( + "requirements, commands, with_mlrun, mlrun_version_specifier, client_version, expected_commands, " + "expected_requirements_list, expected_requirements_path", + [ + ([], [], False, None, None, [], [], ""), + ( + [], + [], + True, + None, + None, + [ + f"python -m pip install --upgrade pip{mlrun.config.config.httpdb.builder.pip_version}" + ], + ["mlrun[complete] @ git+https://github.com/mlrun/mlrun@development"], + "/empty/requirements.txt", + ), + ( + [], + ["some command"], + True, + "mlrun~=1.4", + None, + [ + "some command", + f"python -m pip install --upgrade pip{mlrun.config.config.httpdb.builder.pip_version}", + ], + ["mlrun~=1.4"], + "/empty/requirements.txt", + ), + ( + [], + [], + True, + "", + "1.4.0", + [ + f"python -m pip install --upgrade pip{mlrun.config.config.httpdb.builder.pip_version}" + ], + ["mlrun[complete]==1.4.0"], + "/empty/requirements.txt", + ), + ( + ["pandas"], + [], + True, + "", + "1.4.0", + [ + f"python -m pip install --upgrade pip{mlrun.config.config.httpdb.builder.pip_version}" + ], + ["mlrun[complete]==1.4.0", "pandas"], + "/empty/requirements.txt", + ), + (["pandas"], [], False, "", "1.4.0", [], ["pandas"], "/empty/requirements.txt"), + ], +) +def test_resolve_build_requirements( + requirements, + commands, + with_mlrun, + mlrun_version_specifier, + client_version, + expected_commands, + expected_requirements_list, + expected_requirements_path, +): + ( + commands, + requirements_list, + requirements_path, + ) = mlrun.api.utils.builder._resolve_build_requirements( + requirements, commands, with_mlrun, mlrun_version_specifier, client_version + ) + assert commands == expected_commands + assert requirements_list == expected_requirements_list + assert requirements_path == expected_requirements_path def _get_target_image_from_create_pod_mock(): @@ -850,7 +952,11 @@ def _get_target_image_from_create_pod_mock(): def _create_pod_mock_pod_spec(): - return mlrun.builder.get_k8s_helper().create_pod.call_args[0][0].pod.spec + return ( + mlrun.api.utils.singletons.k8s.get_k8s_helper() + .create_pod.call_args[0][0] + .pod.spec + ) def _patch_k8s_helper(monkeypatch): @@ -867,15 +973,9 @@ def _patch_k8s_helper(monkeypatch): get_k8s_helper_mock.get_project_secret_data = unittest.mock.Mock( side_effect=lambda project, keys: {"KEY": "val"} ) - monkeypatch.setattr( - mlrun.builder, "get_k8s_helper", lambda *args, **kwargs: get_k8s_helper_mock - ) - monkeypatch.setattr( - mlrun.k8s_utils, "get_k8s_helper", lambda *args, **kwargs: get_k8s_helper_mock - ) monkeypatch.setattr( mlrun.api.utils.singletons.k8s, - "get_k8s", + "get_k8s_helper", lambda *args, **kwargs: get_k8s_helper_mock, ) diff --git a/tests/api/utils/test_scheduler.py b/tests/api/utils/test_scheduler.py index fd9316d1d3a0..3537a9da3736 100644 --- a/tests/api/utils/test_scheduler.py +++ b/tests/api/utils/test_scheduler.py @@ -32,9 +32,9 @@ import mlrun.api.utils.auth.verifier import mlrun.api.utils.singletons.k8s import mlrun.api.utils.singletons.project_member +import mlrun.common.schemas import mlrun.errors import tests.api.conftest -from mlrun.api import schemas from mlrun.api.utils.scheduler import Scheduler from mlrun.api.utils.singletons.db import get_db from mlrun.config import config @@ -56,7 +56,10 @@ async def scheduler(db: Session) -> typing.Generator: call_counter: int = 0 -schedule_end_time_margin = 0.5 + +# TODO: The margin will need to rise for each additional CPU-consuming operation added along the flow, +# we need to consider how to decouple in the future +schedule_end_time_margin = 0.7 async def bump_counter(): @@ -75,6 +78,19 @@ async def do_nothing(): pass +def create_project( + db: Session, project_name: str = None +) -> mlrun.common.schemas.Project: + """API tests use sql db, so we need to create the project with its schema""" + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata( + name=project_name or config.default_project + ) + ) + mlrun.api.crud.Projects().create_project(db, project) + return project + + @pytest.mark.asyncio async def test_not_skipping_delayed_schedules(db: Session, scheduler: Scheduler): global call_counter @@ -85,17 +101,17 @@ async def test_not_skipping_delayed_schedules(db: Session, scheduler: Scheduler) number_of_jobs=expected_call_counter, seconds_interval=1 ) # this way we're leaving ourselves one second to create the schedule preventing transient test failure - cron_trigger = schemas.ScheduleCronTrigger( + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger( second="*/1", start_date=start_date, end_date=end_date ) schedule_name = "schedule-name" project = config.default_project scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, bump_counter, cron_trigger, ) @@ -116,17 +132,17 @@ async def test_create_schedule(db: Session, scheduler: Scheduler): number_of_jobs=5, seconds_interval=1 ) # this way we're leaving ourselves one second to create the schedule preventing transient test failure - cron_trigger = schemas.ScheduleCronTrigger( + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger( second="*/1", start_date=start_date, end_date=end_date ) schedule_name = "schedule-name" project = config.default_project scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, bump_counter, cron_trigger, ) @@ -148,10 +164,10 @@ async def test_invoke_schedule( scheduler: Scheduler, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): - cron_trigger = schemas.ScheduleCronTrigger(year=1999) + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year=1999) schedule_name = "schedule-name" project_name = config.default_project - mlrun.new_project(project_name, save=False) + create_project(db, project_name) scheduled_object = _create_mlrun_function_and_matching_scheduled_object( db, project_name ) @@ -159,22 +175,22 @@ async def test_invoke_schedule( assert len(runs) == 0 scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project_name, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, cron_trigger, ) runs = get_db().list_runs(db, project=project_name) assert len(runs) == 0 response_1 = await scheduler.invoke_schedule( - db, mlrun.api.schemas.AuthInfo(), project_name, schedule_name + db, mlrun.common.schemas.AuthInfo(), project_name, schedule_name ) runs = get_db().list_runs(db, project=project_name) assert len(runs) == 1 response_2 = await scheduler.invoke_schedule( - db, mlrun.api.schemas.AuthInfo(), project_name, schedule_name + db, mlrun.common.schemas.AuthInfo(), project_name, schedule_name ) runs = get_db().list_runs(db, project=project_name) assert len(runs) == 2 @@ -208,29 +224,30 @@ async def test_create_schedule_mlrun_function( k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): + project_name = config.default_project + create_project(db, project_name) + + scheduled_object = _create_mlrun_function_and_matching_scheduled_object( + db, project_name + ) + runs = get_db().list_runs(db, project=project_name) + assert len(runs) == 0 + expected_call_counter = 1 start_date, end_date = _get_start_and_end_time_for_scheduled_trigger( number_of_jobs=expected_call_counter, seconds_interval=1 ) # this way we're leaving ourselves one second to create the schedule preventing transient test failure - cron_trigger = schemas.ScheduleCronTrigger( + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger( second="*/1", start_date=start_date, end_date=end_date ) schedule_name = "schedule-name" - project_name = config.default_project - mlrun.new_project(project_name, save=False) - - scheduled_object = _create_mlrun_function_and_matching_scheduled_object( - db, project_name - ) - runs = get_db().list_runs(db, project=project_name) - assert len(runs) == 0 scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project_name, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, cron_trigger, ) @@ -267,13 +284,13 @@ async def test_create_schedule_success_cron_trigger_validation( {"year": "2050"}, ] for index, case in enumerate(cases): - cron_trigger = schemas.ScheduleCronTrigger(**case) + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(**case) scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), "project", f"schedule-name-{index}", - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, ) @@ -287,7 +304,7 @@ async def test_schedule_upgrade_from_scheduler_without_credentials_store( ): name = "schedule-name" project_name = config.default_project - mlrun.new_project(project_name, save=False) + create_project(db, project_name) scheduled_object = _create_mlrun_function_and_matching_scheduled_object( db, project_name @@ -297,16 +314,16 @@ async def test_schedule_upgrade_from_scheduler_without_credentials_store( start_date, end_date = _get_start_and_end_time_for_scheduled_trigger( number_of_jobs=expected_call_counter, seconds_interval=1 ) - cron_trigger = schemas.ScheduleCronTrigger( + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger( second="*/1", start_date=start_date, end_date=end_date ) # we're before upgrade so create a schedule with empty auth info scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project_name, name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, cron_trigger, ) @@ -323,7 +340,7 @@ async def test_schedule_upgrade_from_scheduler_without_credentials_store( access_key = "some-access_key" mlrun.api.utils.singletons.project_member.get_project_member().get_project_owner = ( unittest.mock.Mock( - return_value=mlrun.api.schemas.ProjectOwner( + return_value=mlrun.common.schemas.ProjectOwner( username=username, access_key=access_key ) ) @@ -361,33 +378,33 @@ async def test_create_schedule_failure_too_frequent_cron_trigger( {"minute": "11,22,33,44,55,59"}, ] for case in cases: - cron_trigger = schemas.ScheduleCronTrigger(**case) + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(**case) with pytest.raises(ValueError) as excinfo: scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), "project", "schedule-name", - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, ) - assert "Cron trigger too frequent. no more then one job" in str(excinfo.value) + assert "Cron trigger too frequent. no more than one job" in str(excinfo.value) @pytest.mark.asyncio async def test_create_schedule_failure_already_exists( db: Session, scheduler: Scheduler ): - cron_trigger = schemas.ScheduleCronTrigger(year="1999") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") schedule_name = "schedule-name" project = config.default_project scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, ) @@ -398,10 +415,10 @@ async def test_create_schedule_failure_already_exists( ): scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, ) @@ -415,11 +432,11 @@ async def test_validate_cron_trigger_multi_checks(db: Session, scheduler: Schedu If the limit is 10 minutes and the cron trigger configured with minute=0-45 (which means every minute, for the first 45 minutes of every hour), and the check will occur at the 44 minute of some hour, the next run time will be one minute away, but the second next run time after it, will be at the next hour 0 minute. The delta - between the two will be 15 minutes, more then 10 minutes so it will pass validation, although it actually runs + between the two will be 15 minutes, more than 10 minutes so it will pass validation, although it actually runs every minute. """ scheduler._min_allowed_interval = "10 minutes" - cron_trigger = schemas.ScheduleCronTrigger(minute="0-45") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(minute="0-45") now = datetime( year=2020, month=2, @@ -431,20 +448,20 @@ async def test_validate_cron_trigger_multi_checks(db: Session, scheduler: Schedu ) with pytest.raises(ValueError) as excinfo: scheduler._validate_cron_trigger(cron_trigger, now) - assert "Cron trigger too frequent. no more then one job" in str(excinfo.value) + assert "Cron trigger too frequent. no more than one job" in str(excinfo.value) @pytest.mark.asyncio async def test_get_schedule_datetime_fields_timezone(db: Session, scheduler: Scheduler): - cron_trigger = schemas.ScheduleCronTrigger(minute="*/10") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(minute="*/10") schedule_name = "schedule-name" project = config.default_project scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, ) @@ -464,15 +481,15 @@ async def test_get_schedule(db: Session, scheduler: Scheduler): "label1": "value1", "label2": "value2", } - cron_trigger = schemas.ScheduleCronTrigger(year="1999") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") schedule_name = "schedule-name" project = config.default_project scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, labels_1, @@ -484,7 +501,7 @@ async def test_get_schedule(db: Session, scheduler: Scheduler): schedule, project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, cron_trigger, None, labels_1, @@ -495,14 +512,14 @@ async def test_get_schedule(db: Session, scheduler: Scheduler): "label4": "value4", } year = 2050 - cron_trigger_2 = schemas.ScheduleCronTrigger(year=year, timezone="utc") + cron_trigger_2 = mlrun.common.schemas.ScheduleCronTrigger(year=year, timezone="utc") schedule_name_2 = "schedule-name-2" scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name_2, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger_2, labels_2, @@ -513,7 +530,7 @@ async def test_get_schedule(db: Session, scheduler: Scheduler): schedule_2, project, schedule_name_2, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, cron_trigger_2, year_datetime, labels_2, @@ -525,7 +542,7 @@ async def test_get_schedule(db: Session, scheduler: Scheduler): schedules.schedules[0], project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, cron_trigger, None, labels_1, @@ -534,7 +551,7 @@ async def test_get_schedule(db: Session, scheduler: Scheduler): schedules.schedules[1], project, schedule_name_2, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, cron_trigger_2, year_datetime, labels_2, @@ -546,7 +563,7 @@ async def test_get_schedule(db: Session, scheduler: Scheduler): schedules.schedules[0], project, schedule_name_2, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, cron_trigger_2, year_datetime, labels_2, @@ -555,15 +572,15 @@ async def test_get_schedule(db: Session, scheduler: Scheduler): @pytest.mark.asyncio async def test_get_schedule_next_run_time_from_db(db: Session, scheduler: Scheduler): - cron_trigger = schemas.ScheduleCronTrigger(minute="*/10") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(minute="*/10") schedule_name = "schedule-name" project = config.default_project scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, ) @@ -572,7 +589,7 @@ async def test_get_schedule_next_run_time_from_db(db: Session, scheduler: Schedu # simulating when running in worker mlrun.mlconf.httpdb.clusterization.role = ( - mlrun.api.schemas.ClusterizationRole.worker + mlrun.common.schemas.ClusterizationRole.worker ) worker_schedule = scheduler.get_schedule(db, project, schedule_name) assert worker_schedule.next_run_time is not None @@ -602,7 +619,7 @@ async def test_list_schedules_name_filter(db: Session, scheduler: Scheduler): {"name": "mluRn", "should_find": False}, ] - cron_trigger = schemas.ScheduleCronTrigger(minute="*/10") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(minute="*/10") project = config.default_project expected_schedule_names = [] for case in cases: @@ -610,10 +627,10 @@ async def test_list_schedules_name_filter(db: Session, scheduler: Scheduler): should_find = case["should_find"] scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, ) @@ -651,15 +668,15 @@ async def test_list_schedules_from_scheduler(db: Session, scheduler: Scheduler): @pytest.mark.asyncio async def test_delete_schedule(db: Session, scheduler: Scheduler): - cron_trigger = schemas.ScheduleCronTrigger(year="1999") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") schedule_name = "schedule-name" project = config.default_project scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, ) @@ -715,17 +732,17 @@ async def test_rescheduling(db: Session, scheduler: Scheduler): start_date, end_date = _get_start_and_end_time_for_scheduled_trigger( number_of_jobs=expected_call_counter, seconds_interval=1 ) - cron_trigger = schemas.ScheduleCronTrigger( + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger( second="*/1", start_date=start_date, end_date=end_date ) schedule_name = "schedule-name" project = config.default_project scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, bump_counter, cron_trigger, ) @@ -758,13 +775,13 @@ async def test_rescheduling_secrets_storing( scheduled_object = _create_mlrun_function_and_matching_scheduled_object(db, project) username = "some-username" access_key = "some-user-access-key" - cron_trigger = schemas.ScheduleCronTrigger(year="1999") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(username=username, access_key=access_key), + mlrun.common.schemas.AuthInfo(username=username, access_key=access_key), project, name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, cron_trigger, ) @@ -805,13 +822,13 @@ async def test_schedule_crud_secrets_handling( ) access_key = "some-user-access-key" username = "some-username" - cron_trigger = schemas.ScheduleCronTrigger(year="1999") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(username=username, access_key=access_key), + mlrun.common.schemas.AuthInfo(username=username, access_key=access_key), project, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, cron_trigger, ) @@ -829,7 +846,7 @@ async def test_schedule_crud_secrets_handling( # update labels scheduler.update_schedule( db, - mlrun.api.schemas.AuthInfo(username=username, access_key=access_key), + mlrun.common.schemas.AuthInfo(username=username, access_key=access_key), project, schedule_name, labels={"label-key": "label-value"}, @@ -864,17 +881,17 @@ async def test_schedule_access_key_generation( project = config.default_project schedule_name = "schedule-name" scheduled_object = _create_mlrun_function_and_matching_scheduled_object(db, project) - cron_trigger = schemas.ScheduleCronTrigger(year="1999") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") access_key = "generated-access-key" mlrun.api.utils.auth.verifier.AuthVerifier().get_or_create_access_key = ( unittest.mock.Mock(return_value=access_key) ) scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, cron_trigger, ) @@ -889,7 +906,7 @@ async def test_schedule_access_key_generation( ) scheduler.update_schedule( db, - mlrun.api.schemas.AuthInfo( + mlrun.common.schemas.AuthInfo( access_key=mlrun.model.Credentials.generate_access_key ), project, @@ -914,15 +931,13 @@ async def test_schedule_access_key_reference_handling( project = config.default_project schedule_name = "schedule-name" scheduled_object = _create_mlrun_function_and_matching_scheduled_object(db, project) - cron_trigger = schemas.ScheduleCronTrigger(year="1999") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") username = "some-user-name" access_key = "some-access-key" - secret_ref = ( - mlrun.model.Credentials.secret_reference_prefix - + k8s_secrets_mock.store_auth_secret(username, access_key) - ) - auth_info = mlrun.api.schemas.AuthInfo() + mocked_secret_ref, _ = k8s_secrets_mock.store_auth_secret(username, access_key) + secret_ref = mlrun.model.Credentials.secret_reference_prefix + mocked_secret_ref + auth_info = mlrun.common.schemas.AuthInfo() auth_info.access_key = secret_ref scheduler.create_schedule( @@ -930,7 +945,7 @@ async def test_schedule_access_key_reference_handling( auth_info, project, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, cron_trigger, labels={"label1": "value1", "label2": "value2"}, @@ -952,7 +967,7 @@ async def test_schedule_convert_from_old_credentials_to_new( project = config.default_project schedule_name = "schedule-name" scheduled_object = _create_mlrun_function_and_matching_scheduled_object(db, project) - cron_trigger = schemas.ScheduleCronTrigger(year="1999") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") username = "some-user-name" access_key = "some-access-key" @@ -960,16 +975,16 @@ async def test_schedule_convert_from_old_credentials_to_new( # to simulate an old schedule. scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, cron_trigger, labels={"label1": "value1", "label2": "value2"}, ) - auth_info = mlrun.api.schemas.AuthInfo(username=username, access_key=access_key) + auth_info = mlrun.common.schemas.AuthInfo(username=username, access_key=access_key) mlrun.api.utils.auth.verifier.AuthVerifier().is_jobs_auth_required = ( unittest.mock.Mock(return_value=True) ) @@ -1013,10 +1028,10 @@ async def test_update_schedule( "label3": "value3", "label4": "value4", } - inactive_cron_trigger = schemas.ScheduleCronTrigger(year="1999") + inactive_cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") schedule_name = "schedule-name" project_name = config.default_project - mlrun.new_project(project_name, save=False) + create_project(db, project_name) scheduled_object = _create_mlrun_function_and_matching_scheduled_object( db, project_name @@ -1025,10 +1040,10 @@ async def test_update_schedule( assert len(runs) == 0 scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project_name, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, inactive_cron_trigger, labels=labels_1, @@ -1040,7 +1055,7 @@ async def test_update_schedule( schedule, project_name, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, inactive_cron_trigger, None, labels_1, @@ -1049,7 +1064,7 @@ async def test_update_schedule( # update labels scheduler.update_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project_name, schedule_name, labels=labels_2, @@ -1060,7 +1075,7 @@ async def test_update_schedule( schedule, project_name, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, inactive_cron_trigger, None, labels_2, @@ -1069,7 +1084,7 @@ async def test_update_schedule( # update nothing scheduler.update_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project_name, schedule_name, ) @@ -1079,7 +1094,7 @@ async def test_update_schedule( schedule, project_name, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, inactive_cron_trigger, None, labels_2, @@ -1088,7 +1103,7 @@ async def test_update_schedule( # update labels to empty dict scheduler.update_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project_name, schedule_name, labels={}, @@ -1099,7 +1114,7 @@ async def test_update_schedule( schedule, project_name, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, inactive_cron_trigger, None, {}, @@ -1111,14 +1126,14 @@ async def test_update_schedule( number_of_jobs=expected_call_counter, seconds_interval=1 ) # this way we're leaving ourselves one second to create the schedule preventing transient test failure - cron_trigger = schemas.ScheduleCronTrigger( + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger( second="*/1", start_date=start_date, end_date=end_date, ) scheduler.update_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project_name, schedule_name, cron_trigger=cron_trigger, @@ -1139,7 +1154,7 @@ async def test_update_schedule( schedule, project_name, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, cron_trigger, next_run_time, {}, @@ -1162,7 +1177,7 @@ async def test_update_schedule_failure_not_found_in_db( project = config.default_project with pytest.raises(mlrun.errors.MLRunNotFoundError) as excinfo: scheduler.update_schedule( - db, mlrun.api.schemas.AuthInfo(), project, schedule_name + db, mlrun.common.schemas.AuthInfo(), project, schedule_name ) assert "Schedule not found" in str(excinfo.value) @@ -1178,12 +1193,12 @@ async def test_update_schedule_failure_not_found_in_scheduler( ) # create the schedule only in the db - inactive_cron_trigger = schemas.ScheduleCronTrigger(year="1999") + inactive_cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") get_db().create_schedule( db, project_name, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, inactive_cron_trigger, 1, @@ -1192,7 +1207,7 @@ async def test_update_schedule_failure_not_found_in_scheduler( # update schedule should fail since the schedule job was not created in the scheduler with pytest.raises(mlrun.errors.MLRunNotFoundError) as excinfo: scheduler.update_schedule( - db, mlrun.api.schemas.AuthInfo(), project_name, schedule_name + db, mlrun.common.schemas.AuthInfo(), project_name, schedule_name ) job_id = scheduler._resolve_job_id(project_name, schedule_name) assert ( @@ -1211,43 +1226,48 @@ async def test_update_schedule_failure_not_found_in_scheduler( [(1, 2), (2, 3), (3, 4)], ) @pytest.mark.parametrize( - "schedule_kind", [schemas.ScheduleKinds.job, schemas.ScheduleKinds.local_function] + "schedule_kind", + [ + mlrun.common.schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.local_function, + ], ) async def test_schedule_job_concurrency_limit( db: Session, scheduler: Scheduler, concurrency_limit: int, run_amount: int, - schedule_kind: schemas.ScheduleKinds, + schedule_kind: mlrun.common.schemas.ScheduleKinds, k8s_secrets_mock: tests.api.conftest.K8sSecretsMock, ): global call_counter call_counter = 0 - now = datetime.now(timezone.utc) - now_plus_1_seconds = now + timedelta(seconds=1) - now_plus_5_seconds = now + timedelta(seconds=5) - cron_trigger = schemas.ScheduleCronTrigger( - second="*/1", start_date=now_plus_1_seconds, end_date=now_plus_5_seconds - ) - schedule_name = "schedule-name" project_name = config.default_project - mlrun.new_project(project_name, save=False) + create_project(db, project_name) scheduled_object = ( _create_mlrun_function_and_matching_scheduled_object( db, project_name, handler="sleep_two_seconds" ) - if schedule_kind == schemas.ScheduleKinds.job + if schedule_kind == mlrun.common.schemas.ScheduleKinds.job else bump_counter_and_wait ) runs = get_db().list_runs(db, project=project_name) assert len(runs) == 0 + now = datetime.now(timezone.utc) + now_plus_1_seconds = now + timedelta(seconds=1) + now_plus_5_seconds = now + timedelta(seconds=5) + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger( + second="*/1", start_date=now_plus_1_seconds, end_date=now_plus_5_seconds + ) + schedule_name = "schedule-name" + scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project_name, schedule_name, schedule_kind, @@ -1276,7 +1296,7 @@ async def test_schedule_job_concurrency_limit( # wait so all runs will complete await asyncio.sleep(7 - random_sleep_time) - if schedule_kind == schemas.ScheduleKinds.job: + if schedule_kind == mlrun.common.schemas.ScheduleKinds.job: runs = get_db().list_runs(db, project=project_name) assert len(runs) == run_amount else: @@ -1299,12 +1319,12 @@ async def test_schedule_job_next_run_time( now = datetime.now(timezone.utc) now_plus_1_seconds = now + timedelta(seconds=1) now_plus_5_seconds = now + timedelta(seconds=5) - cron_trigger = schemas.ScheduleCronTrigger( + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger( second="*/1", start_date=now_plus_1_seconds, end_date=now_plus_5_seconds ) schedule_name = "schedule-name" project_name = config.default_project - mlrun.new_project(project_name, save=False) + create_project(db, project_name) scheduled_object = _create_mlrun_function_and_matching_scheduled_object( db, project_name, handler="sleep_two_seconds" @@ -1315,10 +1335,10 @@ async def test_schedule_job_next_run_time( scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project_name, schedule_name, - schemas.ScheduleKinds.job, + mlrun.common.schemas.ScheduleKinds.job, scheduled_object, cron_trigger, concurrency_limit=1, @@ -1337,7 +1357,7 @@ async def test_schedule_job_next_run_time( # the next run time should be updated to the next second after the invocation failure schedule_invocation_timestamp = datetime.now(timezone.utc) await scheduler.invoke_schedule( - db, mlrun.api.schemas.AuthInfo(), project_name, schedule_name + db, mlrun.common.schemas.AuthInfo(), project_name, schedule_name ) runs = get_db().list_runs(db, project=project_name) @@ -1442,7 +1462,7 @@ def _assert_schedule_secrets( def _assert_schedule( - schedule: schemas.ScheduleOutput, + schedule: mlrun.common.schemas.ScheduleOutput, project, name, kind, @@ -1462,13 +1482,13 @@ def _assert_schedule( def _create_do_nothing_schedule( db: Session, scheduler: Scheduler, project: str, name: str ): - cron_trigger = schemas.ScheduleCronTrigger(year="1999") + cron_trigger = mlrun.common.schemas.ScheduleCronTrigger(year="1999") scheduler.create_schedule( db, - mlrun.api.schemas.AuthInfo(), + mlrun.common.schemas.AuthInfo(), project, name, - schemas.ScheduleKinds.local_function, + mlrun.common.schemas.ScheduleKinds.local_function, do_nothing, cron_trigger, ) diff --git a/tests/artifacts/test_artifacts.py b/tests/artifacts/test_artifacts.py index eed506afedd7..2ec434b8ba91 100644 --- a/tests/artifacts/test_artifacts.py +++ b/tests/artifacts/test_artifacts.py @@ -468,40 +468,6 @@ def test_resolve_body_hash_path( assert expected_target_path == target_path -def test_export_import(): - project = mlrun.new_project("log-mod", save=False) - target_project = mlrun.new_project("log-mod2", save=False) - for mode in [False, True]: - mlrun.mlconf.artifacts.generate_target_path_from_artifact_hash = mode - - model = project.log_model( - "mymod", - body=b"123", - model_file="model.pkl", - extra_data={"kk": b"456"}, - artifact_path=results_dir, - ) - - for suffix in ["yaml", "json", "zip"]: - # export the artifact to a file - model.export(f"{results_dir}/a.{suffix}") - - # import and log the artifact to the new project - artifact = target_project.import_artifact( - f"{results_dir}/a.{suffix}", f"mod-{suffix}", artifact_path=results_dir - ) - assert artifact.kind == "model" - assert artifact.metadata.key == f"mod-{suffix}" - assert artifact.metadata.project == "log-mod2" - temp_path, model_spec, extra_dataitems = mlrun.artifacts.get_model( - artifact.uri - ) - with open(temp_path, "rb") as fp: - data = fp.read() - assert data == b"123" - assert extra_dataitems["kk"].get() == b"456" - - def test_inline_body(): project = mlrun.new_project("inline", save=False) diff --git a/tests/artifacts/test_model.py b/tests/artifacts/test_model.py deleted file mode 100644 index cef8a1133540..000000000000 --- a/tests/artifacts/test_model.py +++ /dev/null @@ -1,92 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import pathlib - -import pandas as pd - -import mlrun -from mlrun.artifacts.model import ModelArtifact, get_model, update_model -from mlrun.features import Feature -from tests.conftest import results - -results_dir = f"{results}/artifacts/" - -raw_data = { - "first_name": ["Jason", "Molly", "Tina", "Jake", "Amy"], - "last_name": ["Miller", "Jacobson", "Ali", "Milner", "Cooze"], - "age": [42, 52, 36, 24, 73], - "testScore": [25, 94, 57, 62, 70], -} - -expected_inputs = [ - {"name": "last_name", "value_type": "str"}, - {"name": "first_name", "value_type": "str"}, - {"name": "age", "value_type": "int"}, -] -expected_outputs = [{"name": "testScore", "value_type": "int"}] - - -def test_infer(): - model = ModelArtifact("my-model") - df = pd.DataFrame(raw_data, columns=["last_name", "first_name", "age", "testScore"]) - model.infer_from_df(df, ["testScore"]) - assert model.inputs.to_dict() == expected_inputs, "unexpected model inputs" - assert model.outputs.to_dict() == expected_outputs, "unexpected model outputs" - assert list(model.feature_stats.keys()) == [ - "last_name", - "first_name", - "age", - "testScore", - ], "wrong stat keys" - - -def test_model_update(): - path = pathlib.Path(__file__).absolute().parent - model = ModelArtifact( - "my-model", model_dir=str(path / "assets"), model_file="model.pkl" - ) - - target_path = results_dir + "model/" - - project = mlrun.new_project("test-proj", save=False) - artifact = project.log_artifact(model, upload=True, artifact_path=target_path) - - artifact_uri = f"store://artifacts/{artifact.project}/{artifact.db_key}" - updated_model_spec = update_model( - artifact_uri, - parameters={"a": 1}, - metrics={"b": 2}, - inputs=[Feature(name="f1")], - outputs=[Feature(name="f2")], - feature_vector="vec", - feature_weights=[1, 2], - key_prefix="test-", - labels={"lbl": "tst"}, - write_spec_copy=False, - ) - print(updated_model_spec.to_yaml()) - - model_path, model, extra_dataitems = get_model(artifact_uri) - - assert model_path.endswith(f"model/{model.model_file}"), "illegal model path" - assert model.parameters == {"a": 1}, "wrong parameters" - assert model.metrics == {"test-b": 2}, "wrong metrics" - - assert model.inputs[0].name == "f1", "wrong inputs" - assert model.outputs[0].name == "f2", "wrong outputs" - - assert model.feature_vector == "vec", "wrong feature_vector" - assert model.feature_weights == [1, 2], "wrong feature_weights" - assert model.labels == {"lbl": "tst"}, "wrong labels" diff --git a/tests/assets/notification.json b/tests/assets/notification.json new file mode 100644 index 000000000000..c30eb56b99dd --- /dev/null +++ b/tests/assets/notification.json @@ -0,0 +1 @@ +{"slack": {"webhook": "123456"},"ipython" : {"webhook": "1234"}} diff --git a/tests/automation/package_test/assets/ignored_vulnerabilities.json b/tests/automation/package_test/assets/ignored_vulnerabilities.json index e7fbb3e5f1d2..9a50f45fc0b5 100644 --- a/tests/automation/package_test/assets/ignored_vulnerabilities.json +++ b/tests/automation/package_test/assets/ignored_vulnerabilities.json @@ -1,47 +1,46 @@ -[ - [ - "mlrun", - "<=1.1.0rc1", - "1.0.3rc2", - "Mlrun 1.1.0rc1 and prior uses a version of 'TensorFlow' (2.4.1) that has known vulnerabilities.", - "48250", - null, - null - ], - [ - "mlrun", - "<1.0.3rc1", - "1.0.0", - "Mlrun 1.0.3rc1 adds \"pillow~=9.0\" to requirements to tackle vulnerabilities.", - "49220", - null, - null - ], - [ - "mlrun", - "<1.0.3rc1", - "1.0.0", - "Mlrun 1.0.3rc1 adds \"notebook~=6.4\" to requirements to tackle vulnerabilities.", - "49216", - null, - null - ], - [ - "mlrun", - "<1.0.3rc1", - "1.0.0", - "Mlrun 1.0.3rc1 adds command to install security fixes in Docker base image.\r\nhttps://github.com/mlrun/mlrun/pull/1997/commits/de4c87f478f8d76dd8e46942588c81ef0d0b481e", - "49213", - null, - null - ], - [ - "kubernetes", - ">0", - "12.0.1", - "Kubernetes (python client) uses Kubernetes API, which has an unfixed vulnerability, CVE-2021-29923: Go before 1.17 does not properly consider extraneous zero characters at the beginning of an IP address octet, which (in some situations) allows attackers to bypass access control that is based on IP addresses, because of unexpected octal interpretation. This affects net.ParseIP and net.ParseCIDR. Kubernetes interprets leading zeros on IPv4 addresses as decimal to keep backwards compatibility, but users relying on parser alignment will be impacted by this CVE.\\r\\nhttps://github.com/kubernetes/kubernetes/pull/104368\\r\\nhttps://github.com/kubernetes/kubernetes/issues/108074", - "45114", - null, - null +{ + "vulnerabilities": [ + { + "vulnerability_id": "11111", + "package_name": "mlrun", + "ignored": {}, + "ignored_reason": null, + "ignored_expires": null, + "vulnerable_spec": "<1.5.0", + "all_vulnerable_specs": [ + "<1.5.0" + ], + "analyzed_version": "1.3.0", + "advisory": "Mlrun 1.3.0 uses TensorFlow' (2.4.1) which is really terrible", + "is_transitive": false, + "published_date": null, + "fixed_versions": [], + "closest_versions_without_known_vulnerabilities": [], + "resources": [], + "CVE": "CVE-2021-41496", + "severity": null + }, + { + "vulnerability_id": "22222", + "package_name": "kubernetes", + "ignored": {}, + "ignored_reason": null, + "ignored_expires": null, + "vulnerable_spec": "<2.0.0", + "all_vulnerable_specs": [ + "<2.0.0" + ], + "analyzed_version": "1.24.0", + "advisory": "Kubernetes 1.x versions have unfixed vulnerability, CVE-2021-29923 which can blow up earth", + "is_transitive": false, + "published_date": null, + "fixed_versions": [], + "closest_versions_without_known_vulnerabilities": [], + "resources": [], + "CVE": "CVE-666-666666", + "severity": null, + "affected_versions": [], + "more_info_url": "https://WereAllGonnaDie.com/today" + } ] -] +} \ No newline at end of file diff --git a/tests/automation/package_test/test_package_test.py b/tests/automation/package_test/test_package_test.py index e451274c6881..bee15a9ada5c 100644 --- a/tests/automation/package_test/test_package_test.py +++ b/tests/automation/package_test/test_package_test.py @@ -26,17 +26,25 @@ def test_test_requirements_vulnerabilities(): cases = [ { "output": """ -[ - [ - "fastapi", - "<0.75.2", - "0.67.0", - "Fastapi 0.75.2 updates its dependency 'ujson' ranges to include a security fix.", - "48159", - null, - null - ] -]""", + { + "vulnerabilities": [ + { + "vulnerability_id": "44716", + "package_name": "numpy", + "vulnerable_spec": "<1.22.0", + "all_vulnerable_specs": [ + "<1.22.0" + ], + "analyzed_version": "1.21.6", + "advisory": "Numpy 1.22.0 includes a fix for CVE-2021-41496", + "CVE": "CVE-2021-41496", + "severity": null, + "affected_versions": [], + "more_info_url": "https://pyup.io/v/44716/f17" + } + ] + } +""", "expected_to_fail": True, }, { diff --git a/tests/common_fixtures.py b/tests/common_fixtures.py index ef0b4808182e..462fceaf9aa9 100644 --- a/tests/common_fixtures.py +++ b/tests/common_fixtures.py @@ -35,6 +35,7 @@ import mlrun.datastore import mlrun.db import mlrun.k8s_utils +import mlrun.projects.project import mlrun.utils import mlrun.utils.singleton from mlrun.api.db.sqldb.db import SQLDB @@ -42,6 +43,7 @@ from mlrun.api.initial_data import init_data from mlrun.api.utils.singletons.db import initialize_db from mlrun.config import config +from mlrun.lists import ArtifactList from mlrun.runtimes import BaseRuntime from mlrun.runtimes.function import NuclioStatus from mlrun.runtimes.utils import global_context @@ -80,6 +82,9 @@ def config_test_base(): mlrun.datastore.store_manager._db = None mlrun.datastore.store_manager._stores = {} + # no need to raise error when using nop_db + mlrun.mlconf.httpdb.nop_db.raise_error = False + # remove the is_running_as_api cache, so it won't pass between tests mlrun.config._is_running_as_api = None # remove singletons in case they were changed (we don't want changes to pass between tests) @@ -91,7 +96,6 @@ def config_test_base(): mlrun.api.utils.singletons.k8s._k8s = None mlrun.api.utils.singletons.logs_dir.logs_dir = None - mlrun.k8s_utils._k8s = None mlrun.runtimes.runtime_handler_instances_cache = {} mlrun.runtimes.utils.cached_mpijob_crd_version = None mlrun.runtimes.utils.cached_nuclio_version = None @@ -99,6 +103,10 @@ def config_test_base(): # TODO: update this to "sidecar" once the default mode is changed mlrun.config.config.log_collector.mode = "legacy" + # revert change of default project after project creation + mlrun.mlconf.default_project = "default" + mlrun.projects.project.pipeline_context.set(None) + @pytest.fixture def aioresponses_mock(): @@ -116,20 +124,27 @@ def db(): db_session = None try: config.httpdb.dsn = dsn - _init_engine(dsn) + _init_engine(dsn=dsn) init_data() initialize_db() db_session = create_session() db = SQLDB(dsn) db.initialize(db_session) + config.dbpath = dsn finally: if db_session is not None: db_session.close() mlrun.api.utils.singletons.db.initialize_db(db) + mlrun.api.utils.singletons.logs_dir.initialize_logs_dir() mlrun.api.utils.singletons.project_member.initialize_project_member() return db +@pytest.fixture +def ensure_default_project() -> mlrun.projects.project.MlrunProject: + return mlrun.get_or_create_project("default") + + @pytest.fixture() def db_session() -> Generator: db_session = None @@ -141,6 +156,14 @@ def db_session() -> Generator: db_session.close() +@pytest.fixture() +def running_as_api(): + old_is_running_as_api = mlrun.config.is_running_as_api + mlrun.config.is_running_as_api = unittest.mock.Mock(return_value=True) + yield + mlrun.config.is_running_as_api = old_is_running_as_api + + @pytest.fixture def patch_file_forbidden(monkeypatch): class MockV3ioClient: @@ -190,50 +213,81 @@ class RunDBMock: def __init__(self): self.kind = "http" self._pipeline = None - self._function = None - self._artifact = None + self._functions = {} + self._artifacts = {} + self._project_name = None + self._project = None self._runs = {} def reset(self): - self._function = None + self._functions = {} self._pipeline = None self._project_name = None self._project = None - self._artifact = None + self._artifacts = {} # Expected to return a hash-key def store_function(self, function, name, project="", tag=None, versioned=False): - self._function = function - return "1234-1234-1234-1234" + hash_key = mlrun.utils.fill_function_hash(function, tag) + self._functions[name] = function + return hash_key + + def store_artifact(self, key, artifact, uid, iter=None, tag="", project=""): + self._artifacts[key] = artifact + return artifact + + def read_artifact(self, key, tag=None, iter=None, project=""): + return self._artifacts.get(key, None) + + def list_artifacts( + self, + name="", + project="", + tag="", + labels=None, + since=None, + until=None, + kind=None, + category=None, + iter: int = None, + best_iteration: bool = False, + as_records: bool = False, + use_tag_as_uid: bool = None, + ): + def filter_artifact(artifact): + if artifact["metadata"].get("tag", None) == tag: + return True + + return ArtifactList(filter(filter_artifact, self._artifacts.values())) def store_run(self, struct, uid, project="", iter=0): - self._runs[uid] = { - "struct": struct, - "project": project, - "iter": iter, - } + if hasattr(struct, "to_dict"): + struct = struct.to_dict() - def read_run(self, uid, project, iter=0): - return self._runs.get(uid, {}) + if project: + struct["metadata"]["project"] = project - def store_artifact(self, key, artifact, uid, iter=None, tag="", project=""): - self._artifact = artifact + if iter: + struct["status"]["iteration"] = iter - def read_artifact(self, key, tag=None, iter=None, project=""): - return self._artifact + self._runs[uid] = struct - def get_function(self, function, project, tag): - return { - "name": function, - "metadata": "bla", - "uid": "1234-1234-1234-1234", - "project": project, - "tag": tag, - } + def read_run(self, uid, project, iter=0): + return self._runs.get(uid, {}) + + def get_function(self, function, project, tag, hash_key=None): + if function not in self._functions: + raise mlrun.errors.MLRunNotFoundError("Function not found") + return self._functions[function] def submit_job(self, runspec, schedule=None): return {"status": {"status_text": "just a status"}} + def watch_log(self, uid, project="", watch=True, offset=0): + # mock API updated the run status to completed + self._runs[uid]["status"] = {"state": "completed"} + return "completed", 0 + def submit_pipeline( self, project, @@ -250,13 +304,21 @@ def submit_pipeline( def store_project(self, name, project): self._project_name = name + + if isinstance(project, dict): + project = mlrun.projects.MlrunProject.from_dict(project) self._project = project def get_project(self, name): if self._project_name and name == self._project_name: return self._project - else: - raise mlrun.errors.MLRunNotFoundError("Project not found") + + elif name == config.default_project and not self._project: + project = mlrun.projects.MlrunProject(name) + self.store_project(name, project) + return project + + raise mlrun.errors.MLRunNotFoundError(f"Project '{name}' not found") def remote_builder( self, @@ -266,16 +328,17 @@ def remote_builder( skip_deployed=False, builder_env=None, ): - self._function = func.to_dict() + function = func.to_dict() status = NuclioStatus( state="ready", nuclio_name="test-nuclio-name", ) + self._functions[function["metadata"]["name"]] = function return { "data": { "status": status.to_dict(), - "metadata": self._function.get("metadata"), - "spec": self._function.get("spec"), + "metadata": function.get("metadata"), + "spec": function.get("spec"), } } @@ -291,10 +354,12 @@ def get_builder_status( def update_run(self, updates: dict, uid, project="", iter=0): for key, value in updates.items(): - update_in(self._runs[uid]["struct"], key, value) + update_in(self._runs[uid], key, value) + + def assert_no_mount_or_creds_configured(self, function_name=None): + function = self._get_function_internal(function_name) - def assert_no_mount_or_creds_configured(self): - env_list = self._function["spec"]["env"] + env_list = function["spec"]["env"] env_params = [item["name"] for item in env_list] for env_variable in [ "V3IO_USERNAME", @@ -304,15 +369,16 @@ def assert_no_mount_or_creds_configured(self): ]: assert env_variable not in env_params - volume_mounts = self._function["spec"]["volume_mounts"] - volumes = self._function["spec"]["volumes"] + volume_mounts = function["spec"]["volume_mounts"] + volumes = function["spec"]["volumes"] assert len(volumes) == 0 assert len(volume_mounts) == 0 def assert_v3io_mount_or_creds_configured( - self, v3io_user, v3io_access_key, cred_only=False + self, v3io_user, v3io_access_key, cred_only=False, function_name=None ): - env_list = self._function["spec"]["env"] + function = self._get_function_internal(function_name) + env_list = function["spec"]["env"] env_dict = {item["name"]: item["value"] for item in env_list} expected_env = { "V3IO_USERNAME": v3io_user, @@ -323,8 +389,8 @@ def assert_v3io_mount_or_creds_configured( result.pop("dictionary_item_removed") assert result == {} - volume_mounts = self._function["spec"]["volume_mounts"] - volumes = self._function["spec"]["volumes"] + volume_mounts = function["spec"]["volume_mounts"] + volumes = function["spec"]["volumes"] if cred_only: assert len(volumes) == 0 @@ -348,8 +414,8 @@ def assert_v3io_mount_or_creds_configured( assert deepdiff.DeepDiff(volumes, expected_volumes) == {} assert deepdiff.DeepDiff(volume_mounts, expected_mounts) == {} - def assert_pvc_mount_configured(self, pvc_params): - function_spec = self._function["spec"] + def assert_pvc_mount_configured(self, pvc_params, function_name=None): + function_spec = self._get_function_internal(function_name)["spec"] expected_volumes = [ { @@ -367,8 +433,9 @@ def assert_pvc_mount_configured(self, pvc_params): assert deepdiff.DeepDiff(function_spec["volumes"], expected_volumes) == {} assert deepdiff.DeepDiff(function_spec["volume_mounts"], expected_mounts) == {} - def assert_s3_mount_configured(self, s3_params): - env_list = self._function["spec"]["env"] + def assert_s3_mount_configured(self, s3_params, function_name=None): + function = self._get_function_internal(function_name) + env_list = function["spec"]["env"] param_names = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"] secret_name = s3_params.get("secret_name") non_anonymous = s3_params.get("non_anonymous") @@ -393,8 +460,9 @@ def assert_s3_mount_configured(self, s3_params): expected_envs["S3_NON_ANONYMOUS"] = "true" assert expected_envs == env_dict - def assert_env_variables(self, expected_env_dict): - env_list = self._function["spec"]["env"] + def assert_env_variables(self, expected_env_dict, function_name=None): + function = self._get_function_internal(function_name) + env_list = function["spec"]["env"] env_dict = {item["name"]: item["value"] for item in env_list} for key, value in expected_env_dict.items(): @@ -402,10 +470,16 @@ def assert_env_variables(self, expected_env_dict): def verify_authorization( self, - authorization_verification_input: mlrun.api.schemas.AuthorizationVerificationInput, + authorization_verification_input: mlrun.common.schemas.AuthorizationVerificationInput, ): pass + def _get_function_internal(self, function_name: str = None): + if function_name: + return self._functions[function_name] + + return list(self._functions.values())[0] + @pytest.fixture() def rundb_mock() -> RunDBMock: @@ -415,17 +489,19 @@ def rundb_mock() -> RunDBMock: mlrun.db.get_run_db = unittest.mock.Mock(return_value=mock_object) mlrun.get_run_db = unittest.mock.Mock(return_value=mock_object) - orig_use_remote_api = BaseRuntime._use_remote_api orig_get_db = BaseRuntime._get_db BaseRuntime._get_db = unittest.mock.Mock(return_value=mock_object) orig_db_path = config.dbpath config.dbpath = "http://localhost:12345" + + # Create the default project to mimic real MLRun DB (the default project is always available for use): + mlrun.get_or_create_project("default") + yield mock_object # Have to revert the mocks, otherwise scheduling tests (and possibly others) are failing mlrun.db.get_run_db = orig_get_run_db mlrun.get_run_db = orig_get_run_db - BaseRuntime._use_remote_api = orig_use_remote_api BaseRuntime._get_db = orig_get_db config.dbpath = orig_db_path diff --git a/tests/datastore/test_base.py b/tests/datastore/test_base.py index de4cacdf5013..ac24656c443d 100644 --- a/tests/datastore/test_base.py +++ b/tests/datastore/test_base.py @@ -40,6 +40,13 @@ def test_http_fs_parquet_as_df(): data_item.as_df() +def test_http_fs_parquet_with_params_as_df(): + data_item = mlrun.datastore.store_manager.object( + "https://s3.wasabisys.com/iguazio/data/market-palce/aggregate/metrics.pq?param1=1¶m2=2" + ) + data_item.as_df() + + def test_s3_fs_parquet_as_df(): data_item = mlrun.datastore.store_manager.object( "s3://aws-roda-hcls-datalake/gnomad/chrm/run-DataSink0-1-part-block-0-r-00009-snappy.parquet" diff --git a/tests/feature-store/test_infer.py b/tests/feature-store/test_infer.py index b7c2cd1e564c..129cd16c7f40 100644 --- a/tests/feature-store/test_infer.py +++ b/tests/feature-store/test_infer.py @@ -107,7 +107,7 @@ def test_target_no_time_column(): ) -def test_check_permissions(): +def test_check_permissions(rundb_mock, monkeypatch): data = pd.DataFrame( { "time_stamp": [ @@ -121,54 +121,35 @@ def test_check_permissions(): ) data_set1 = fstore.FeatureSet("fs1", entities=[Entity("string")]) - mlrun.db.FileRunDB.verify_authorization = unittest.mock.Mock( - side_effect=mlrun.errors.MLRunAccessDeniedError("") + monkeypatch.setattr( + rundb_mock, + "verify_authorization", + unittest.mock.Mock(side_effect=mlrun.errors.MLRunAccessDeniedError("")), ) - try: + with pytest.raises(mlrun.errors.MLRunAccessDeniedError): fstore.preview( data_set1, data, entity_columns=[Entity("string")], timestamp_key="time_stamp", ) - assert False - except mlrun.errors.MLRunAccessDeniedError: - pass - - try: + with pytest.raises(mlrun.errors.MLRunAccessDeniedError): fstore.ingest(data_set1, data, infer_options=fstore.InferOptions.default()) - assert False - except mlrun.errors.MLRunAccessDeniedError: - pass features = ["fs1.*"] feature_vector = fstore.FeatureVector("test", features) - try: - fstore.get_offline_features( - feature_vector, entity_timestamp_column="time_stamp" - ) - assert False - except mlrun.errors.MLRunAccessDeniedError: - pass + with pytest.raises(mlrun.errors.MLRunAccessDeniedError): + fstore.get_offline_features(feature_vector) - try: + with pytest.raises(mlrun.errors.MLRunAccessDeniedError): fstore.get_online_feature_service(feature_vector) - assert False - except mlrun.errors.MLRunAccessDeniedError: - pass - try: + with pytest.raises(mlrun.errors.MLRunAccessDeniedError): fstore.deploy_ingestion_service(featureset=data_set1) - assert False - except mlrun.errors.MLRunAccessDeniedError: - pass - try: + with pytest.raises(mlrun.errors.MLRunAccessDeniedError): data_set1.purge_targets() - assert False - except mlrun.errors.MLRunAccessDeniedError: - pass def test_check_timestamp_key_is_entity(): diff --git a/tests/feature-store/test_steps.py b/tests/feature-store/test_steps.py index 0d92511c6fa0..258981478e5d 100644 --- a/tests/feature-store/test_steps.py +++ b/tests/feature-store/test_steps.py @@ -43,7 +43,7 @@ def extract_meta(event): return event -def test_set_event_meta(): +def test_set_event_meta(rundb_mock): function = mlrun.new_function("test1", kind="serving") flow = function.set_topology("flow") flow.to(SetEventMetadata(id_path="myid", key_path="mykey")).to( @@ -60,7 +60,7 @@ def test_set_event_meta(): } -def test_set_event_random_id(): +def test_set_event_random_id(rundb_mock): function = mlrun.new_function("test2", kind="serving") flow = function.set_topology("flow") flow.to(SetEventMetadata(random_id=True)).to( @@ -458,6 +458,89 @@ def test_pandas_step_data_extractor( ) +@pytest.mark.parametrize( + "mapping", + [ + {"age": {"ranges": {"one": [0, 30], "two": ["a", "inf"]}}}, + {"names": {"A": 1, "B": False}}, + ], +) +def test_mapvalues_mixed_types_validator(rundb_mock, mapping): + data, _ = get_data() + data_to_ingest = data.copy() + # Define the corresponding FeatureSet + data_set_pandas = fstore.FeatureSet( + "fs-new", + entities=[fstore.Entity("id")], + description="feature set", + engine="pandas", + ) + # Pre-processing grpah steps + data_set_pandas.graph.to( + MapValues( + mapping=mapping, + with_original_features=True, + ) + ) + data_set_pandas._run_db = rundb_mock + + data_set_pandas.reload = unittest.mock.Mock() + data_set_pandas.save = unittest.mock.Mock() + data_set_pandas.purge_targets = unittest.mock.Mock() + # Create a temp directory: + output_path = tempfile.TemporaryDirectory() + + with pytest.raises( + mlrun.errors.MLRunInvalidArgumentError, + match=f"^MapValues - mapping values of the same column must be in the same type, which was not the case for" + f" Column '{list(mapping.keys())[0]}'$", + ): + fstore.ingest( + data_set_pandas, + data_to_ingest, + targets=[ParquetTarget(path=f"{output_path.name}/temp.parquet")], + ) + + +def test_mapvalues_combined_mapping_validator(rundb_mock): + data, _ = get_data() + data_to_ingest = data.copy() + # Define the corresponding FeatureSet + data_set_pandas = fstore.FeatureSet( + "fs-new", + entities=[fstore.Entity("id")], + description="feature set", + engine="pandas", + ) + # Pre-processing grpah steps + data_set_pandas.graph.to( + MapValues( + mapping={ + "age": {"ranges": {"one": [0, 30], "two": ["a", "inf"]}, 4: "kid"} + }, + with_original_features=True, + ) + ) + data_set_pandas._run_db = rundb_mock + + data_set_pandas.reload = unittest.mock.Mock() + data_set_pandas.save = unittest.mock.Mock() + data_set_pandas.purge_targets = unittest.mock.Mock() + # Create a temp directory: + output_path = tempfile.TemporaryDirectory() + + with pytest.raises( + mlrun.errors.MLRunInvalidArgumentError, + match="^MapValues - mapping values of the same column can not combine ranges and single " + "replacement, which is the case for column 'age'$", + ): + fstore.ingest( + data_set_pandas, + data_to_ingest, + targets=[ParquetTarget(path=f"{output_path.name}/temp.parquet")], + ) + + @pytest.mark.parametrize("set_index_before", [True, False, 0]) @pytest.mark.parametrize("entities", [["id"], ["id", "name"]]) def test_pandas_step_data_validator(rundb_mock, entities, set_index_before): diff --git a/tests/frameworks/lgbm/test_lgbm.py b/tests/frameworks/lgbm/test_lgbm.py index e8d6c4f43c5c..cbf19e12fa7e 100644 --- a/tests/frameworks/lgbm/test_lgbm.py +++ b/tests/frameworks/lgbm/test_lgbm.py @@ -34,7 +34,7 @@ @pytest.mark.parametrize("algorithm_functionality", ALGORITHM_FUNCTIONALITIES) -def test_training_api_training(algorithm_functionality: str): +def test_training_api_training(rundb_mock, algorithm_functionality: str): # Run training: train_run = mlrun.new_function().run( artifact_path="./temp", @@ -53,7 +53,7 @@ def test_training_api_training(algorithm_functionality: str): @pytest.mark.parametrize("algorithm_functionality", ALGORITHM_FUNCTIONALITIES) -def test_sklearn_api_training(algorithm_functionality: str): +def test_sklearn_api_training(rundb_mock, algorithm_functionality: str): # Run training: train_run = mlrun.new_function().run( artifact_path="./temp", @@ -81,7 +81,7 @@ def test_sklearn_api_training(algorithm_functionality: str): @pytest.mark.parametrize("algorithm_functionality", ALGORITHM_FUNCTIONALITIES) -def test_sklearn_api_evaluation(algorithm_functionality: str): +def test_sklearn_api_evaluation(rundb_mock, algorithm_functionality: str): # Run training: train_run = mlrun.new_function().run( artifact_path="./temp2", diff --git a/tests/frameworks/test_ml_frameworks.py b/tests/frameworks/test_ml_frameworks.py index 108a04ff286a..134a780ac71b 100644 --- a/tests/frameworks/test_ml_frameworks.py +++ b/tests/frameworks/test_ml_frameworks.py @@ -13,6 +13,7 @@ # limitations under the License. # import json +import typing from typing import Dict, List, Tuple import pytest @@ -33,7 +34,7 @@ class FrameworkKeys: SKLEARN = "sklearn" -FRAMEWORKS = { # type: Dict[str, Tuple[MLFunctions, ArtifactsLibrary, MetricsLibrary]] +FRAMEWORKS = { FrameworkKeys.XGBOOST: ( XGBoostFunctions, XGBoostArtifactsLibrary, @@ -44,36 +45,47 @@ class FrameworkKeys: SKLearnArtifactsLibrary, MetricsLibrary, ), -} -FRAMEWORKS_KEYS = [ # type: List[str] +} # type: Dict[str, Tuple[MLFunctions, ArtifactsLibrary, MetricsLibrary]] +FRAMEWORKS_KEYS = [ FrameworkKeys.XGBOOST, FrameworkKeys.SKLEARN, -] -ALGORITHM_FUNCTIONALITIES = [ # type: List[str] +] # type: List[str] +ALGORITHM_FUNCTIONALITIES = [ algorithm_functionality.value for algorithm_functionality in AlgorithmFunctionality if "Unknown" not in algorithm_functionality.value -] +] # type: List[str] +FRAMEWORKS_ALGORITHM_FUNCTIONALITIES = [ + (framework, algorithm_functionality) + for framework in FRAMEWORKS_KEYS + for algorithm_functionality in ALGORITHM_FUNCTIONALITIES + if ( + framework is not FrameworkKeys.XGBOOST + or algorithm_functionality + != AlgorithmFunctionality.MULTI_OUTPUT_MULTICLASS_CLASSIFICATION.value + ) +] # type: List[Tuple[str, str]] + +def framework_algorithm_functionality_pair_ids( + framework_algorithm_functionality_pair: typing.Tuple[str, str] +) -> str: + framework, algorithm_functionality = framework_algorithm_functionality_pair + return f"{framework}-{algorithm_functionality}" -@pytest.mark.parametrize("framework", FRAMEWORKS_KEYS) -@pytest.mark.parametrize("algorithm_functionality", ALGORITHM_FUNCTIONALITIES) -def test_training(framework: str, algorithm_functionality: str): + +@pytest.mark.parametrize( + "framework_algorithm_functionality_pair", + FRAMEWORKS_ALGORITHM_FUNCTIONALITIES, + ids=framework_algorithm_functionality_pair_ids, +) +def test_training(framework_algorithm_functionality_pair: typing.Tuple[str, str]): + framework, algorithm_functionality = framework_algorithm_functionality_pair # Unpack the framework classes: (functions, artifacts_library, metrics_library) = FRAMEWORKS[ framework ] # type: MLFunctions, ArtifactsLibrary, MetricsLibrary - # Skips: - if ( - functions is XGBoostFunctions - and algorithm_functionality - == AlgorithmFunctionality.MULTI_OUTPUT_MULTICLASS_CLASSIFICATION.value - ): - pytest.skip( - "multiclass multi output classification are not supported in 'xgboost'." - ) - # Run training: train_run = mlrun.new_function().run( artifact_path="./temp", @@ -100,24 +112,21 @@ def test_training(framework: str, algorithm_functionality: str): assert len(train_run.status.results) == len(expected_results) -@pytest.mark.parametrize("framework", FRAMEWORKS_KEYS) -@pytest.mark.parametrize("algorithm_functionality", ALGORITHM_FUNCTIONALITIES) -def test_evaluation(framework: str, algorithm_functionality: str): +@pytest.mark.parametrize( + "framework_algorithm_functionality_pair", + FRAMEWORKS_ALGORITHM_FUNCTIONALITIES, + ids=framework_algorithm_functionality_pair_ids, +) +def test_evaluation( + rundb_mock, + framework_algorithm_functionality_pair: typing.Tuple[str, str], +): + framework, algorithm_functionality = framework_algorithm_functionality_pair # Unpack the framework classes: (functions, artifacts_library, metrics_library) = FRAMEWORKS[ framework ] # type: MLFunctions, ArtifactsLibrary, MetricsLibrary - # Skips: - if ( - functions is XGBoostFunctions - and algorithm_functionality - == AlgorithmFunctionality.MULTI_OUTPUT_MULTICLASS_CLASSIFICATION.value - ): - pytest.skip( - "multiclass multi output classification are not supported in 'xgboost'." - ) - # Run training: train_run = mlrun.new_function().run( artifact_path="./temp2", @@ -147,7 +156,8 @@ def test_evaluation(framework: str, algorithm_functionality: str): expected_artifacts = [ plan for plan in artifacts_library.get_plans(model=dummy_model, y=dummy_y) - if not ( # Count only pre and post prediction artifacts (evaluation artifacts). + if not ( + # Count only pre and post prediction artifacts (evaluation artifacts). plan.is_ready(stage=MLPlanStages.POST_FIT, is_probabilities=False) or plan.is_ready(stage=MLPlanStages.PRE_FIT, is_probabilities=False) ) diff --git a/tests/integration/azure_blob/test_azure_blob.py b/tests/integration/azure_blob/test_azure_blob.py index 242f7cc22fc1..093eca6e467a 100644 --- a/tests/integration/azure_blob/test_azure_blob.py +++ b/tests/integration/azure_blob/test_azure_blob.py @@ -16,6 +16,7 @@ import random from pathlib import Path +import pandas as pd import pytest import yaml @@ -29,6 +30,7 @@ config = yaml.safe_load(fp) test_filename = here / "test.txt" +test_csv_filename = here / "test_data.csv" with open(test_filename, "r") as f: test_string = f.read() @@ -166,3 +168,16 @@ def test_blob_upload(auth_method): response = upload_data_item.get() assert response.decode() == test_string, "Result differs from original test" + + +def test_as_df(auth_method): + source_df = pd.read_csv(test_csv_filename) + storage_options = verify_auth_parameters_and_configure_env(auth_method) + blob_path = "az://" + config["env"].get("AZURE_CONTAINER") + blob_url = blob_path + "/" + blob_dir + "/" + blob_file.replace("txt", "csv") + + upload_data_item = mlrun.run.get_dataitem(blob_url, storage_options) + upload_data_item.upload(test_csv_filename) + + result_df = upload_data_item.as_df() + assert result_df.equals(source_df) diff --git a/tests/integration/azure_blob/test_data.csv b/tests/integration/azure_blob/test_data.csv new file mode 100644 index 000000000000..ed8408276f9f --- /dev/null +++ b/tests/integration/azure_blob/test_data.csv @@ -0,0 +1,4 @@ +Name,Age,City +Alice,30,Los Angeles +Bob,35,Chicago +Jane,28,San Francisco diff --git a/tests/integration/sdk_api/artifacts/test_artifacts.py b/tests/integration/sdk_api/artifacts/test_artifacts.py index 66529f6567a2..6e9568a00f2b 100644 --- a/tests/integration/sdk_api/artifacts/test_artifacts.py +++ b/tests/integration/sdk_api/artifacts/test_artifacts.py @@ -12,11 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import os +import pathlib +import shutil +import unittest.mock + import pandas import mlrun import mlrun.artifacts import tests.integration.sdk_api.base +from tests import conftest + +results_dir = (pathlib.Path(conftest.results) / "artifacts").absolute() class TestArtifacts(tests.integration.sdk_api.base.TestMLRunIntegration): @@ -68,6 +76,72 @@ def test_list_artifacts_filter_by_kind(self): assert len(artifacts) == 1, "bad number of model artifacts" artifacts = db.list_artifacts( - project=prj, category=mlrun.api.schemas.ArtifactCategories.dataset + project=prj, category=mlrun.common.schemas.ArtifactCategories.dataset ) assert len(artifacts) == 1, "bad number of dataset artifacts" + + def test_export_import(self): + project = mlrun.new_project("log-mod") + target_project = mlrun.new_project("log-mod2") + for mode in [False, True]: + mlrun.mlconf.artifacts.generate_target_path_from_artifact_hash = mode + + model = project.log_model( + "mymod", + body=b"123", + model_file="model.pkl", + extra_data={"kk": b"456"}, + artifact_path=results_dir, + ) + + for suffix in ["yaml", "json", "zip"]: + # export the artifact to a file + model.export(f"{results_dir}/a.{suffix}") + + # import and log the artifact to the new project + artifact = target_project.import_artifact( + f"{results_dir}/a.{suffix}", + f"mod-{suffix}", + artifact_path=results_dir, + ) + assert artifact.kind == "model" + assert artifact.metadata.key == f"mod-{suffix}" + assert artifact.metadata.project == "log-mod2" + temp_path, model_spec, extra_dataitems = mlrun.artifacts.get_model( + artifact.uri + ) + with open(temp_path, "rb") as fp: + data = fp.read() + assert data == b"123" + assert extra_dataitems["kk"].get() == b"456" + + def test_import_remote_zip(self): + project = mlrun.new_project("log-mod") + target_project = mlrun.new_project("log-mod2") + model = project.log_model( + "mymod", + body=b"123", + model_file="model.pkl", + extra_data={"kk": b"456"}, + artifact_path=results_dir, + ) + + artifact_url = f"{results_dir}/a.zip" + model.export(artifact_url) + + # mock downloading the artifact from s3 by copying it locally to a temp path + mlrun.datastore.base.DataStore.download = unittest.mock.MagicMock( + side_effect=shutil.copyfile + ) + artifact = target_project.import_artifact( + f"s3://ֿ{results_dir}/a.zip", + "mod-zip", + artifact_path=results_dir, + ) + + temp_local_path = mlrun.datastore.base.DataStore.download.call_args[0][1] + assert artifact.metadata.project == "log-mod2" + # verify that the original artifact was not deleted + assert os.path.exists(artifact_url) + # verify that the temp path was deleted after the import + assert not os.path.exists(temp_local_path) diff --git a/tests/integration/sdk_api/base.py b/tests/integration/sdk_api/base.py index 5e7b3ee8ced4..891c649157c9 100644 --- a/tests/integration/sdk_api/base.py +++ b/tests/integration/sdk_api/base.py @@ -21,7 +21,7 @@ import pymysql import mlrun -import mlrun.api.schemas +import mlrun.common.schemas import tests.conftest from mlrun.db.httpdb import HTTPRunDB from mlrun.utils import create_logger, retry_until_successful @@ -142,8 +142,7 @@ def _run_api(self): { "MLRUN_VERSION": "0.0.0+unstable", "MLRUN_HTTPDB__DSN": self.db_dsn, - # integration tests run in docker, and do no support sidecars for log collection - "MLRUN__LOG_COLLECTOR__MODE": "legacy", + "MLRUN_LOG_LEVEL": "DEBUG", } ), cwd=TestMLRunIntegration.root_path, @@ -215,7 +214,7 @@ def _extend_current_env(env): @staticmethod def _check_api_is_healthy(url): health_url = f"{url}/{HTTPRunDB.get_api_path_prefix()}/healthz" - timeout = 30 + timeout = 90 if not tests.conftest.wait_for_server(health_url, timeout): raise RuntimeError(f"API did not start after {timeout} sec") diff --git a/tests/integration/sdk_api/httpdb/runs/__init__.py b/tests/integration/sdk_api/httpdb/runs/__init__.py new file mode 100644 index 000000000000..245d0063f465 --- /dev/null +++ b/tests/integration/sdk_api/httpdb/runs/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/integration/sdk_api/httpdb/assets/big-run.json b/tests/integration/sdk_api/httpdb/runs/assets/big-run.json similarity index 100% rename from tests/integration/sdk_api/httpdb/assets/big-run.json rename to tests/integration/sdk_api/httpdb/runs/assets/big-run.json diff --git a/tests/integration/sdk_api/httpdb/runs/test_dask.py b/tests/integration/sdk_api/httpdb/runs/test_dask.py new file mode 100644 index 000000000000..cb716258d004 --- /dev/null +++ b/tests/integration/sdk_api/httpdb/runs/test_dask.py @@ -0,0 +1,56 @@ +# Copyright 2023 MLRun Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +import mlrun +import tests.conftest +import tests.integration.sdk_api.base + +has_dask = False +try: + import dask # noqa + + has_dask = True +except ImportError: + pass + + +def inc(x): + return x + 2 + + +def my_func(context, p1=1, p2="a-string"): + print(f"Run: {context.name} (uid={context.uid})") + print(f"Params: p1={p1}, p2={p2}\n") + + x = context.dask_client.submit(inc, p1) + + context.log_result("accuracy", x.result()) + context.log_metric("loss", 7) + context.log_artifact("chart", body="abc") + return f"tst-me-{context.iteration}" + + +@pytest.mark.skipif(not has_dask, reason="missing dask") +class TestDask(tests.integration.sdk_api.base.TestMLRunIntegration): + def test_dask_local(self, ensure_default_project): + spec = tests.conftest.tag_test( + mlrun.new_task(params={"p1": 3, "p2": "vv"}), "test_dask_local" + ) + function = mlrun.new_function(kind="dask") + function.spec.remote = False + run = function.run(spec, handler=my_func) + tests.conftest.verify_state(run) diff --git a/tests/integration/sdk_api/httpdb/test_runs.py b/tests/integration/sdk_api/httpdb/runs/test_runs.py similarity index 73% rename from tests/integration/sdk_api/httpdb/test_runs.py rename to tests/integration/sdk_api/httpdb/runs/test_runs.py index a02653d62941..95a3e2c26ae1 100644 --- a/tests/integration/sdk_api/httpdb/test_runs.py +++ b/tests/integration/sdk_api/httpdb/runs/test_runs.py @@ -18,8 +18,9 @@ import pytest import mlrun -import mlrun.api.schemas +import mlrun.common.schemas import tests.integration.sdk_api.base +from tests.conftest import examples_path class TestRuns(tests.integration.sdk_api.base.TestMLRunIntegration): @@ -76,9 +77,9 @@ def test_list_runs(self): runs = _list_and_assert_objects( 3, project=projects[0], - partition_by=mlrun.api.schemas.RunPartitionByField.name, - partition_sort_by=mlrun.api.schemas.SortField.created, - partition_order=mlrun.api.schemas.OrderType.asc, + partition_by=mlrun.common.schemas.RunPartitionByField.name, + partition_sort_by=mlrun.common.schemas.SortField.created, + partition_order=mlrun.common.schemas.OrderType.asc, ) # sorted by ascending created so only the first ones created for run in runs: @@ -88,9 +89,9 @@ def test_list_runs(self): runs = _list_and_assert_objects( 3, project=projects[0], - partition_by=mlrun.api.schemas.RunPartitionByField.name, - partition_sort_by=mlrun.api.schemas.SortField.updated, - partition_order=mlrun.api.schemas.OrderType.desc, + partition_by=mlrun.common.schemas.RunPartitionByField.name, + partition_sort_by=mlrun.common.schemas.SortField.updated, + partition_order=mlrun.common.schemas.OrderType.desc, ) # sorted by descending updated so only the third ones created for run in runs: @@ -100,9 +101,9 @@ def test_list_runs(self): runs = _list_and_assert_objects( 15, project=projects[0], - partition_by=mlrun.api.schemas.RunPartitionByField.name, - partition_sort_by=mlrun.api.schemas.SortField.updated, - partition_order=mlrun.api.schemas.OrderType.desc, + partition_by=mlrun.common.schemas.RunPartitionByField.name, + partition_sort_by=mlrun.common.schemas.SortField.updated, + partition_order=mlrun.common.schemas.OrderType.desc, rows_per_partition=5, iter=True, ) @@ -111,9 +112,9 @@ def test_list_runs(self): runs = _list_and_assert_objects( 10, project=projects[0], - partition_by=mlrun.api.schemas.RunPartitionByField.name, - partition_sort_by=mlrun.api.schemas.SortField.updated, - partition_order=mlrun.api.schemas.OrderType.desc, + partition_by=mlrun.common.schemas.RunPartitionByField.name, + partition_sort_by=mlrun.common.schemas.SortField.updated, + partition_order=mlrun.common.schemas.OrderType.desc, rows_per_partition=5, max_partitions=2, iter=True, @@ -124,9 +125,9 @@ def test_list_runs(self): runs = _list_and_assert_objects( 6, project=projects[0], - partition_by=mlrun.api.schemas.RunPartitionByField.name, - partition_sort_by=mlrun.api.schemas.SortField.updated, - partition_order=mlrun.api.schemas.OrderType.desc, + partition_by=mlrun.common.schemas.RunPartitionByField.name, + partition_sort_by=mlrun.common.schemas.SortField.updated, + partition_order=mlrun.common.schemas.OrderType.desc, rows_per_partition=4, max_partitions=2, iter=False, @@ -137,7 +138,7 @@ def test_list_runs(self): _list_and_assert_objects( 0, project=projects[0], - partition_by=mlrun.api.schemas.RunPartitionByField.name, + partition_by=mlrun.common.schemas.RunPartitionByField.name, ) # An invalid partition-by field - will be failed by fastapi due to schema validation. with pytest.raises(mlrun.errors.MLRunHTTPError) as excinfo: @@ -145,7 +146,7 @@ def test_list_runs(self): 0, project=projects[0], partition_by="key", - partition_sort_by=mlrun.api.schemas.SortField.updated, + partition_sort_by=mlrun.common.schemas.SortField.updated, ) assert ( excinfo.value.response.status_code @@ -172,6 +173,27 @@ def test_list_runs(self): assert run["metadata"]["uid"] in uid_list uid_list.remove(run["metadata"]["uid"]) + def test_job_file(self, ensure_default_project): + filename = f"{examples_path}/training.py" + fn = mlrun.code_to_function(filename=filename, kind="job") + assert fn.kind == "job", "kind not set, test failed" + assert fn.spec.build.functionSourceCode, "code not embedded" + assert fn.spec.build.origin_filename == filename, "did not record filename" + assert type(fn.metadata.labels) == dict, "metadata labels were not set" + run = fn.run(workdir=str(examples_path), local=True) + + project, uri, tag, hash_key = mlrun.utils.parse_versioned_object_uri( + run.spec.function + ) + local_fn = mlrun.get_run_db().get_function( + uri, project, tag=tag, hash_key=hash_key + ) + assert local_fn["spec"]["command"] == filename, "wrong command path" + assert ( + local_fn["spec"]["build"]["functionSourceCode"] + == fn.spec.build.functionSourceCode + ), "code was not copied to local function" + def _list_and_assert_objects(expected_number_of_runs: int, **kwargs): runs = mlrun.get_run_db().list_runs(**kwargs) diff --git a/tests/integration/sdk_api/httpdb/test_exception_handling.py b/tests/integration/sdk_api/httpdb/test_exception_handling.py index e05d7104faf7..122b7e77655e 100644 --- a/tests/integration/sdk_api/httpdb/test_exception_handling.py +++ b/tests/integration/sdk_api/httpdb/test_exception_handling.py @@ -15,7 +15,7 @@ import pytest import mlrun -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.errors import tests.integration.sdk_api.base @@ -47,8 +47,8 @@ def test_exception_handling(self): # This is handled in the mlrun/api/main.py::http_status_error_handler invalid_project_name = "some_project" # Not using client class cause it does validation on client side and we want to fail on server side - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name=invalid_project_name) + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name=invalid_project_name) ) with pytest.raises( mlrun.errors.MLRunBadRequestError, diff --git a/tests/integration/sdk_api/hub/__init__.py b/tests/integration/sdk_api/hub/__init__.py new file mode 100644 index 000000000000..b3085be1eb56 --- /dev/null +++ b/tests/integration/sdk_api/hub/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/integration/sdk_api/marketplace/test_marketplace.py b/tests/integration/sdk_api/hub/test_hub.py similarity index 67% rename from tests/integration/sdk_api/marketplace/test_marketplace.py rename to tests/integration/sdk_api/hub/test_hub.py index a39f60618071..bf75dc285a4b 100644 --- a/tests/integration/sdk_api/marketplace/test_marketplace.py +++ b/tests/integration/sdk_api/hub/test_hub.py @@ -18,10 +18,10 @@ import tests.integration.sdk_api.base -class TestMarketplace(tests.integration.sdk_api.base.TestMLRunIntegration): +class TestHub(tests.integration.sdk_api.base.TestMLRunIntegration): @staticmethod def _assert_source_lists_match(expected_response): - response = mlrun.get_run_db().list_marketplace_sources() + response = mlrun.get_run_db().list_hub_sources() exclude_paths = [ "root['source']['metadata']['updated']", @@ -37,50 +37,50 @@ def _assert_source_lists_match(expected_response): == {} ) - def test_marketplace(self): + def test_hub(self): db = mlrun.get_run_db() - default_source = mlrun.api.schemas.IndexedMarketplaceSource( + default_source = mlrun.common.schemas.IndexedHubSource( index=-1, - source=mlrun.api.schemas.MarketplaceSource.generate_default_source(), + source=mlrun.common.schemas.HubSource.generate_default_source(), ) self._assert_source_lists_match([default_source]) - new_source = mlrun.api.schemas.IndexedMarketplaceSource( - source=mlrun.api.schemas.MarketplaceSource( - metadata=mlrun.api.schemas.MarketplaceObjectMetadata( + new_source = mlrun.common.schemas.IndexedHubSource( + source=mlrun.common.schemas.HubSource( + metadata=mlrun.common.schemas.HubObjectMetadata( name="source-1", description="a private source" ), - spec=mlrun.api.schemas.MarketplaceSourceSpec( + spec=mlrun.common.schemas.HubSourceSpec( path="/local/path/to/source", channel="development" ), ) ) - db.create_marketplace_source(new_source) + db.create_hub_source(new_source) new_source.index = 1 self._assert_source_lists_match([new_source, default_source]) - new_source_2 = mlrun.api.schemas.IndexedMarketplaceSource( + new_source_2 = mlrun.common.schemas.IndexedHubSource( index=1, - source=mlrun.api.schemas.MarketplaceSource( - metadata=mlrun.api.schemas.MarketplaceObjectMetadata( + source=mlrun.common.schemas.HubSource( + metadata=mlrun.common.schemas.HubObjectMetadata( name="source-2", description="2nd private source" ), - spec=mlrun.api.schemas.MarketplaceSourceSpec( + spec=mlrun.common.schemas.HubSourceSpec( path="/local/path/to/source", channel="prod" ), ), ) - db.create_marketplace_source(new_source_2) + db.create_hub_source(new_source_2) new_source.index = 2 self._assert_source_lists_match([new_source_2, new_source, default_source]) new_source.index = 1 - db.store_marketplace_source(new_source.source.metadata.name, new_source) + db.store_hub_source(new_source.source.metadata.name, new_source) new_source_2.index = 2 self._assert_source_lists_match([new_source, new_source_2, default_source]) - db.delete_marketplace_source("source-1") + db.delete_hub_source("source-1") new_source_2.index = 1 self._assert_source_lists_match([new_source_2, default_source]) diff --git a/tests/integration/sdk_api/projects/test_project.py b/tests/integration/sdk_api/projects/test_project.py index 7d42b8dad200..3d7134c45630 100644 --- a/tests/integration/sdk_api/projects/test_project.py +++ b/tests/integration/sdk_api/projects/test_project.py @@ -18,7 +18,7 @@ import pytest import mlrun -import mlrun.api.schemas +import mlrun.common.schemas import tests.conftest import tests.integration.sdk_api.base @@ -42,6 +42,26 @@ def test_create_project_failure_already_exists(self): in str(exc.value) ) + def test_sync_functions(self): + project_name = "project-name" + project = mlrun.new_project(project_name) + project.set_function("hub://describe", "describe") + project_function_object = project.spec._function_objects + project_file_path = pathlib.Path(tests.conftest.results) / "project.yaml" + project.export(str(project_file_path)) + imported_project = mlrun.load_project("./", str(project_file_path)) + assert imported_project.spec._function_objects == {} + imported_project.sync_functions() + _assert_project_function_objects(imported_project, project_function_object) + + fn = project.get_function("describe") + assert fn.metadata.name == "describe", "func did not return" + + # test that functions can be fetched from the DB (w/o set_function) + mlrun.import_function("hub://sklearn_classifier", new_name="train").save() + fn = project.get_function("train") + assert fn.metadata.name == "train", "train func did not return" + def test_overwrite_project(self): project_name = "some-project" @@ -238,3 +258,19 @@ def _assert_projects(expected_project, project): ) assert expected_project.spec.desired_state == project.spec.desired_state assert expected_project.spec.desired_state == project.status.state + + +def _assert_project_function_objects(project, expected_function_objects): + project_function_objects = project.spec._function_objects + assert len(project_function_objects) == len(expected_function_objects) + for function_name, function_object in expected_function_objects.items(): + assert function_name in project_function_objects + assert ( + deepdiff.DeepDiff( + project_function_objects[function_name].to_dict(), + function_object.to_dict(), + ignore_order=True, + exclude_paths=["root['spec']['build']['code_origin']"], + ) + == {} + ) diff --git a/tests/integration/sdk_api/run/test_main.py b/tests/integration/sdk_api/run/test_main.py new file mode 100644 index 000000000000..c752d7c7e3a4 --- /dev/null +++ b/tests/integration/sdk_api/run/test_main.py @@ -0,0 +1,450 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import os +import pathlib +import sys +import traceback +from base64 import b64encode +from subprocess import PIPE, run +from sys import executable, stderr + +import pytest + +import mlrun +import tests.integration.sdk_api.base +from tests.conftest import examples_path, out_path, tests_root_directory + +code = """ +import mlrun, sys +if __name__ == "__main__": + context = mlrun.get_or_create_ctx("test1") + context.log_result("my_args", sys.argv) + context.commit(completed=True) +""" + +nonpy_code = """ +echo "abc123" $1 +""" + + +class TestMain(tests.integration.sdk_api.base.TestMLRunIntegration): + + assets_path = ( + pathlib.Path(__file__).absolute().parent.parent.parent.parent / "run" / "assets" + ) + + def custom_setup(self): + # ensure default project exists + mlrun.get_or_create_project("default") + + def test_main_run_basic(self): + out = self._exec_run( + f"{examples_path}/training.py", + self._compose_param_list(dict(p1=5, p2='"aaa"')), + "test_main_run_basic", + ) + print(out) + assert out.find("state: completed") != -1, out + + def test_main_run_wait_for_completion(self): + """ + Test that the run command waits for the run to complete before returning + (mainly sanity as this is expected when running local function) + """ + path = str(self.assets_path / "sleep.py") + time_to_sleep = 10 + start_time = datetime.datetime.now() + out = self._exec_run( + path, + self._compose_param_list(dict(time_to_sleep=time_to_sleep)) + + ["--handler", "handler"], + "test_main_run_wait_for_completion", + ) + end_time = datetime.datetime.now() + print(out) + assert out.find("state: completed") != -1, out + assert ( + end_time - start_time + ).seconds >= time_to_sleep, "run did not wait for completion" + + def test_main_run_hyper(self): + out = self._exec_run( + f"{examples_path}/training.py", + self._compose_param_list(dict(p2=[4, 5, 6]), "-x"), + "test_main_run_hyper", + ) + print(out) + assert out.find("state: completed") != -1, out + assert out.find("iterations:") != -1, out + + def test_main_run_args(self): + out = self._exec_run( + f"{tests_root_directory}/no_ctx.py -x " + "{p2}", + ["--uid", "123457"] + self._compose_param_list(dict(p1=5, p2="aaa")), + "test_main_run_args", + ) + print(out) + assert out.find("state: completed") != -1, out + db = mlrun.get_run_db() + state, log = db.get_log("123457") + print(log) + assert str(log).find(", -x, aaa") != -1, "params not detected in argv" + + def test_main_run_args_with_url_placeholder_missing_env(self): + args = [ + "--name", + "test_main_run_args_with_url_placeholder_missing_env", + "--dump", + "*", + "--arg1", + "value1", + "--arg2", + "value2", + ] + out = self._exec_main( + "run", + args, + raise_on_error=False, + ) + out_stdout = out.stdout.decode("utf-8") + print(out) + assert ( + out_stdout.find( + "command/url '*' placeholder is not allowed when code is not from env" + ) + != -1 + ), out + + def test_main_run_args_with_url_placeholder_from_env(self): + os.environ["MLRUN_EXEC_CODE"] = b64encode(code.encode("utf-8")).decode("utf-8") + args = [ + "--name", + "test_main_run_args_with_url_placeholder_from_env", + "--uid", + "123456789", + "--from-env", + "--dump", + "*", + "--arg1", + "value1", + "--arg2", + "value2", + ] + self._exec_main( + "run", + args, + raise_on_error=True, + ) + db = mlrun.get_run_db() + _run = db.read_run("123456789") + print(_run) + assert _run["status"]["results"]["my_args"] == [ + "main.py", + "--arg1", + "value1", + "--arg2", + "value2", + ] + assert _run["status"]["state"] == "completed" + + args = [ + "--name", + "test_main_run_args_with_url_placeholder_with_origin_file", + "--uid", + "987654321", + "--from-env", + "--dump", + "*", + "--origin-file", + "my_file.py", + "--arg3", + "value3", + ] + self._exec_main( + "run", + args, + raise_on_error=True, + ) + db = mlrun.get_run_db() + _run = db.read_run("987654321") + print(_run) + assert _run["status"]["results"]["my_args"] == [ + "my_file.py", + "--arg3", + "value3", + ] + assert _run["status"]["state"] == "completed" + + def test_main_with_url_placeholder(self): + os.environ["MLRUN_EXEC_CODE"] = b64encode(code.encode("utf-8")).decode("utf-8") + args = [ + "--name", + "test_main_with_url_placeholder", + "--uid", + "123456789", + "--from-env", + "*", + ] + self._exec_main( + "run", + args, + raise_on_error=True, + ) + db = mlrun.get_run_db() + _run = db.read_run("123456789") + print(_run) + assert _run["status"]["results"]["my_args"] == ["main.py"] + assert _run["status"]["state"] == "completed" + + @pytest.mark.parametrize( + "op,args,raise_on_error,expected_output", + [ + # bad flag before command + [ + "run", + [ + "--bad-flag", + "--name", + "test_main_run_basic", + "--dump", + f"{examples_path}/training.py", + ], + False, + "Error: Invalid value for '[URL]': URL (--bad-flag) cannot start with '-', " + "ensure the command options are typed correctly. Preferably use '--' to separate options and " + "arguments e.g. 'mlrun run --option1 --option2 -- [URL] [--arg1|arg1] [--arg2|arg2]'", + ], + # bad flag with no command + [ + "run", + ["--name", "test_main_run_basic", "--bad-flag"], + False, + "Error: Invalid value for '[URL]': URL (--bad-flag) cannot start with '-', " + "ensure the command options are typed correctly. Preferably use '--' to separate options and " + "arguments e.g. 'mlrun run --option1 --option2 -- [URL] [--arg1|arg1] [--arg2|arg2]'", + ], + # bad flag after -- separator + [ + "run", + ["--name", "test_main_run_basic", "--", "-notaflag"], + False, + "Error: Invalid value for '[URL]': URL (-notaflag) cannot start with '-', " + "ensure the command options are typed correctly. Preferably use '--' to separate options and " + "arguments e.g. 'mlrun run --option1 --option2 -- [URL] [--arg1|arg1] [--arg2|arg2]'", + ], + # correct command with -- separator + [ + "run", + [ + "--name", + "test_main_run_basic", + "--", + f"{examples_path}/training.py", + "--some-arg", + ], + True, + "'status': 'completed'", + ], + ], + ) + def test_main_run_args_validation(self, op, args, raise_on_error, expected_output): + out = self._exec_main( + op, + args, + raise_on_error=raise_on_error, + ) + if not raise_on_error: + out = out.stderr.decode("utf-8") + + assert out.find(expected_output) != -1, out + + def test_main_run_args_from_env(self): + os.environ["MLRUN_EXEC_CODE"] = b64encode(code.encode("utf-8")).decode("utf-8") + os.environ["MLRUN_EXEC_CONFIG"] = ( + '{"spec":{"parameters":{"x": "bbb"}},' + '"metadata":{"uid":"123459", "name":"tst", "labels": {"kind": "job"}}}' + ) + + out = self._exec_run( + "'main.py -x {x}'", + ["--from-env"], + "test_main_run_args_from_env", + ) + db = mlrun.get_run_db() + run_object = db.read_run("123459") + print(out) + assert run_object["status"]["state"] == "completed", out + assert run_object["status"]["results"]["my_args"] == [ + "main.py", + "-x", + "bbb", + ], "params not detected in argv" + + @pytest.mark.skipif(sys.platform == "win32", reason="skip on windows") + def test_main_run_nonpy_from_env(self): + os.environ["MLRUN_EXEC_CODE"] = b64encode(nonpy_code.encode("utf-8")).decode( + "utf-8" + ) + os.environ[ + "MLRUN_EXEC_CONFIG" + ] = '{"spec":{},"metadata":{"uid":"123411", "name":"tst", "labels": {"kind": "job"}}}' + + # --kfp flag will force the logs to print (for the assert) + out = self._exec_run( + "bash {codefile} xx", + ["--from-env", "--mode", "pass", "--kfp"], + "test_main_run_nonpy_from_env", + ) + db = mlrun.get_run_db() + run_object = db.read_run("123411") + assert run_object["status"]["state"] == "completed", out + state, log = db.get_log("123411") + print(state, log) + assert str(log).find("abc123 xx") != -1, "incorrect output" + + def test_main_run_pass(self): + out = self._exec_run( + "python -c print(56)", + ["--mode", "pass", "--uid", "123458"], + "test_main_run_pass", + ) + print(out) + assert out.find("state: completed") != -1, out + db = mlrun.get_run_db() + state, log = db.get_log("123458") + assert str(log).find("56") != -1, "incorrect output" + + def test_main_run_pass_args(self): + out = self._exec_run( + "'python -c print({x})'", + ["--mode", "pass", "--uid", "123451", "-p", "x=33"], + "test_main_run_pass", + ) + print(out) + assert out.find("state: completed") != -1, out + db = mlrun.get_run_db() + state, log = db.get_log("123451") + print(log) + assert str(log).find("33") != -1, "incorrect output" + + def test_main_run_archive(self): + args = f"--source {examples_path}/archive.zip --handler handler -p p1=1" + out = self._exec_run("./myfunc.py", args.split(), "test_main_run_archive") + assert out.find("state: completed") != -1, out + + def test_main_local_source(self): + args = f"--source {examples_path} --handler my_func" + with pytest.raises(Exception) as e: + self._exec_run("./handler.py", args.split(), "test_main_local_source") + assert ( + f"source ({examples_path}) must be a compressed (tar.gz / zip) file, " + f"a git repo, a file path or in the project's context (.)" in str(e.value) + ) + + def test_main_run_archive_subdir(self): + runtime = '{"spec":{"pythonpath":"./subdir"}}' + args = f"--source {examples_path}/archive.zip -r {runtime}" + out = self._exec_run( + "./subdir/func2.py", args.split(), "test_main_run_archive_subdir" + ) + print(out) + assert out.find("state: completed") != -1, out + + def test_main_local_project(self): + mlrun.get_or_create_project("testproject") + project_path = str(self.assets_path) + args = "-f simple -p x=2 --dump" + out = self._exec_main("run", args.split(), cwd=project_path) + assert out.find("state: completed") != -1, out + assert out.find("y: 4") != -1, out # y = x * 2 + + def test_main_local_flag(self): + fn = mlrun.code_to_function( + filename=f"{examples_path}/handler.py", kind="job", handler="my_func" + ) + yaml_path = f"{out_path}/myfunc.yaml" + fn.export(yaml_path) + args = f"-f {yaml_path} --local" + out = self._exec_run("", args.split(), "test_main_local_flag") + print(out) + assert out.find("state: completed") != -1, out + + def test_main_run_class(self): + function_path = str(self.assets_path / "handler.py") + + out = self._exec_run( + function_path, + self._compose_param_list(dict(x=8)) + ["--handler", "mycls::mtd"], + "test_main_run_class", + ) + assert out.find("state: completed") != -1, out + assert out.find("rx: 8") != -1, out + + def test_run_from_module(self): + args = [ + "--name", + "test1", + "--dump", + "--handler", + "json.dumps", + "-p", + "obj=[6,7]", + ] + out = self._exec_main("run", args) + assert out.find("state: completed") != -1, out + assert out.find("return: '[6, 7]'") != -1, out + + def test_main_env_file(self): + # test run with env vars loaded from a .env file + function_path = str(self.assets_path / "handler.py") + envfile = str(self.assets_path / "envfile") + + out = self._exec_run( + function_path, + ["--handler", "env_file_test", "--env-file", envfile], + "test_main_env_file", + ) + assert out.find("state: completed") != -1, out + assert out.find("ENV_ARG1: '123'") != -1, out + assert out.find("kfp_ttl: 12345") != -1, out + + @staticmethod + def _exec_main(op, args, cwd=examples_path, raise_on_error=True): + cmd = [executable, "-m", "mlrun", op] + if args: + cmd += args + out = run(cmd, stdout=PIPE, stderr=PIPE, cwd=cwd) + if out.returncode != 0: + print(out.stderr.decode("utf-8"), file=stderr) + print(out.stdout.decode("utf-8"), file=stderr) + print(traceback.format_exc()) + if raise_on_error: + raise Exception(out.stderr.decode("utf-8")) + else: + # return out so that we can check the error message on stdout and stderr + return out + + return out.stdout.decode("utf-8") + + def _exec_run(self, cmd, args, test, raise_on_error=True): + args = args + ["--name", test, "--dump", cmd] + return self._exec_main("run", args, raise_on_error=raise_on_error) + + @staticmethod + def _compose_param_list(params: dict, flag="-p"): + composed_params = [] + for k, v in params.items(): + composed_params += [flag, f"{k}={v}"] + return composed_params diff --git a/tests/launcher/__init__.py b/tests/launcher/__init__.py new file mode 100644 index 000000000000..b3085be1eb56 --- /dev/null +++ b/tests/launcher/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/tests/launcher/assets/sample_function.py b/tests/launcher/assets/sample_function.py new file mode 100644 index 000000000000..ede8b032dfd2 --- /dev/null +++ b/tests/launcher/assets/sample_function.py @@ -0,0 +1,20 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +def hello_word(context): + return "hello world" + + +def handler_v2(context): + return "hello world v2" diff --git a/tests/launcher/test_factory.py b/tests/launcher/test_factory.py new file mode 100644 index 000000000000..b077aeb14ad0 --- /dev/null +++ b/tests/launcher/test_factory.py @@ -0,0 +1,91 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import typing +from contextlib import nullcontext as does_not_raise + +import pytest + +import mlrun.api.launcher +import mlrun.launcher.base +import mlrun.launcher.factory +import mlrun.launcher.local +import mlrun.launcher.remote + + +@pytest.mark.parametrize( + "is_remote, local, expected_instance", + [ + # runtime is remote and user didn't specify local - submit job flow + ( + True, + False, + mlrun.launcher.remote.ClientRemoteLauncher, + ), + # runtime is remote but specify local - run local flow + ( + True, + True, + mlrun.launcher.local.ClientLocalLauncher, + ), + # runtime is local and user specify local - run local flow + ( + False, + True, + mlrun.launcher.local.ClientLocalLauncher, + ), + # runtime is local and user didn't specify local - run local flow + ( + False, + False, + mlrun.launcher.local.ClientLocalLauncher, + ), + ], +) +def test_create_client_launcher( + is_remote: bool, + local: bool, + expected_instance: typing.Union[ + mlrun.launcher.remote.ClientRemoteLauncher, + mlrun.launcher.local.ClientLocalLauncher, + ], +): + launcher = mlrun.launcher.factory.LauncherFactory.create_launcher(is_remote, local) + assert isinstance(launcher, expected_instance) + + if local: + assert launcher._is_run_local + + elif not is_remote: + assert not launcher._is_run_local + + +@pytest.mark.parametrize( + "is_remote, local, expectation", + [ + (True, False, does_not_raise()), + (False, False, does_not_raise()), + # local run is not allowed when running as API + (True, True, pytest.raises(mlrun.errors.MLRunInternalServerError)), + (False, True, pytest.raises(mlrun.errors.MLRunInternalServerError)), + ], +) +def test_create_server_side_launcher(running_as_api, is_remote, local, expectation): + """Test that the server side launcher is created when we are running as API""" + with expectation: + launcher = mlrun.launcher.factory.LauncherFactory.create_launcher( + is_remote, local + ) + assert isinstance(launcher, mlrun.api.launcher.ServerSideLauncher) diff --git a/tests/launcher/test_local.py b/tests/launcher/test_local.py new file mode 100644 index 000000000000..8b6ec761d32b --- /dev/null +++ b/tests/launcher/test_local.py @@ -0,0 +1,118 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pathlib + +import pytest + +import mlrun.launcher.local + +assets_path = pathlib.Path(__file__).parent / "assets" +func_path = assets_path / "sample_function.py" +handler = "hello_word" + + +def test_launch_local(): + launcher = mlrun.launcher.local.ClientLocalLauncher(local=True) + runtime = mlrun.code_to_function( + name="test", kind="job", filename=str(func_path), handler=handler + ) + result = launcher.launch(runtime) + assert result.status.state == "completed" + assert result.status.results.get("return") == "hello world" + + +def test_override_handler(): + launcher = mlrun.launcher.local.ClientLocalLauncher(local=True) + runtime = mlrun.code_to_function( + name="test", kind="job", filename=str(func_path), handler=handler + ) + result = launcher.launch(runtime, handler="handler_v2") + assert result.status.state == "completed" + assert result.status.results.get("return") == "hello world v2" + + +def test_launch_remote_job_locally(): + launcher = mlrun.launcher.local.ClientLocalLauncher(local=False) + runtime = mlrun.code_to_function( + name="test", kind="job", filename=str(func_path), handler=handler + ) + with pytest.raises(mlrun.errors.MLRunRuntimeError) as exc: + launcher.launch(runtime) + assert "Remote function cannot be executed locally" in str(exc.value) + + +def test_create_local_function_for_execution(): + launcher = mlrun.launcher.local.ClientLocalLauncher(local=False) + runtime = mlrun.code_to_function( + name="test", kind="job", filename=str(func_path), handler=handler + ) + run = mlrun.run.RunObject() + runtime = launcher._create_local_function_for_execution( + runtime=runtime, + run=run, + ) + assert runtime.metadata.project == "default" + assert runtime.metadata.name == "test" + assert run.spec.handler == handler + assert runtime.kind == "" + assert runtime._is_run_local + + +def test_create_local_function_for_execution_with_enrichment(): + launcher = mlrun.launcher.local.ClientLocalLauncher(local=False) + runtime = mlrun.code_to_function( + name="test", kind="job", filename=str(func_path), handler=handler + ) + runtime.spec.allow_empty_resources = True + run = mlrun.run.RunObject() + runtime = launcher._create_local_function_for_execution( + runtime=runtime, + run=run, + local_code_path="some_path.py", + project="some_project", + name="other_name", + workdir="some_workdir", + handler="handler_v2", + ) + assert runtime.spec.command == "some_path.py" + assert runtime.metadata.project == "some_project" + assert runtime.metadata.name == "other_name" + assert runtime.spec.workdir == "some_workdir" + assert run.spec.handler == "handler_v2" + assert runtime.kind == "" + assert runtime._is_run_local + assert runtime.spec.allow_empty_resources + + +def test_validate_inputs(): + launcher = mlrun.launcher.local.ClientLocalLauncher(local=False) + runtime = mlrun.code_to_function( + name="test", kind="job", filename=str(func_path), handler=handler + ) + run = mlrun.run.RunObject(spec=mlrun.model.RunSpec(inputs={"input1": 1})) + with pytest.raises(mlrun.errors.MLRunInvalidArgumentTypeError) as exc: + launcher._validate_runtime(runtime, run) + assert "Inputs should be of type Dict[str,str]" in str(exc.value) + + +def test_validate_runtime_success(): + launcher = mlrun.launcher.local.ClientLocalLauncher(local=False) + runtime = mlrun.code_to_function( + name="test", kind="local", filename=str(func_path), handler=handler + ) + run = mlrun.run.RunObject( + spec=mlrun.model.RunSpec(inputs={"input1": ""}, output_path="./some_path") + ) + launcher._validate_runtime(runtime, run) diff --git a/tests/launcher/test_remote.py b/tests/launcher/test_remote.py new file mode 100644 index 000000000000..d92258971280 --- /dev/null +++ b/tests/launcher/test_remote.py @@ -0,0 +1,137 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import pathlib +import unittest.mock + +import pytest + +import mlrun.config +import mlrun.launcher.remote + +assets_path = pathlib.Path(__file__).parent / "assets" +func_path = assets_path / "sample_function.py" +handler = "hello_word" + + +def test_launch_remote_job(rundb_mock): + launcher = mlrun.launcher.remote.ClientRemoteLauncher() + mlrun.config.config.artifact_path = "v3io:///users/admin/mlrun" + runtime = mlrun.code_to_function( + name="test", + kind="job", + filename=str(func_path), + handler=handler, + image="mlrun/mlrun", + ) + + # store the run is done by the API so we need to mock it + uid = "123" + run = mlrun.run.RunObject( + metadata=mlrun.model.RunMetadata(uid=uid), + ) + rundb_mock.store_run(run, uid) + result = launcher.launch(runtime, run) + assert result.status.state == "completed" + + +def test_launch_remote_job_no_watch(rundb_mock): + launcher = mlrun.launcher.remote.ClientRemoteLauncher() + mlrun.config.config.artifact_path = "v3io:///users/admin/mlrun" + runtime = mlrun.code_to_function( + name="test", + kind="job", + filename=str(func_path), + handler=handler, + image="mlrun/mlrun", + ) + result = launcher.launch(runtime, watch=False) + assert result.status.state == "created" + + +def test_validate_inputs(): + launcher = mlrun.launcher.remote.ClientRemoteLauncher() + runtime = mlrun.code_to_function( + name="test", kind="job", filename=str(func_path), handler=handler + ) + run = mlrun.run.RunObject(spec=mlrun.model.RunSpec(inputs={"input1": 1})) + with pytest.raises(mlrun.errors.MLRunInvalidArgumentTypeError) as exc: + launcher._validate_runtime(runtime, run) + assert "Inputs should be of type Dict[str,str]" in str(exc.value) + + +def test_validate_runtime_success(): + launcher = mlrun.launcher.remote.ClientRemoteLauncher() + runtime = mlrun.code_to_function( + name="test", kind="local", filename=str(func_path), handler=handler + ) + run = mlrun.run.RunObject( + spec=mlrun.model.RunSpec(inputs={"input1": ""}, output_path="./some_path") + ) + launcher._validate_runtime(runtime, run) + + +@pytest.mark.parametrize( + "kind, requirements, expected_base_image, expected_image", + [ + ("job", [], None, "mlrun/mlrun"), + ("job", ["pandas"], "mlrun/mlrun", ""), + ("nuclio", ["pandas"], None, "mlrun/mlrun"), + ("serving", ["pandas"], None, "mlrun/mlrun"), + ], +) +def test_prepare_image_for_deploy( + kind, requirements, expected_base_image, expected_image +): + launcher = mlrun.launcher.remote.ClientRemoteLauncher() + runtime = mlrun.code_to_function( + name="test", + kind=kind, + filename=str(func_path), + handler=handler, + image="mlrun/mlrun", + requirements=requirements, + ) + launcher.prepare_image_for_deploy(runtime) + assert runtime.spec.build.base_image == expected_base_image + assert runtime.spec.image == expected_image + + +def test_run_error_status(rundb_mock): + launcher = mlrun.launcher.remote.ClientRemoteLauncher() + mlrun.config.config.artifact_path = "v3io:///users/admin/mlrun" + runtime = mlrun.code_to_function( + name="test", + kind="job", + filename=str(func_path), + handler=handler, + image="mlrun/mlrun", + ) + + # store the run is done by the API so we need to mock it + uid = "123" + run = mlrun.run.RunObject( + metadata=mlrun.model.RunMetadata(uid=uid), + ) + rundb_mock.store_run(run, uid) + + result = mlrun.run.RunObject( + metadata=mlrun.model.RunMetadata(uid=uid), + status=mlrun.model.RunStatus(state="error", reason="some error"), + ) + runtime._get_db_run = unittest.mock.MagicMock(return_value=result.to_dict()) + + with pytest.raises(mlrun.runtimes.utils.RunError) as exc: + launcher.launch(runtime, run, watch=True) + assert "some error" in str(exc.value) diff --git a/tests/model_monitoring/test_features_drift_table.py b/tests/model_monitoring/test_features_drift_table.py index cbc72b50b562..9ddfc2e0c013 100644 --- a/tests/model_monitoring/test_features_drift_table.py +++ b/tests/model_monitoring/test_features_drift_table.py @@ -100,7 +100,7 @@ def plot_produce(context: mlrun.MLClientCtx): ) -def test_plot_produce(): +def test_plot_produce(rundb_mock): # Create a temp directory: output_path = tempfile.TemporaryDirectory() @@ -118,7 +118,7 @@ def test_plot_produce(): # Check the plot was saved properly (only the drift table plot should appear): artifact_directory_content = os.listdir( - os.path.dirname(train_run.outputs["drift_table_plot"]) + os.path.dirname(train_run.status.artifacts[0]["spec"]["target_path"]) ) assert len(artifact_directory_content) == 1 assert artifact_directory_content[0] == "drift_table_plot.html" diff --git a/tests/model_monitoring/test_target_path.py b/tests/model_monitoring/test_target_path.py new file mode 100644 index 000000000000..097afc9d57bc --- /dev/null +++ b/tests/model_monitoring/test_target_path.py @@ -0,0 +1,73 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import mock + +import mlrun.config +import mlrun.utils.model_monitoring + +TEST_PROJECT = "test-model-endpoints" + + +@mock.patch.dict(os.environ, {"MLRUN_ARTIFACT_PATH": "s3://some-bucket/"}, clear=True) +def test_get_file_target_path(): + + # offline target with relative path + offline_parquet_relative = mlrun.mlconf.get_model_monitoring_file_target_path( + project=TEST_PROJECT, + kind="parquet", + target="offline", + artifact_path=os.environ["MLRUN_ARTIFACT_PATH"], + ) + assert ( + offline_parquet_relative + == os.environ["MLRUN_ARTIFACT_PATH"] + "model-endpoints/parquet" + ) + + # online target + online_target = mlrun.mlconf.get_model_monitoring_file_target_path( + project=TEST_PROJECT, kind="some_kind", target="online" + ) + assert ( + online_target + == f"v3io:///users/pipelines/{TEST_PROJECT}/model-endpoints/some_kind" + ) + + # offline target with absolute path + mlrun.mlconf.model_endpoint_monitoring.offline_storage_path = ( + "schema://projects/test-path" + ) + offline_parquet_abs = mlrun.mlconf.get_model_monitoring_file_target_path( + project=TEST_PROJECT, kind="parquet", target="offline" + ) + assert ( + offline_parquet_abs + f"/{TEST_PROJECT}/parquet" + == f"schema://projects/test-path/{TEST_PROJECT}/parquet" + ) + + +def test_get_stream_path(): + # default stream path + stream_path = mlrun.utils.model_monitoring.get_stream_path(project=TEST_PROJECT) + assert ( + stream_path == f"v3io:///users/pipelines/{TEST_PROJECT}/model-endpoints/stream" + ) + + # kafka stream path from env + os.environ["STREAM_PATH"] = "kafka://some_kafka_bootstrap_servers:8080" + stream_path = mlrun.utils.model_monitoring.get_stream_path(project=TEST_PROJECT) + assert ( + stream_path + == f"kafka://some_kafka_bootstrap_servers:8080?topic=monitoring_stream_{TEST_PROJECT}" + ) diff --git a/mlrun/runtimes/package/__init__.py b/tests/package/__init__.py similarity index 100% rename from mlrun/runtimes/package/__init__.py rename to tests/package/__init__.py diff --git a/tests/package/packager_tester.py b/tests/package/packager_tester.py new file mode 100644 index 000000000000..a2813c302162 --- /dev/null +++ b/tests/package/packager_tester.py @@ -0,0 +1,140 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import sys +from abc import ABC +from typing import Any, Callable, List, NamedTuple, Tuple, Union + +import cloudpickle + +from mlrun import Packager + +# When using artifact type "object", these instructions will be common to most artifacts in the tests: +COMMON_OBJECT_INSTRUCTIONS = { + "pickle_module_name": "cloudpickle", + "pickle_module_version": cloudpickle.__version__, + "python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}", +} + + +class PackTest(NamedTuple): + """ + Tuple for creating a test to run in the `test_packager_pack` test of "test_packagers.py". + + :param pack_handler: The handler to run as a MLRun function for packing. + :param log_hint: The log hint to pass to the pack handler. + :param validation_function: Function to assert a success packing. Will run without MLRun. It expects to + receive the logged result / Artifact object. + :param pack_parameters: The parameters to pass to the pack handler. + :param validation_parameters: Additional parameters to pass to the validation function. + :param default_artifact_type_object: Optional field to hold a dummy object to test the default artifact type method + of the packager. Make sure to not pass an artifact type in the log hint, so it + will be tested. + :param exception: If an exception should be raised during the test, this should be part of the + expected exception message. Default is None (the test should succeed). + """ + + pack_handler: str + log_hint: Union[str, dict] + validation_function: Callable[[Any, ...], bool] + pack_parameters: dict = {} + validation_parameters: dict = {} + default_artifact_type_object: Any = None + exception: str = None + + +class UnpackTest(NamedTuple): + """ + Tuple for creating a test to run in the `test_packager_unpack` test of "test_packagers.py". + + :param prepare_input_function: Function to prepare the input to pass to the unpack handler. It should return a tuple + of strings: the input path to pass as input to the function and the root directory to + delete after the test where all files that were generated are stored. + :param unpack_handler: The handler to run as a MLRun function for unpacking. Must accept "obj" as the + argument to unpack. + :param prepare_parameters: The parameters to pass to the prepare function. + :param unpack_parameters: The parameters to pass to the unpack handler. + :param exception: If an exception should be raised during the test, this should be part of the expected + exception message. Default is None (the test should succeed). + """ + + prepare_input_function: Callable[[...], Tuple[str, str]] + unpack_handler: str + prepare_parameters: dict = {} + unpack_parameters: dict = {} + exception: str = None + + +class PackToUnpackTest(NamedTuple): + """ + Tuple for creating a test to run in the `test_packager_pack_to_unpack` test of "test_packagers.py". + + :param pack_handler: The handler to run as a MLRun function for packing. + :param log_hint: The log hint to pass to the pack handler. Result will skip the + `expected_instructions` and `unpack_handler` variables (hence they are + optional). + :param pack_parameters: The parameters to pass to the pack handler. + :param expected_instructions: The expected instructions the packed artifact should have. + :param unpack_handler: The handler to run as a MLRun function for unpacking. Must accept "obj" as the + argument to unpack. + :param unpack_parameters: The parameters to pass to the unpack handler. + :param default_artifact_type_object: Optional field to hold a dummy object to test the default artifact type method + of the packager. Make sure to not pass an artifact type in the log hint, so it + will be tested. + :param exception: If an exception should be raised during the test, this should be part of the + expected exception message. Default is None (the test should succeed). + """ + + pack_handler: str + log_hint: Union[str, dict] + pack_parameters: dict = {} + expected_instructions: dict = {} + unpack_handler: str = None + unpack_parameters: dict = {} + default_artifact_type_object: Any = None + exception: str = None + + +class PackagerTester(ABC): + """ + A simple class for all testers to inherit from, so they will be able to be added to the tests in + "test_packagers.py". + """ + + # The packager being tested by this tester: + PACKAGER_IN_TEST: Packager = None + + # The list of tests tuples to include from this tester in the tests of "test_packagers.py": + TESTS: List[Union[PackTest, UnpackTest, PackToUnpackTest]] = [] + + +class NewClass: + """ + Class to use for testing the default class. + """ + + # It is declared in this file so that it won't be part of the MLRun function module when a tester of + # `default_packager_tester.py` is running. For more information, see the long exception at `packagers_manager.py`'s + # `PackagersManager._unpack_package` function. + + def __init__(self, a: int, b: int, c: int): + self.a = a + self.b = b + self.c = c + + def __eq__(self, other): + return self.a == other.a and self.b == other.b and self.c == other.c + + def __str__(self): + return str(self.a + self.b + self.c) diff --git a/tests/package/packagers/__init__.py b/tests/package/packagers/__init__.py new file mode 100644 index 000000000000..4f418a506ca1 --- /dev/null +++ b/tests/package/packagers/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx diff --git a/tests/package/packagers/test_numpy_packagers.py b/tests/package/packagers/test_numpy_packagers.py new file mode 100644 index 000000000000..3cb16295ceb4 --- /dev/null +++ b/tests/package/packagers/test_numpy_packagers.py @@ -0,0 +1,105 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tempfile +from pathlib import Path +from typing import Dict, List, Union + +import numpy as np +import pytest + +from mlrun.package.packagers.numpy_packagers import NumPySupportedFormat + + +def _test( + obj: Union[np.ndarray, Dict[str, np.ndarray], List[np.ndarray]], + file_format: str, + **save_kwargs, +): + # Create a temporary directory for the test outputs: + test_directory = tempfile.TemporaryDirectory() + + # Set up the main directory to archive and the output path for the archive file: + file_path = Path(test_directory.name) / f"my_array.{file_format}" + assert not file_path.exists() + + # Archive the files: + formatter = NumPySupportedFormat.get_format_handler(fmt=file_format) + formatter.save(obj=obj, file_path=str(file_path), **save_kwargs) + assert file_path.exists() + + # Extract the files: + saved_object = formatter.load(file_path=str(file_path)) + if isinstance(obj, np.ndarray): + assert (saved_object == obj).all() + elif isinstance(obj, dict): + for original, saved in zip(obj.values(), saved_object.values()): + assert (original == saved).all() + else: + for original, saved in zip(obj, saved_object.values()): + assert (original == saved).all() + + # Clean the test outputs: + test_directory.cleanup() + + +@pytest.mark.parametrize( + "obj", + [ + np.random.random((10, 30)), + np.random.random(100), + np.random.randint(0, 255, (150, 200)), + ], +) +@pytest.mark.parametrize( + "file_format", + NumPySupportedFormat.get_single_array_formats(), +) +def test_formatter_single_array(obj: np.ndarray, file_format: str): + """ + Test the formatters for saving and writing a numpy array. + + :param obj: The array to write. + :param file_format: The numpy format to use. + """ + _test(file_format=file_format, obj=obj) + + +@pytest.mark.parametrize( + "obj", + [ + {f"array_{i}": np.random.random(size=(10, 30)) for i in range(5)}, + [np.random.random(size=777) for i in range(10)], + ], +) +@pytest.mark.parametrize( + "file_format", + NumPySupportedFormat.get_multi_array_formats(), +) +@pytest.mark.parametrize( + "save_kwargs", [{"is_compressed": boolean_value} for boolean_value in [True, False]] +) +def test_formatter_multiple_arrays( + obj: Union[Dict[str, np.ndarray], List[np.ndarray]], + file_format: str, + save_kwargs: bool, +): + """ + Test the formatters for saving and writing a numpy array. + + :param obj: The array to write. + :param file_format: The numpy format to use. + :param save_kwargs: Save kwargs to use. + """ + _test(obj=obj, file_format=file_format, save_kwargs=save_kwargs) diff --git a/tests/package/packagers/test_pandas_packagers.py b/tests/package/packagers/test_pandas_packagers.py new file mode 100644 index 000000000000..33f903f3e3fa --- /dev/null +++ b/tests/package/packagers/test_pandas_packagers.py @@ -0,0 +1,178 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from mlrun.package.packagers.pandas_packagers import PandasSupportedFormat + +# Set up the format requirements dictionary: +FORMAT_REQUIREMENTS = { + PandasSupportedFormat.PARQUET: "pyarrow", + PandasSupportedFormat.H5: "tables", + PandasSupportedFormat.XLSX: "openpyxl", + PandasSupportedFormat.XML: "lxml", + PandasSupportedFormat.HTML: "lxml", + PandasSupportedFormat.FEATHER: "pyarrow", + PandasSupportedFormat.ORC: "pyarrow", +} + + +def check_skipping_pandas_format(fmt: str): + if fmt in FORMAT_REQUIREMENTS: + try: + importlib.import_module(FORMAT_REQUIREMENTS[fmt]) + except ModuleNotFoundError: + return True + + # TODO: Remove when padnas>=1.5 in requirements + if fmt == PandasSupportedFormat.ORC: + # ORC format is added only since pandas 1.5.0, so we skip if pandas is older than this: + v1, v2, v3 = pd.__version__.split(".") + if int(v1) == 1 and int(v2) < 5: + return True + return False + + +def get_test_dataframes(): + # Configurations: + _n_rows = 100 + _n_columns = 24 + _single_level_column_names = [f"column_{i}" for i in range(_n_columns)] + _multi_level_column_names = [ + [f"{chr(n)}1" for n in range(ord("A"), ord("A") + 2)], + [f"{chr(n)}2" for n in range(ord("A"), ord("A") + 3)], + [f"{chr(n)}3" for n in range(ord("A"), ord("A") + 4)], + ] # 2 * 3 * 4 = 24 (_n_columns) + _column_levels_names = ["letter_level_1", "letter_level_2", "letter_level_3"] + _single_index = [i for i in range(0, _n_rows * 2, 2)] + _multi_index = [ + list(range(2)), + list(range(5)), + list(range(10)), + ] # 2 * 5 * 10 = 100 (_n_rows) + + # Initialize the data and options for dataframes: + data = np.random.randint(0, 256, (_n_rows, _n_columns)) + columns_options = [ + # Single level: + _single_level_column_names, + # Multi-level: + pd.MultiIndex.from_product(_multi_level_column_names), + # Multi-level with names: + pd.MultiIndex.from_product( + _multi_level_column_names, + names=_column_levels_names, + ), + ] + index_options = [ + # Default: + None, + # Single level: + _single_index, + # Single level with name: + pd.Index(data=_single_index, name="my_index"), + # Multi-level: + pd.MultiIndex.from_product(_multi_index), + # Multi-level with names: + pd.MultiIndex.from_product( + _multi_index, names=["index_5", "index_10", "index_20"] + ), + ] + + # Initialize the dataframes: + dataframes = [] + for columns in columns_options: + for index in index_options: + df = pd.DataFrame(data=data, columns=columns, index=index) + dataframes.append(df) + # Add same name of columns and indexes scenarios if index has a name: + if index is not None and all( + index_name is not None for index_name in df.index.names + ): + same_name_df = df.copy() + if isinstance(df.index, pd.MultiIndex): + if isinstance(df.columns, pd.MultiIndex): + same_name_df.index.set_names( + names=df.columns.names[: len(df.index.names)], inplace=True + ) + else: # Single index + same_name_df.index.set_names( + names=df.columns[: len(df.index.names)], inplace=True + ) + else: # Single index + if isinstance(df.columns, pd.MultiIndex): + same_name_df.index.set_names( + names=str(df.columns.names[0]), inplace=True + ) + else: # Single index + same_name_df.index.set_names( + names=str(df.columns[0]), inplace=True + ) + dataframes.append(same_name_df) + + return dataframes + + +@pytest.mark.parametrize("obj", get_test_dataframes()) +@pytest.mark.parametrize( + "file_format", + PandasSupportedFormat.get_all_formats(), +) +def test_formatter( + obj: pd.DataFrame, + file_format: str, +): + """ + Test the pandas formatters for writing and reading dataframes. + + :param obj: The dataframe to write. + :param file_format: The pandas file format to use. + """ + # Check if needed to skip this file format test: + if check_skipping_pandas_format(fmt=file_format): + pytest.skip( + f"Skipping test of pandas file format '{file_format}' " + f"due to missing requirement: '{FORMAT_REQUIREMENTS[file_format]}'" + ) + + # Create a temporary directory for the test outputs: + test_directory = tempfile.TemporaryDirectory() + + # Set up the main directory to archive and the output path for the archive file: + file_path = Path(test_directory.name) / f"my_array.{file_format}" + assert not file_path.exists() + + # Save the dataframe to file: + formatter = PandasSupportedFormat.get_format_handler(fmt=file_format) + read_kwargs = formatter.to(obj=obj.copy(), file_path=str(file_path)) + assert file_path.exists() + + # Read the file: + saved_object = formatter.read(file_path=str(file_path), **read_kwargs) + + # Assert equality post reading: + assert isinstance(saved_object, type(obj)) + assert list(saved_object.columns) == list(obj.columns) + assert saved_object.columns.names == obj.columns.names + assert saved_object.index.names == obj.index.names + assert (saved_object == obj).all().all() + + # Clean the test outputs: + test_directory.cleanup() diff --git a/tests/package/packagers_testers/__init__.py b/tests/package/packagers_testers/__init__.py new file mode 100644 index 000000000000..4f418a506ca1 --- /dev/null +++ b/tests/package/packagers_testers/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx diff --git a/tests/package/packagers_testers/default_packager_tester.py b/tests/package/packagers_testers/default_packager_tester.py new file mode 100644 index 000000000000..b555542d2d8a --- /dev/null +++ b/tests/package/packagers_testers/default_packager_tester.py @@ -0,0 +1,81 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import tempfile +from typing import Tuple + +import cloudpickle + +from mlrun.package import DefaultPackager +from tests.package.packager_tester import ( + COMMON_OBJECT_INSTRUCTIONS, + NewClass, + PackagerTester, + PackTest, + PackToUnpackTest, + UnpackTest, +) + + +def pack_some_class() -> NewClass: + return NewClass(a=1, b=2, c=3) + + +def unpack_some_class(obj: NewClass): + assert type(obj).__name__ == NewClass.__name__ + assert obj == NewClass(a=1, b=2, c=3) + + +def validate_some_class_result(result: str) -> bool: + return result == "6" + + +def prepare_new_class() -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + pkl_path = os.path.join(temp_directory, "my_class.pkl") + some_class = NewClass(a=1, b=2, c=3) + with open(pkl_path, "wb") as pkl_file: + cloudpickle.dump(some_class, pkl_file) + + return pkl_path, temp_directory + + +class DefaultPackagerTester(PackagerTester): + """ + A tester for the `DefaultPackager`. + """ + + PACKAGER_IN_TEST = DefaultPackager + + TESTS = [ + PackTest( + pack_handler="pack_some_class", + log_hint="my_result : result", + validation_function=validate_some_class_result, + ), + UnpackTest( + prepare_input_function=prepare_new_class, + unpack_handler="unpack_some_class", + ), + PackToUnpackTest( + pack_handler="pack_some_class", + log_hint="my_object", + expected_instructions={ + "object_module_name": "tests", + **COMMON_OBJECT_INSTRUCTIONS, + }, + unpack_handler="unpack_some_class", + ), + ] diff --git a/tests/package/packagers_testers/numpy_packagers_testers.py b/tests/package/packagers_testers/numpy_packagers_testers.py new file mode 100644 index 000000000000..b13ba5201413 --- /dev/null +++ b/tests/package/packagers_testers/numpy_packagers_testers.py @@ -0,0 +1,326 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import tempfile +from typing import Dict, List, Tuple + +import numpy as np + +from mlrun.package.packagers.numpy_packagers import ( + NumPyNDArrayDictPackager, + NumPyNDArrayListPackager, + NumPyNDArrayPackager, + NumPyNumberPackager, + NumPySupportedFormat, +) +from tests.package.packager_tester import ( + COMMON_OBJECT_INSTRUCTIONS, + PackagerTester, + PackTest, + PackToUnpackTest, + UnpackTest, +) + +# Common instructions for "object" artifacts of numpy objects: +_COMMON_OBJECT_INSTRUCTIONS = { + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": "numpy", + "object_module_version": np.__version__, +} + + +_ARRAY_SAMPLE = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + +def pack_array() -> np.ndarray: + return _ARRAY_SAMPLE + + +def validate_array(result: List[List[int]]) -> bool: + return (np.array(result) == _ARRAY_SAMPLE).all() + + +def unpack_array(obj: np.ndarray): + assert isinstance(obj, np.ndarray) + assert (obj == _ARRAY_SAMPLE).all() + + +def prepare_array_file(file_format: str) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_array.{file_format}") + formatter = NumPySupportedFormat.get_format_handler(fmt=file_format) + formatter.save(obj=_ARRAY_SAMPLE, file_path=file_path) + return file_path, temp_directory + + +class NumPyNDArrayPackagerTester(PackagerTester): + """ + A tester for the `NumPyNDArrayPackager`. + """ + + PACKAGER_IN_TEST = NumPyNDArrayPackager + + TESTS = [ + PackTest( + pack_handler="pack_array", + log_hint="my_result", + validation_function=validate_array, + pack_parameters={}, + default_artifact_type_object=np.ones(1), + ), + *[ + UnpackTest( + prepare_input_function=prepare_array_file, + unpack_handler="unpack_array", + prepare_parameters={"file_format": file_format}, + ) + for file_format in NumPySupportedFormat.get_single_array_formats() + ], + PackToUnpackTest( + pack_handler="pack_array", + log_hint="my_result: result", + ), + PackToUnpackTest( + pack_handler="pack_array", + log_hint="my_result: object", + expected_instructions=_COMMON_OBJECT_INSTRUCTIONS, + unpack_handler="unpack_array", + ), + PackToUnpackTest( + pack_handler="pack_array", + log_hint="my_result: dataset", + unpack_handler="unpack_array", + ), + *[ + PackToUnpackTest( + pack_handler="pack_array", + log_hint={ + "key": "my_array", + "artifact_type": "file", + "file_format": file_format, + }, + expected_instructions={"file_format": file_format}, + unpack_handler="unpack_array", + ) + for file_format in NumPySupportedFormat.get_single_array_formats() + ], + ] + + +_NUMBER_SAMPLE = np.float64(5.10203) + + +def pack_number() -> np.number: + return _NUMBER_SAMPLE + + +def validate_number(result: float) -> bool: + return np.float64(result) == _NUMBER_SAMPLE + + +def unpack_number(obj: np.float64): + assert isinstance(obj, np.float64) + assert obj == _NUMBER_SAMPLE + + +class NumPyNumberPackagerTester(PackagerTester): + """ + A tester for the `NumPyNumberPackager`. + """ + + PACKAGER_IN_TEST = NumPyNumberPackager + + TESTS = [ + PackTest( + pack_handler="pack_number", + log_hint="my_result", + validation_function=validate_number, + ), + PackToUnpackTest( + pack_handler="pack_number", + log_hint="my_result", + ), + PackToUnpackTest( + pack_handler="pack_number", + log_hint="my_result: object", + expected_instructions=_COMMON_OBJECT_INSTRUCTIONS, + unpack_handler="unpack_number", + ), + ] + + +_ARRAY_DICT_SAMPLE = {f"my_array_{i}": _ARRAY_SAMPLE * i for i in range(1, 5)} + + +def pack_array_dict() -> Dict[str, np.ndarray]: + return _ARRAY_DICT_SAMPLE + + +def unpack_array_dict(obj: Dict[str, np.ndarray]): + assert isinstance(obj, dict) and all( + isinstance(key, str) and isinstance(value, np.ndarray) + for key, value in obj.items() + ) + assert obj.keys() == _ARRAY_DICT_SAMPLE.keys() + for obj_array, sample_array in zip(obj.values(), _ARRAY_DICT_SAMPLE.values()): + assert (obj_array == sample_array).all() + + +def validate_array_dict(result: Dict[str, list]) -> bool: + # Numppy arrays are serialized as lists: + for key in _ARRAY_DICT_SAMPLE: + array = result.pop(key) + if not (np.array(array) == _ARRAY_DICT_SAMPLE[key]).all(): + return False + return len(result) == 0 + + +def prepare_array_dict_file(file_format: str, **save_kwargs) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_file.{file_format}") + formatter = NumPySupportedFormat.get_format_handler(fmt=file_format) + formatter.save(obj=_ARRAY_DICT_SAMPLE, file_path=file_path, **save_kwargs) + return file_path, temp_directory + + +class NumPyNDArrayDictPackagerTester(PackagerTester): + """ + A tester for the `NumPyNDArrayDictPackager`. + """ + + PACKAGER_IN_TEST = NumPyNDArrayDictPackager + + TESTS = [ + PackTest( + pack_handler="pack_array_dict", + log_hint="my_result: result", + validation_function=validate_array_dict, + ), + *[ + UnpackTest( + prepare_input_function=prepare_array_dict_file, + unpack_handler="unpack_array_dict", + prepare_parameters={"file_format": file_format}, + ) + for file_format in NumPySupportedFormat.get_multi_array_formats() + ], + PackToUnpackTest( + pack_handler="pack_array_dict", + log_hint="my_array: result", + ), + PackToUnpackTest( + pack_handler="pack_array_dict", + log_hint="my_array: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": dict.__module__, + }, + unpack_handler="unpack_array_dict", + ), + *[ + PackToUnpackTest( + pack_handler="pack_array_dict", + log_hint={ + "key": "my_array", + "file_format": file_format, + }, + expected_instructions={ + "file_format": file_format, + }, + unpack_handler="unpack_array_dict", + ) + for file_format in NumPySupportedFormat.get_multi_array_formats() + ], + ] + + +_ARRAY_LIST_SAMPLE = list(_ARRAY_DICT_SAMPLE.values()) + + +def pack_array_list() -> List[np.ndarray]: + return _ARRAY_LIST_SAMPLE + + +def unpack_array_list(obj: List[np.ndarray]): + assert isinstance(obj, list) and all(isinstance(value, np.ndarray) for value in obj) + for obj_array, sample_array in zip(obj, _ARRAY_LIST_SAMPLE): + assert (obj_array == sample_array).all() + + +def validate_array_list(result: List[list]) -> bool: + # Numppy arrays are serialized as lists: + for result_array, sample_array in zip(result, _ARRAY_LIST_SAMPLE): + if not (np.array(result_array) == sample_array).all(): + return False + return True + + +def prepare_array_list_file(file_format: str, **save_kwargs) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_file.{file_format}") + formatter = NumPySupportedFormat.get_format_handler(fmt=file_format) + formatter.save(obj=_ARRAY_LIST_SAMPLE, file_path=file_path, **save_kwargs) + return file_path, temp_directory + + +class NumPyNDArrayListPackagerTester(PackagerTester): + """ + A tester for the `NumPyNDArrayListPackager`. + """ + + PACKAGER_IN_TEST = NumPyNDArrayListPackager + + TESTS = [ + PackTest( + pack_handler="pack_array_list", + log_hint="my_result: result", + validation_function=validate_array_list, + ), + *[ + UnpackTest( + prepare_input_function=prepare_array_list_file, + unpack_handler="unpack_array_list", + prepare_parameters={"file_format": file_format}, + ) + for file_format in NumPySupportedFormat.get_multi_array_formats() + ], + PackToUnpackTest( + pack_handler="pack_array_list", + log_hint="my_array: result", + ), + PackToUnpackTest( + pack_handler="pack_array_list", + log_hint="my_array: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": dict.__module__, + }, + unpack_handler="unpack_array_list", + ), + *[ + PackToUnpackTest( + pack_handler="pack_array_list", + log_hint={ + "key": "my_array", + "file_format": file_format, + }, + expected_instructions={ + "file_format": file_format, + }, + unpack_handler="unpack_array_list", + ) + for file_format in NumPySupportedFormat.get_multi_array_formats() + ], + ] diff --git a/tests/package/packagers_testers/pandas_packagers_testers.py b/tests/package/packagers_testers/pandas_packagers_testers.py new file mode 100644 index 000000000000..ff4c66959edf --- /dev/null +++ b/tests/package/packagers_testers/pandas_packagers_testers.py @@ -0,0 +1,302 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import itertools +import os +import tempfile +from typing import Tuple + +import numpy as np +import pandas as pd + +from mlrun.package.packagers.pandas_packagers import ( + PandasDataFramePackager, + PandasSeriesPackager, + PandasSupportedFormat, +) +from tests.package.packager_tester import ( + COMMON_OBJECT_INSTRUCTIONS, + PackagerTester, + PackTest, + PackToUnpackTest, + UnpackTest, +) + +# Common instructions for "object" artifacts of pandas objects: +_COMMON_OBJECT_INSTRUCTIONS = { + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": "pandas", + "object_module_version": pd.__version__, +} + +# Seed for reproducible tests: +np.random.seed(99) + + +def _prepare_result(dataframe: pd.DataFrame): + if len(dataframe.index.names) > 1: + orient = "split" + elif dataframe.index.name is not None: + orient = "dict" + else: + orient = "list" + return PandasDataFramePackager._prepare_result(obj=dataframe.to_dict(orient=orient)) + + +_DATAFRAME_SAMPLES = [ + pd.DataFrame( + data=np.random.randint(0, 256, (1000, 10)), + columns=[f"column_{i}" for i in range(10)], + ), + pd.DataFrame( + data=np.random.randint(0, 256, (1000, 10)), + columns=[f"column_{i}" for i in range(10)], + index=[i for i in range(1000)], + ), + pd.DataFrame( + data={ + **{f"column_{i}": np.random.randint(0, 256, 1000) for i in range(7)}, + **{f"column_{i}": np.arange(0, 1000) for i in range(7, 10)}, + }, + ).set_index(keys=["column_7", "column_8", "column_9"]), +] + + +def pack_dataframe(i: int) -> pd.DataFrame: + return _DATAFRAME_SAMPLES[i] + + +def validate_dataframe(result: dict, i: int) -> bool: + # Pandas dataframes are serialized as dictionaries: + return result == _prepare_result(dataframe=_DATAFRAME_SAMPLES[i]) + + +def unpack_dataframe(obj: pd.DataFrame, i: int): + assert isinstance(obj, pd.DataFrame) + assert list(obj.columns) == list(_DATAFRAME_SAMPLES[i].columns) + assert obj.columns.names == _DATAFRAME_SAMPLES[i].columns.names + assert obj.index.names == _DATAFRAME_SAMPLES[i].index.names + assert (obj == _DATAFRAME_SAMPLES[i]).all().all() + + +def prepare_dataframe_file(file_format: str, i: int) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_dataframe.{file_format}") + formatter = PandasSupportedFormat.get_format_handler(fmt=file_format) + formatter.to(obj=_DATAFRAME_SAMPLES[i], file_path=file_path) + return file_path, temp_directory + + +class PandasDataFramePackagerTester(PackagerTester): + """ + A tester for the `PandasDataFramePackager`. + """ + + PACKAGER_IN_TEST = PandasDataFramePackager + + TESTS = list( + itertools.chain.from_iterable( + [ + *[ + [ + PackTest( + pack_handler="pack_dataframe", + pack_parameters={"i": i}, + log_hint="my_result: result", + validation_function=validate_dataframe, + validation_parameters={"i": i}, + ), + UnpackTest( + prepare_input_function=prepare_dataframe_file, + unpack_handler="unpack_dataframe", + prepare_parameters={"file_format": "parquet", "i": i}, + unpack_parameters={"i": i}, + ), + PackToUnpackTest( + pack_handler="pack_dataframe", + pack_parameters={"i": i}, + log_hint="my_dataframe: object", + expected_instructions=_COMMON_OBJECT_INSTRUCTIONS, + unpack_handler="unpack_dataframe", + unpack_parameters={"i": i}, + ), + PackToUnpackTest( + pack_handler="pack_dataframe", + pack_parameters={"i": i}, + log_hint="my_dataframe: dataset", + unpack_handler="unpack_dataframe", + unpack_parameters={"i": i}, + ), + *[ + PackToUnpackTest( + pack_handler="pack_dataframe", + pack_parameters={"i": i}, + log_hint={ + "key": "my_dataframe", + "artifact_type": "file", + "file_format": file_format, + }, + expected_instructions={ + "file_format": file_format, + "read_kwargs": { + "unflatten_kwargs": { + "columns": [ + column + if not isinstance(column, tuple) + else list(column) + for column in _DATAFRAME_SAMPLES[ + i + ].columns + ], + "columns_levels": list( + _DATAFRAME_SAMPLES[i].columns.names + ), + "index_levels": list( + _DATAFRAME_SAMPLES[i].index.names + ), + } + } + if file_format + not in [ + PandasSupportedFormat.PARQUET, + PandasSupportedFormat.H5, + ] + else {}, + }, + unpack_handler="unpack_dataframe", + unpack_parameters={"i": i}, + ) + for file_format in ["parquet", "csv"] + ], + ] + for i in range(len(_DATAFRAME_SAMPLES)) + ] + ] + ) + ) + + +_SERIES_SAMPLES = [ + pd.Series(data=np.random.randint(0, 256, (100,))), + pd.Series(data=np.random.randint(0, 256, (100,)), name="my_series"), + pd.DataFrame(data=np.random.randint(0, 256, (10, 10))).mean(), + pd.DataFrame(data=np.random.randint(0, 256, (10, 10)))[0], + pd.DataFrame(data=np.random.randint(0, 256, (10, 3)), columns=["a", "b", "c"])["a"], + pd.DataFrame( + data=np.random.randint(0, 256, (10, 4)), + columns=["a", "b", "c", "d"], + index=pd.MultiIndex.from_product( + [[1, 2, 3, 4, 5], ["A", "B"]], names=["number", "letter"] + ), + )["a"], +] + + +def pack_series(i: int) -> pd.Series: + return _SERIES_SAMPLES[i] + + +def validate_series(result: dict, i: int) -> bool: + return result == _prepare_result(dataframe=pd.DataFrame(_SERIES_SAMPLES[i])) + + +def prepare_series_file(file_format: str, i: int) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_series.{file_format}") + formatter = PandasSupportedFormat.get_format_handler(fmt=file_format) + formatter.to(obj=pd.DataFrame(_SERIES_SAMPLES[i]), file_path=file_path) + return file_path, temp_directory + + +def unpack_series(obj: pd.Series, i: int): + assert isinstance(obj, pd.Series) + assert obj.name == _SERIES_SAMPLES[i].name + assert obj.index.names == _SERIES_SAMPLES[i].index.names + assert (obj == _SERIES_SAMPLES[i]).all() + + +class PandasSeriesPackagerTester(PackagerTester): + """ + A tester for the `PandasSeriesPackager`. + """ + + PACKAGER_IN_TEST = PandasSeriesPackager + + TESTS = list( + itertools.chain.from_iterable( + [ + *[ + [ + PackTest( + pack_handler="pack_series", + pack_parameters={"i": i}, + log_hint="my_result: result", + validation_function=validate_series, + validation_parameters={"i": i}, + ), + PackToUnpackTest( + pack_handler="pack_series", + pack_parameters={"i": i}, + log_hint="my_dataframe: object", + expected_instructions=_COMMON_OBJECT_INSTRUCTIONS, + unpack_handler="unpack_series", + unpack_parameters={"i": i}, + ), + PackToUnpackTest( + pack_handler="pack_series", + pack_parameters={"i": i}, + log_hint={ + "key": "my_series", + "artifact_type": "file", + }, + expected_instructions={ + "file_format": "parquet" if i in [1, 4, 5] else "csv", + "read_kwargs": { + "unflatten_kwargs": { + # Unnamed series will have a column named 0 by default when cast to dataframe. + # Because we cast to dataframe before writing to file, 0 will be written for + # unnamed series samples: + "columns": [ + _SERIES_SAMPLES[i].name + if _SERIES_SAMPLES[i].name is not None + else 0 + ], + "columns_levels": [None], + "index_levels": list( + _SERIES_SAMPLES[i].index.names + ), + } + } + if i not in [1, 4, 5] + else {}, + "column_name": _SERIES_SAMPLES[i].name, + }, + unpack_handler="unpack_series", + unpack_parameters={"i": i}, + ), + ] + for i in range(len(_SERIES_SAMPLES)) + ], + [ + UnpackTest( + prepare_input_function=prepare_series_file, + unpack_handler="unpack_series", + prepare_parameters={"file_format": "parquet", "i": i}, + unpack_parameters={"i": i}, + ) + for i in [1, 4, 5] + ], + ] + ) + ) diff --git a/tests/package/packagers_testers/python_standard_library_packagers_testers.py b/tests/package/packagers_testers/python_standard_library_packagers_testers.py new file mode 100644 index 000000000000..ef8f2615df67 --- /dev/null +++ b/tests/package/packagers_testers/python_standard_library_packagers_testers.py @@ -0,0 +1,938 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import ast +import os +import pathlib +import tempfile +from typing import Tuple + +from mlrun import MLClientCtx +from mlrun.package.packagers.python_standard_library_packagers import ( + BoolPackager, + BytearrayPackager, + BytesPackager, + DictPackager, + FloatPackager, + FrozensetPackager, + IntPackager, + ListPackager, + PathPackager, + SetPackager, + StrPackager, + TuplePackager, +) +from mlrun.package.utils import ArchiveSupportedFormat, StructFileSupportedFormat +from tests.package.packager_tester import ( + COMMON_OBJECT_INSTRUCTIONS, + PackagerTester, + PackTest, + PackToUnpackTest, + UnpackTest, +) + +# ---------------------------------------------------------------------------------------------------------------------- +# builtins packagers: +# ---------------------------------------------------------------------------------------------------------------------- + +_INT_SAMPLE = 7 + + +def pack_int() -> int: + return _INT_SAMPLE + + +def validate_int(result: int) -> bool: + return result == _INT_SAMPLE + + +def unpack_int(obj: int): + assert isinstance(obj, int) + assert obj == _INT_SAMPLE + + +class IntPackagerTester(PackagerTester): + """ + A tester for the `IntPackager`. + """ + + PACKAGER_IN_TEST = IntPackager + + TESTS = [ + PackTest( + pack_handler="pack_int", + log_hint="my_result", + validation_function=validate_int, + ), + PackToUnpackTest( + pack_handler="pack_int", + log_hint="my_result", + ), + PackToUnpackTest( + pack_handler="pack_int", + log_hint="my_result: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": int.__module__, + }, + unpack_handler="unpack_int", + ), + ] + + +_FLOAT_SAMPLE = 0.97123 + + +def pack_float() -> float: + return _FLOAT_SAMPLE + + +def validate_float(result: float) -> bool: + return result == _FLOAT_SAMPLE + + +def unpack_float(obj: float): + assert isinstance(obj, float) + assert obj == _FLOAT_SAMPLE + + +class FloatPackagerTester(PackagerTester): + """ + A tester for the `FloatPackager`. + """ + + PACKAGER_IN_TEST = FloatPackager + + TESTS = [ + PackTest( + pack_handler="pack_float", + log_hint="my_result", + validation_function=validate_float, + ), + PackToUnpackTest( + pack_handler="pack_float", + log_hint="my_result", + ), + PackToUnpackTest( + pack_handler="pack_float", + log_hint="my_result: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": float.__module__, + }, + unpack_handler="unpack_float", + ), + ] + + +_BOOL_SAMPLE = True + + +def pack_bool() -> float: + return _BOOL_SAMPLE + + +def validate_bool(result: bool) -> bool: + return result is _BOOL_SAMPLE + + +def unpack_bool(obj: bool): + assert isinstance(obj, bool) + assert obj is _BOOL_SAMPLE + + +class BoolPackagerTester(PackagerTester): + """ + A tester for the `BoolPackager`. + """ + + PACKAGER_IN_TEST = BoolPackager + + TESTS = [ + PackTest( + pack_handler="pack_bool", + log_hint="my_result", + validation_function=validate_bool, + ), + PackToUnpackTest( + pack_handler="pack_bool", + log_hint="my_result", + ), + PackToUnpackTest( + pack_handler="pack_bool", + log_hint="my_result: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": bool.__module__, + }, + unpack_handler="unpack_bool", + ), + ] + + +_STR_RESULT_SAMPLE = "I'm a string." +_STR_FILE_SAMPLE = "Something written in a file..." +_STR_DIRECTORY_FILES_SAMPLE = "I'm text file number {}" + + +def pack_str() -> str: + return _STR_RESULT_SAMPLE + + +def pack_str_path_file(context: MLClientCtx) -> str: + file_path = os.path.join(context.artifact_path, "my_file.txt") + with open(file_path, "w") as file: + file.write(_STR_FILE_SAMPLE) + return file_path + + +def pack_str_path_directory(context: MLClientCtx) -> str: + directory_path = os.path.join(context.artifact_path, "my_directory") + os.makedirs(directory_path) + for i in range(5): + with open(os.path.join(directory_path, f"file_{i}.txt"), "w") as file: + file.write(_STR_DIRECTORY_FILES_SAMPLE.format(i)) + return directory_path + + +def validate_str_result(result: str) -> bool: + return result == _STR_RESULT_SAMPLE + + +def unpack_str(obj: str): + assert isinstance(obj, str) + assert obj == _STR_RESULT_SAMPLE + + +def unpack_str_path_file(obj: str): + assert isinstance(obj, str) + with open(obj, "r") as file: + file_content = file.read() + assert file_content == _STR_FILE_SAMPLE + + +def unpack_str_path_directory(obj: str): + assert isinstance(obj, str) + for i in range(5): + with open(os.path.join(obj, f"file_{i}.txt"), "r") as file: + file_content = file.read() + assert file_content == _STR_DIRECTORY_FILES_SAMPLE.format(i) + + +def prepare_str_path_file() -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, "my_file.txt") + with open(file_path, "w") as file: + file.write(_STR_FILE_SAMPLE) + return file_path, temp_directory + + +class StrPackagerTester(PackagerTester): + """ + A tester for the `StrPackager`. + """ + + PACKAGER_IN_TEST = StrPackager + + TESTS = [ + PackTest( + pack_handler="pack_str", + log_hint="my_result", + validation_function=validate_str_result, + pack_parameters={}, + ), + UnpackTest( + prepare_input_function=prepare_str_path_file, + unpack_handler="unpack_str_path_file", + ), + PackToUnpackTest( + pack_handler="pack_str", + log_hint="my_result", + ), + PackToUnpackTest( + pack_handler="pack_str", + log_hint="my_result: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": str.__module__, + }, + unpack_handler="unpack_str", + ), + PackToUnpackTest( + pack_handler="pack_str_path_file", + log_hint="my_file: path", + expected_instructions={"is_directory": False}, + unpack_handler="unpack_str_path_file", + ), + *[ + PackToUnpackTest( + pack_handler="pack_str_path_directory", + log_hint={ + "key": "my_dir", + "artifact_type": "path", + "archive_format": archive_format, + }, + expected_instructions={ + "is_directory": True, + "archive_format": archive_format, + }, + unpack_handler="unpack_str_path_directory", + ) + for archive_format in ArchiveSupportedFormat.get_all_formats() + ], + ] + + +_DICT_SAMPLE = {"a1": {"a2": [1, 2, 3], "b2": [4, 5, 6]}, "b1": {"b2": [4, 5, 6]}} + + +def pack_dict() -> dict: + return _DICT_SAMPLE + + +def unpack_dict(obj: dict): + assert isinstance(obj, dict) + assert obj == _DICT_SAMPLE + + +def validate_dict_result(result: dict) -> bool: + return result == _DICT_SAMPLE + + +def prepare_dict_file(file_format: str) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_file.{file_format}") + formatter = StructFileSupportedFormat.get_format_handler(fmt=file_format) + formatter.write(obj=_DICT_SAMPLE, file_path=file_path) + return file_path, temp_directory + + +class DictPackagerTester(PackagerTester): + """ + A tester for the `DictPackager`. + """ + + PACKAGER_IN_TEST = DictPackager + + TESTS = [ + PackTest( + pack_handler="pack_dict", + log_hint="my_dict", + validation_function=validate_dict_result, + ), + *[ + UnpackTest( + prepare_input_function=prepare_dict_file, + unpack_handler="unpack_dict", + prepare_parameters={"file_format": file_format}, + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + PackToUnpackTest( + pack_handler="pack_dict", + log_hint="my_dict", + ), + PackToUnpackTest( + pack_handler="pack_dict", + log_hint="my_dict: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": dict.__module__, + }, + unpack_handler="unpack_dict", + ), + *[ + PackToUnpackTest( + pack_handler="pack_dict", + log_hint={ + "key": "my_dict", + "artifact_type": "file", + "file_format": file_format, + }, + expected_instructions={ + "file_format": file_format, + }, + unpack_handler="unpack_dict", + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + ] + + +_LIST_SAMPLE = [1, 2, 3, {"a": 1, "b": 2}] + + +def pack_list() -> list: + return _LIST_SAMPLE + + +def unpack_list(obj: list): + assert isinstance(obj, list) + assert obj == _LIST_SAMPLE + + +def validate_list_result(result: list) -> bool: + return result == _LIST_SAMPLE + + +def prepare_list_file(file_format: str) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_file.{file_format}") + formatter = StructFileSupportedFormat.get_format_handler(fmt=file_format) + formatter.write(obj=_LIST_SAMPLE, file_path=file_path) + return file_path, temp_directory + + +class ListPackagerTester(PackagerTester): + """ + A tester for the `ListPackager`. + """ + + PACKAGER_IN_TEST = ListPackager + + TESTS = [ + PackTest( + pack_handler="pack_list", + log_hint="my_list", + validation_function=validate_list_result, + ), + *[ + UnpackTest( + prepare_input_function=prepare_list_file, + unpack_handler="unpack_list", + prepare_parameters={"file_format": file_format}, + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + PackToUnpackTest( + pack_handler="pack_list", + log_hint="my_list", + ), + PackToUnpackTest( + pack_handler="pack_list", + log_hint="my_list: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": tuple.__module__, + }, + unpack_handler="unpack_list", + ), + *[ + PackToUnpackTest( + pack_handler="pack_list", + log_hint={ + "key": "my_list", + "artifact_type": "file", + "file_format": file_format, + }, + expected_instructions={ + "file_format": file_format, + }, + unpack_handler="unpack_list", + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + ] + + +_TUPLE_SAMPLE = (1, 2, 3) + + +def pack_tuple() -> tuple: + return _TUPLE_SAMPLE + + +def unpack_tuple(obj: tuple): + assert isinstance(obj, tuple) + assert obj == _TUPLE_SAMPLE + + +def validate_tuple_result(result: list) -> bool: + # Tuples are serialized as lists: + return tuple(result) == _TUPLE_SAMPLE + + +def prepare_tuple_file(file_format: str) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_file.{file_format}") + formatter = StructFileSupportedFormat.get_format_handler(fmt=file_format) + formatter.write(obj=list(_TUPLE_SAMPLE), file_path=file_path) + return file_path, temp_directory + + +class TuplePackagerTester(PackagerTester): + """ + A tester for the `TuplePackager`. + """ + + PACKAGER_IN_TEST = TuplePackager + + TESTS = [ + PackTest( + pack_handler="pack_tuple", + log_hint="my_tuple", + validation_function=validate_tuple_result, + ), + *[ + UnpackTest( + prepare_input_function=prepare_tuple_file, + unpack_handler="unpack_tuple", + prepare_parameters={"file_format": file_format}, + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + PackToUnpackTest( + pack_handler="pack_tuple", + log_hint="my_tuple", + ), + PackToUnpackTest( + pack_handler="pack_tuple", + log_hint="my_tuple: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": tuple.__module__, + }, + unpack_handler="unpack_tuple", + ), + *[ + PackToUnpackTest( + pack_handler="pack_tuple", + log_hint={ + "key": "my_tuple", + "artifact_type": "file", + "file_format": file_format, + }, + expected_instructions={ + "file_format": file_format, + }, + unpack_handler="unpack_tuple", + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + ] + + +_SET_SAMPLE = {1, 2, 3} + + +def pack_set() -> set: + return _SET_SAMPLE + + +def unpack_set(obj: set): + assert isinstance(obj, set) + assert obj == _SET_SAMPLE + + +def validate_set_result(result: list) -> bool: + # Sets are serialized as lists: + return set(result) == _SET_SAMPLE + + +def prepare_set_file(file_format: str) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_file.{file_format}") + formatter = StructFileSupportedFormat.get_format_handler(fmt=file_format) + formatter.write(obj=list(_SET_SAMPLE), file_path=file_path) + return file_path, temp_directory + + +class SetPackagerTester(PackagerTester): + """ + A tester for the `SetPackager`. + """ + + PACKAGER_IN_TEST = SetPackager + + TESTS = [ + PackTest( + pack_handler="pack_set", + log_hint="my_set", + validation_function=validate_set_result, + ), + *[ + UnpackTest( + prepare_input_function=prepare_set_file, + unpack_handler="unpack_set", + prepare_parameters={"file_format": file_format}, + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + PackToUnpackTest( + pack_handler="pack_set", + log_hint="my_set", + ), + PackToUnpackTest( + pack_handler="pack_set", + log_hint="my_set: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": set.__module__, + }, + unpack_handler="unpack_set", + ), + *[ + PackToUnpackTest( + pack_handler="pack_set", + log_hint={ + "key": "my_set", + "artifact_type": "file", + "file_format": file_format, + }, + expected_instructions={ + "file_format": file_format, + }, + unpack_handler="unpack_set", + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + ] + + +_FROZENSET_SAMPLE = frozenset([1, 2, 3]) + + +def pack_frozenset() -> frozenset: + return _FROZENSET_SAMPLE + + +def unpack_frozenset(obj: frozenset): + assert isinstance(obj, frozenset) + assert obj == _FROZENSET_SAMPLE + + +def validate_frozenset_result(result: list) -> bool: + # Frozen sets are serialized as lists: + return frozenset(result) == _FROZENSET_SAMPLE + + +def prepare_frozenset_file(file_format: str) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_file.{file_format}") + formatter = StructFileSupportedFormat.get_format_handler(fmt=file_format) + formatter.write(obj=list(_FROZENSET_SAMPLE), file_path=file_path) + return file_path, temp_directory + + +class FrozensetPackagerTester(PackagerTester): + """ + A tester for the `FrozensetPackager`. + """ + + PACKAGER_IN_TEST = FrozensetPackager + + TESTS = [ + PackTest( + pack_handler="pack_frozenset", + log_hint="my_frozenset", + validation_function=validate_frozenset_result, + ), + *[ + UnpackTest( + prepare_input_function=prepare_frozenset_file, + unpack_handler="unpack_frozenset", + prepare_parameters={"file_format": file_format}, + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + PackToUnpackTest( + pack_handler="pack_frozenset", + log_hint="my_frozenset", + ), + PackToUnpackTest( + pack_handler="pack_frozenset", + log_hint="my_frozenset: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": set.__module__, + }, + unpack_handler="unpack_frozenset", + ), + *[ + PackToUnpackTest( + pack_handler="pack_frozenset", + log_hint={ + "key": "my_frozenset", + "artifact_type": "file", + "file_format": file_format, + }, + expected_instructions={ + "file_format": file_format, + }, + unpack_handler="unpack_frozenset", + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + ] + + +_BYTEARRAY_SAMPLE = bytearray([1, 2, 3]) + + +def pack_bytearray() -> bytearray: + return _BYTEARRAY_SAMPLE + + +def unpack_bytearray(obj: bytearray): + assert isinstance(obj, bytearray) + assert obj == _BYTEARRAY_SAMPLE + + +def validate_bytearray_result(result: str) -> bool: + # Byte arrays are serialized as strings (not decoded): + return bytearray(ast.literal_eval(result)) == _BYTEARRAY_SAMPLE + + +def prepare_bytearray_file(file_format: str) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_file.{file_format}") + formatter = StructFileSupportedFormat.get_format_handler(fmt=file_format) + formatter.write(obj=list(_BYTEARRAY_SAMPLE), file_path=file_path) + return file_path, temp_directory + + +class BytearrayPackagerTester(PackagerTester): + """ + A tester for the `BytearrayPackager`. + """ + + PACKAGER_IN_TEST = BytearrayPackager + + TESTS = [ + PackTest( + pack_handler="pack_bytearray", + log_hint="my_bytearray", + validation_function=validate_bytearray_result, + ), + *[ + UnpackTest( + prepare_input_function=prepare_bytearray_file, + unpack_handler="unpack_bytearray", + prepare_parameters={"file_format": file_format}, + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + PackToUnpackTest( + pack_handler="pack_bytearray", + log_hint="my_bytearray", + ), + PackToUnpackTest( + pack_handler="pack_bytearray", + log_hint="my_bytearray: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": set.__module__, + }, + unpack_handler="unpack_bytearray", + ), + *[ + PackToUnpackTest( + pack_handler="pack_bytearray", + log_hint={ + "key": "my_bytearray", + "artifact_type": "file", + "file_format": file_format, + }, + expected_instructions={ + "file_format": file_format, + }, + unpack_handler="unpack_bytearray", + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + ] + + +_BYTES_SAMPLE = b"I'm a byte string." + + +def pack_bytes() -> bytes: + return _BYTES_SAMPLE + + +def unpack_bytes(obj: bytes): + assert isinstance(obj, bytes) + assert obj == _BYTES_SAMPLE + + +def validate_bytes_result(result: str) -> bool: + # Bytes are serialized as strings (not decoded): + return ast.literal_eval(result) == _BYTES_SAMPLE + + +def prepare_bytes_file(file_format: str) -> Tuple[str, str]: + temp_directory = tempfile.mkdtemp() + file_path = os.path.join(temp_directory, f"my_file.{file_format}") + formatter = StructFileSupportedFormat.get_format_handler(fmt=file_format) + formatter.write(obj=list(_BYTES_SAMPLE), file_path=file_path) + return file_path, temp_directory + + +class BytesPackagerTester(PackagerTester): + """ + A tester for the `BytesPackager`. + """ + + PACKAGER_IN_TEST = BytesPackager + + TESTS = [ + PackTest( + pack_handler="pack_bytes", + log_hint="my_bytes", + validation_function=validate_bytes_result, + ), + *[ + UnpackTest( + prepare_input_function=prepare_bytes_file, + unpack_handler="unpack_bytes", + prepare_parameters={"file_format": file_format}, + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + PackToUnpackTest( + pack_handler="pack_bytes", + log_hint="my_bytes", + ), + PackToUnpackTest( + pack_handler="pack_bytes", + log_hint="my_bytes: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": set.__module__, + }, + unpack_handler="unpack_bytes", + ), + *[ + PackToUnpackTest( + pack_handler="pack_bytes", + log_hint={ + "key": "my_bytes", + "artifact_type": "file", + "file_format": file_format, + }, + expected_instructions={ + "file_format": file_format, + }, + unpack_handler="unpack_bytes", + ) + for file_format in StructFileSupportedFormat.get_all_formats() + ], + ] + + +# ---------------------------------------------------------------------------------------------------------------------- +# pathlib packagers: +# ---------------------------------------------------------------------------------------------------------------------- + + +_PATH_RESULT_SAMPLE = pathlib.Path("I'm a path.") + + +def pack_path() -> pathlib.Path: + return _PATH_RESULT_SAMPLE + + +def pack_path_file(context: MLClientCtx) -> pathlib.Path: + file_path = pathlib.Path(context.artifact_path) / "my_file.txt" + with open(file_path, "w") as file: + file.write(_STR_FILE_SAMPLE) + return file_path + + +def pack_path_directory(context: MLClientCtx) -> pathlib.Path: + directory_path = pathlib.Path(context.artifact_path) / "my_directory" + os.makedirs(directory_path) + for i in range(5): + with open(directory_path / f"file_{i}.txt", "w") as file: + file.write(_STR_DIRECTORY_FILES_SAMPLE.format(i)) + return directory_path + + +def validate_path_result(result: pathlib.Path) -> bool: + return pathlib.Path(result) == _PATH_RESULT_SAMPLE + + +def unpack_path(obj: pathlib.Path): + assert isinstance(obj, pathlib.Path) + assert obj == _PATH_RESULT_SAMPLE + + +def unpack_path_file(obj: pathlib.Path): + assert isinstance(obj, pathlib.Path) + with open(obj, "r") as file: + file_content = file.read() + assert file_content == _STR_FILE_SAMPLE + + +def unpack_path_directory(obj: pathlib.Path): + assert isinstance(obj, pathlib.Path) + for i in range(5): + with open(obj / f"file_{i}.txt", "r") as file: + file_content = file.read() + assert file_content == _STR_DIRECTORY_FILES_SAMPLE.format(i) + + +class PathPackagerTester(PackagerTester): + """ + A tester for the `PathPackager`. + """ + + PACKAGER_IN_TEST = PathPackager + + TESTS = [ + PackTest( + pack_handler="pack_path", + log_hint="my_result: result", + validation_function=validate_path_result, + pack_parameters={}, + ), + UnpackTest( + prepare_input_function=prepare_str_path_file, # Using str preparing method - same thing + unpack_handler="unpack_path_file", + ), + PackToUnpackTest( + pack_handler="pack_path", + log_hint="my_result: result", + ), + PackToUnpackTest( + pack_handler="pack_path", + log_hint="my_result: object", + expected_instructions={ + **COMMON_OBJECT_INSTRUCTIONS, + "object_module_name": pathlib.Path.__module__, + }, + unpack_handler="unpack_path", + ), + PackToUnpackTest( + pack_handler="pack_path_file", + log_hint="my_file", + expected_instructions={"is_directory": False}, + unpack_handler="unpack_path_file", + ), + *[ + PackToUnpackTest( + pack_handler="pack_path_directory", + log_hint={ + "key": "my_dir", + "archive_format": archive_format, + }, + expected_instructions={ + "is_directory": True, + "archive_format": archive_format, + }, + unpack_handler="unpack_path_directory", + ) + for archive_format in ArchiveSupportedFormat.get_all_formats() + ], + ] diff --git a/tests/package/test_context_handler.py b/tests/package/test_context_handler.py new file mode 100644 index 000000000000..ec019934ccee --- /dev/null +++ b/tests/package/test_context_handler.py @@ -0,0 +1,109 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from types import FunctionType + +import pytest + +import mlrun +from mlrun import MLClientCtx +from mlrun.package import ContextHandler +from mlrun.runtimes import RunError + + +def test_init(): + """ + During the context handler's initialization, it collects the default packagers found in the class variables + `_MLRUN_REQUIREMENTS_PACKAGERS`, `_EXTENDED_PACKAGERS` and `_MLRUN_FRAMEWORKS_PACKAGERS` so this test is making sure + there is no error raised during the init collection of packagers when new ones are being added. + """ + ContextHandler() + + +def _look_for_context_via_get_or_create(not_a_context=None): + assert not isinstance(not_a_context, MLClientCtx) + context_handler = ContextHandler() + context_handler.look_for_context(args=(), kwargs={}) + return context_handler.is_context_available() + + +def _look_for_context_via_header(context: MLClientCtx): + context_handler = ContextHandler() + context_handler.look_for_context(args=(), kwargs={"context": context}) + return context_handler.is_context_available() + + +@pytest.mark.parametrize( + "func", + [_look_for_context_via_get_or_create, _look_for_context_via_header], +) +@pytest.mark.parametrize("run_via_mlrun", [True, False]) +def test_look_for_context(rundb_mock, func: FunctionType, run_via_mlrun: bool): + """ + Test the `look_for_context` method of the context handler. The method should find or create a context only when it + is being run through MLRun. + + :param rundb_mock: A runDB mock fixture. + :param func: The function to run in the test. + :param run_via_mlrun: Boolean flag to expect to find a context (run via MLRun) as True and to not find a context + as False. + """ + if not run_via_mlrun: + assert not func(None) + return + run = mlrun.new_function().run(handler=func, returns=["result:result"]) + assert run.status.results["result"] + + +def collect_custom_packagers(): + return + + +@pytest.mark.parametrize( + "packager, expected_result", + [ + ("tests.package.test_packagers_manager.PackagerA", True), + ("tests.package.packagers_testers.default_packager_tester.SomeClass", False), + ], +) +@pytest.mark.parametrize("is_mandatory", [True, False]) +def test_custom_packagers( + rundb_mock, packager: str, expected_result: bool, is_mandatory: bool +): + """ + Test the custom packagers collection from a project during the `look_for_context` method. + + :param rundb_mock: A runDB mock fixture. + :param packager: The custom packager to collect. + :param expected_result: Whether the packager collection should succeed. + :param is_mandatory: If the packager is mandatory for the run or not. Mandatory packagers will always raise + exception if they couldn't be collected. + """ + project = mlrun.get_or_create_project(name="default") + project.add_custom_packager( + packager=packager, + is_mandatory=is_mandatory, + ) + project.save_to_db() + mlrun_function = project.set_function( + func=__file__, name="test_custom_packagers", image="mlrun/mlrun" + ) + if expected_result or not is_mandatory: + mlrun_function.run(handler="collect_custom_packagers", local=True) + return + try: + mlrun_function.run(handler="collect_custom_packagers", local=True) + assert False + except RunError: + pass diff --git a/tests/package/test_packagers.py b/tests/package/test_packagers.py new file mode 100644 index 000000000000..221c1cce2599 --- /dev/null +++ b/tests/package/test_packagers.py @@ -0,0 +1,309 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import shutil +import tempfile +import typing +from typing import List, Tuple, Type, Union + +import pytest + +import mlrun +from mlrun.package import ArtifactType, LogHintKey, PackagersManager +from mlrun.package.utils import LogHintUtils +from mlrun.runtimes import KubejobRuntime + +from .packager_tester import PackagerTester, PackTest, PackToUnpackTest, UnpackTest +from .packagers_testers.default_packager_tester import DefaultPackagerTester +from .packagers_testers.numpy_packagers_testers import ( + NumPyNDArrayDictPackagerTester, + NumPyNDArrayListPackagerTester, + NumPyNDArrayPackagerTester, + NumPyNumberPackagerTester, +) +from .packagers_testers.pandas_packagers_testers import ( + PandasDataFramePackagerTester, + PandasSeriesPackagerTester, +) +from .packagers_testers.python_standard_library_packagers_testers import ( + BoolPackagerTester, + BytearrayPackagerTester, + BytesPackagerTester, + DictPackagerTester, + FloatPackagerTester, + FrozensetPackagerTester, + IntPackagerTester, + ListPackagerTester, + PathPackagerTester, + SetPackagerTester, + StrPackagerTester, + TuplePackagerTester, +) + +# All the testers to be included in the tests: +_PACKAGERS_TESTERS = [ + DefaultPackagerTester, + BoolPackagerTester, + BytearrayPackagerTester, + BytesPackagerTester, + DictPackagerTester, + FloatPackagerTester, + FrozensetPackagerTester, + IntPackagerTester, + ListPackagerTester, + SetPackagerTester, + StrPackagerTester, + TuplePackagerTester, + PathPackagerTester, + NumPyNDArrayPackagerTester, + NumPyNumberPackagerTester, + NumPyNDArrayDictPackagerTester, + NumPyNDArrayListPackagerTester, + PandasDataFramePackagerTester, + PandasSeriesPackagerTester, +] + + +def _get_tests_tuples( + test_type: Union[Type[PackTest], Type[UnpackTest], Type[PackToUnpackTest]] +) -> List[Tuple[Type[PackagerTester], PackTest]]: + return [ + (tester, test) + for tester in _PACKAGERS_TESTERS + for test in tester.TESTS + if isinstance(test, test_type) + ] + + +def _setup_test( + tester: Type[PackagerTester], + test: Union[PackTest, UnpackTest, PackToUnpackTest], + test_directory: str, +) -> KubejobRuntime: + # Enabled logging tuples only if the tuple test is about to be setup: + if isinstance(test, (PackTest, PackToUnpackTest)) and tester is TuplePackagerTester: + mlrun.mlconf.packagers.pack_tuples = True + + # Create a project for this tester: + project = mlrun.get_or_create_project(name="default", context=test_directory) + + # Create a MLRun function using the tester source file (all the functions must be located in it): + return project.set_function( + func=inspect.getfile(tester), + name=tester.__name__.lower(), + kind="job", + image="mlrun/mlrun", + ) + + +def _get_key_and_artifact_type( + tester: Type[PackagerTester], test: Union[PackTest, PackToUnpackTest] +) -> Tuple[str, str]: + # Parse the log hint (in case it is a string): + log_hint = LogHintUtils.parse_log_hint(log_hint=test.log_hint) + + # Extract the key: + key = log_hint[LogHintKey.KEY] + + # Get the artifact type (either from the log hint or from the packager - the default artifact type): + artifact_type = ( + log_hint[LogHintKey.ARTIFACT_TYPE] + if LogHintKey.ARTIFACT_TYPE in log_hint + else tester.PACKAGER_IN_TEST.get_default_packing_artifact_type( + obj=test.default_artifact_type_object + ) + ) + + return key, artifact_type + + +@pytest.mark.parametrize( + "tester, test", + _get_tests_tuples(test_type=PackTest), +) +def test_packager_pack(rundb_mock, tester: Type[PackagerTester], test: PackTest): + """ + Test a packager's packing. + + :param rundb_mock: A runDB mock fixture. + :param tester: The `PackagerTester` class to get the functions to run from. + :param test: The `PackTest` tuple with the test parameters. + """ + # Set up the test, creating a project and a MLRun function: + test_directory = tempfile.TemporaryDirectory() + mlrun_function = _setup_test( + tester=tester, test=test, test_directory=test_directory.name + ) + + # Run the packing handler: + try: + pack_run = mlrun_function.run( + name="pack", + handler=test.pack_handler, + params=test.pack_parameters, + returns=[test.log_hint], + artifact_path=test_directory.name, + local=True, + ) + + # Verify the packaged output: + key, artifact_type = _get_key_and_artifact_type(tester=tester, test=test) + if artifact_type == ArtifactType.RESULT: + assert key in pack_run.status.results + assert test.validation_function( + pack_run.status.results[key], **test.validation_parameters + ) + else: + assert key in pack_run.outputs + assert test.validation_function( + pack_run._artifact(key=key), **test.validation_parameters + ) + except Exception as exception: + # An error was raised, check if the test failed or should have failed: + if test.exception is None: + raise exception + # Make sure the expected exception was raised: + assert test.exception in str(exception) + + # Clear the tests outputs: + test_directory.cleanup() + + +@pytest.mark.parametrize( + "tester, test", + _get_tests_tuples(test_type=UnpackTest), +) +def test_packager_unpack(rundb_mock, tester: Type[PackagerTester], test: UnpackTest): + """ + Test a packager's unpacking. + + :param rundb_mock: A runDB mock fixture. + :param tester: The `PackagerTester` class to get the functions to run from. + :param test: The `UnpackTest` tuple with the test parameters. + """ + # Create the input path to send for unpacking: + input_path, temp_directory = test.prepare_input_function(**test.prepare_parameters) + + # Set up the test, creating a project and a MLRun function: + test_directory = tempfile.TemporaryDirectory() + mlrun_function = _setup_test( + tester=tester, test=test, test_directory=test_directory.name + ) + + # Run the packing handler: + try: + mlrun_function.run( + name="unpack", + handler=test.unpack_handler, + inputs={"obj": input_path}, + params=test.unpack_parameters, + artifact_path=test_directory.name, + local=True, + ) + except Exception as exception: + # An error was raised, check if the test failed or should have failed: + if test.exception is None: + raise exception + # Make sure the expected exception was raised: + assert test.exception in str(exception) + + # Clear the tests outputs: + shutil.rmtree(temp_directory) + test_directory.cleanup() + + +@pytest.mark.parametrize( + "tester, test", + _get_tests_tuples(test_type=PackToUnpackTest), +) +def test_packager_pack_to_unpack( + rundb_mock, tester: Type[PackagerTester], test: PackToUnpackTest +): + """ + Test a packager's packing and unpacking by running two MLRun functions one after the other, one will return the + value the packager should pack and the other should get the data item to make the packager unpack. + + :param rundb_mock: A runDB mock fixture. + :param tester: The `PackagerTester` class to get the functions to run from. + :param test: The `PackToUnpackTest` tuple with the test parameters. + """ + # Set up the test, creating a project and a MLRun function: + test_directory = tempfile.TemporaryDirectory() + mlrun_function = _setup_test( + tester=tester, test=test, test_directory=test_directory.name + ) + + # Run the packing handler: + try: + pack_run = mlrun_function.run( + name="pack", + handler=test.pack_handler, + params=test.pack_parameters, + returns=[test.log_hint], + artifact_path=test_directory.name, + local=True, + ) + + # Verify the outputs are logged (artifact type as "result" will stop the test here as it cannot be unpacked): + key, artifact_type = _get_key_and_artifact_type(tester=tester, test=test) + if artifact_type == ArtifactType.RESULT: + assert key in pack_run.status.results + return + assert key in pack_run.outputs + + # Validate the packager manager notes and packager instructions: + unpackaging_instructions = pack_run._artifact(key=key)["spec"][ + "unpackaging_instructions" + ] + assert ( + unpackaging_instructions["packager_name"] + == tester.PACKAGER_IN_TEST.__name__ + ) + if tester.PACKAGER_IN_TEST.PACKABLE_OBJECT_TYPE is not ...: + # Check the object name noted match the packager handled type (at least subclass of it): + packable_object_type_name = PackagersManager._get_type_name( + typ=tester.PACKAGER_IN_TEST.PACKABLE_OBJECT_TYPE + if tester.PACKAGER_IN_TEST.PACKABLE_OBJECT_TYPE.__module__ != "typing" + else typing.get_origin(tester.PACKAGER_IN_TEST.PACKABLE_OBJECT_TYPE) + ) + assert unpackaging_instructions[ + "object_type" + ] == packable_object_type_name or issubclass( + PackagersManager._get_type_from_name( + type_name=unpackaging_instructions["object_type"] + ), + tester.PACKAGER_IN_TEST.PACKABLE_OBJECT_TYPE, + ) + assert unpackaging_instructions["artifact_type"] == artifact_type + assert unpackaging_instructions["instructions"] == test.expected_instructions + + # Run the unpacking handler: + mlrun_function.run( + name="unpack", + handler=test.unpack_handler, + inputs={"obj": pack_run.outputs[key]}, + params=test.unpack_parameters, + artifact_path=test_directory.name, + local=True, + ) + except Exception as exception: + # An error was raised, check if the test failed or should have failed: + if test.exception is None: + raise exception + # Make sure the expected exception was raised: + assert test.exception in str(exception) + + # Clear the tests outputs: + test_directory.cleanup() diff --git a/tests/package/test_packagers_manager.py b/tests/package/test_packagers_manager.py new file mode 100644 index 000000000000..feda78b662da --- /dev/null +++ b/tests/package/test_packagers_manager.py @@ -0,0 +1,452 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import shutil +import tempfile +import zipfile +from typing import Any, Dict, List, Tuple, Type, Union + +import pytest + +from mlrun import DataItem +from mlrun.artifacts import Artifact +from mlrun.errors import MLRunInvalidArgumentError +from mlrun.package import ( + DefaultPackager, + MLRunPackageCollectionError, + MLRunPackageUnpackingError, + Packager, + PackagersManager, +) + + +class PackagerA(Packager): + """ + A simple packager to pack strings as results. + """ + + PACKABLE_OBJECT_TYPE = str + + @classmethod + def get_default_packing_artifact_type(cls, obj: Any) -> str: + return "result" + + @classmethod + def get_default_unpacking_artifact_type(cls, data_item: DataItem) -> str: + return "result" + + @classmethod + def get_supported_artifact_types(cls) -> List[str]: + return ["result"] + + @classmethod + def is_packable(cls, obj: Any, artifact_type: str = None) -> bool: + return type(obj) is cls.PACKABLE_OBJECT_TYPE and artifact_type == "result" + + @classmethod + def pack( + cls, obj: str, artifact_type: str = None, configurations: dict = None + ) -> dict: + return {f"{configurations['key']}_from_PackagerA": obj} + + @classmethod + def unpack( + cls, + data_item: DataItem, + artifact_type: str = None, + instructions: dict = None, + ) -> str: + pass + + +class PackagerB(DefaultPackager): + """ + A default packager for strings. The artifact types "b1" and "b2" will be used to verify the future clear feature. + """ + + PACKABLE_OBJECT_TYPE = str + DEFAULT_PACKING_ARTIFACT_TYPE = "b1" + DEFAULT_UNPACKING_ARTIFACT_TYPE = "b1" + + @classmethod + def pack_result(cls, obj: Any, key: str) -> dict: + return {f"{key}_from_PackagerB": obj} + + @classmethod + def pack_b1( + cls, + obj: str, + key: str, + fmt: str, + ) -> Tuple[Artifact, dict]: + # Create a temp directory: + path = tempfile.mkdtemp() + + # Create a file: + file_path = os.path.join(path, f"{key}.{fmt}") + with open(file_path, "w") as file: + file.write(obj) + + # Note for clearance: + cls.add_future_clearing_path(path=file_path) + + return Artifact(key=key, src_path=file_path), {"temp_dir": path} + + @classmethod + def pack_b2( + cls, + obj: str, + key: str, + amount_of_files: int, + ) -> Tuple[Artifact, dict]: + # Create a temp directory: + path = tempfile.mkdtemp() + + # Create some files in it: + files = [] + for i in range(amount_of_files): + file_path = os.path.join(path, f"{i}.txt") + files.append(file_path) + with open(file_path, "w") as file: + file.write(obj) + + # Zip them: + zip_path = os.path.join(path, f"{key}.zip") + with zipfile.ZipFile(zip_path, "w") as zip_file: + for txt_file_path in files: + zip_file.write(txt_file_path) + + # Note for clearance: + cls.add_future_clearing_path(path=path) + + return Artifact(key=key, src_path=zip_path), { + "temp_dir": path, + "amount_of_files": amount_of_files, + } + + @classmethod + def unpack_b1(cls, data_item: DataItem): + pass + + @classmethod + def unpack_b2(cls, data_item: DataItem, length: int): + pass + + +class PackagerC(PackagerA): + """ + Another packager to test collecting an inherited class of `Packager`. In addition, it is used to test the arbitrary + log hint keys. + """ + + PACKABLE_OBJECT_TYPE = float + + @classmethod + def pack( + cls, obj: float, artifact_type: str = None, configurations: dict = None + ) -> dict: + return {configurations["key"]: round(obj, configurations["n_round"])} + + @classmethod + def unpack( + cls, + data_item: DataItem, + artifact_type: str = None, + instructions: dict = None, + ) -> float: + return data_item.key * 2 + + +class NotAPackager: + """ + Simple class to test an exception will be raised when trying to collect it. + """ + + pass + + +@pytest.mark.parametrize( + "packagers_to_collect, validation", + [ + (["tests.package.test_packagers_manager.PackagerA"], [PackagerA]), + ( + [ + "tests.package.test_packagers_manager.PackagerA", + "tests.package.test_packagers_manager.PackagerC", + ], + [PackagerA, PackagerC], + ), + ( + ["tests.package.test_packagers_manager.*"], + [PackagerA, PackagerB, PackagerC], + ), + ( + ["tests.package.module_not_exist.PackagerA"], + "The packager 'PackagerA' could not be collected from the module 'tests.package.module_not_exist'", + ), + ( + ["tests.package.test_packagers_manager.PackagerNotExist"], + "The packager 'PackagerNotExist' could not be collected as it does not exist in the module", + ), + ( + ["tests.package.test_packagers_manager.NotAPackager"], + "The packager 'NotAPackager' could not be collected as it is not a `mlrun.Packager`", + ), + ], +) +def test_collect_packagers( + packagers_to_collect: List[str], validation: Union[List[Type[Packager]], str] +): + """ + Test the manager's `collect_packagers` method. + + :param packagers_to_collect: The packagers to collect. + :param validation: The packager classes that should have been collected. A string means an error should + be raised. + """ + # Prepare the test: + packagers_manager = PackagersManager() + + # Try to collect the packagers: + try: + packagers_manager.collect_packagers(packagers=packagers_to_collect) + except MLRunPackageCollectionError as error: + # Catch only if the validation is a string, otherwise it is a legitimate exception: + if isinstance(validation, str): + # Make sure the correct error was raised: + assert validation in str(error) + return + raise error + + # Validate only the required packagers were collected: + for packager in validation: + assert packager in packagers_manager._packagers + + +@pytest.mark.parametrize( + "packagers_to_collect, result_key_suffix", + [ + ([PackagerA, PackagerB], "_from_PackagerB"), + ([PackagerB, PackagerA], "_from_PackagerA"), + ], +) +@pytest.mark.parametrize("set_via_default_priority", [True, False]) +def test_packagers_priority( + packagers_to_collect: List[Type[Packager]], + result_key_suffix: str, + set_via_default_priority: bool, +): + """ + Test the priority of the collected packagers (last collected will be set with the highest priority). + + :param packagers_to_collect: The packagers to collect + :param result_key_suffix: The suffix the result key should have if it was collected by the right packager. + :param set_via_default_priority: Whether to set the priority via the class or the default priority in collection. + """ + # Reset priorities (when performing multiple runs the class priority is remained set from previous run): + PackagerA.PRIORITY = ... + PackagerB.PRIORITY = ... + + # Collect the packagers: + packagers_manager = PackagersManager() + for packager, priority in zip(packagers_to_collect, [2, 1]): + if not set_via_default_priority: + packager.PRIORITY = priority + packagers_manager.collect_packagers( + packagers=[packager], default_priority=priority + ) + if set_via_default_priority: + assert packager.PRIORITY == priority + + # Pack a string as a result: + key = "some_key" + packagers_manager.pack( + obj="some string", log_hint={"key": key, "artifact_type": "result"} + ) + + # Make sure the correct packager packed the result by the suffix: + assert f"{key}{result_key_suffix}" in packagers_manager.results + + +def test_clear_packagers_outputs(): + """ + Test the manager's `clear_packagers_outputs` method. + """ + # Prepare the test: + packagers_manager = PackagersManager() + packagers_manager.collect_packagers(packagers=[PackagerB]) + + # Pack objects that will create temporary files and directories: + packagers_manager.pack( + obj="I'm a test.", + log_hint={"key": "a", "artifact_type": "b1", "fmt": "txt"}, + ) + packagers_manager.pack( + obj="I'm another test.", + log_hint={ + "key": "b", + "artifact_type": "b2", + "amount_of_files": 3, + }, + ) + + # Get the created files: + a_temp_dir = packagers_manager.artifacts[0].spec.unpackaging_instructions[ + "instructions" + ]["temp_dir"] + a_file = os.path.join(a_temp_dir, "a.txt") + b_temp_dir = packagers_manager.artifacts[1].spec.unpackaging_instructions[ + "instructions" + ]["temp_dir"] + + # Assert they do exist before clearing up: + assert os.path.exists(a_file) + assert os.path.exists(b_temp_dir) + + # Clear: + packagers_manager.clear_packagers_outputs() + + # Assert the clearance: + assert not os.path.exists(a_file) + assert not os.path.exists(b_temp_dir) + + # Remove remained directory (we tested the clearance of a file and a directory, so we need to delete the directory + # of the cleared file (it's directory was not marked as future clear)): + shutil.rmtree(a_temp_dir) + + +@pytest.mark.parametrize( + "key, obj, expected_results", + [ + ( + "*list_", + [0.12111, 0.56111], + {"list_0": 0.12, "list_1": 0.56}, + ), + ( + "*set_", + {0.12111, 0.56111}, + {"set_0": 0.12, "set_1": 0.56}, + ), + ( + "*", + (0.12111, 0.56111), + {"0": 0.12, "1": 0.56}, + ), + ( + "*error", + 0.12111, + "The log hint key '*error' has an iterable unpacking prefix ('*')", + ), + ( + "**dict_", + {"a": 0.12111, "b": 0.56111}, + {"dict_a": 0.12, "dict_b": 0.56}, + ), + ("**", {"a": 0.12111, "b": 0.56111}, {"a": 0.12, "b": 0.56}), + ( + "**error", + 0.12111, + "The log hint key '**error' has a dictionary unpacking prefix ('**')", + ), + ], +) +def test_arbitrary_log_hint( + key: str, + obj: Union[list, dict, tuple, set], + expected_results: Union[Dict[str, float], str], +): + """ + Test the arbitrary log hint key prefixes "*" and "**". + + :param key: The key to use in the log hint + :param obj: The object to pack + :param expected_results: The expected results that should be packed. A string means an error should be raised. + """ + # Prepare the test: + packagers_manager = PackagersManager() + packagers_manager.collect_packagers(packagers=[PackagerC]) + + # Pack an arbitrary amount of objects: + try: + packagers_manager.pack( + obj=obj, log_hint={"key": key, "artifact_type": "result", "n_round": 2} + ) + except MLRunInvalidArgumentError as error: + # Catch only if the expected results is a string, otherwise it is a legitimate exception: + if isinstance(expected_results, str): + assert expected_results in str(error) + return + raise error + + # Validate multiple packages were packed: + assert packagers_manager.results == expected_results + + +class _DummyDataItem: + def __init__(self, key: str, is_artifact: bool = False): + self.key = key + self.artifact_url = "" + self._is_artifact = is_artifact + + def get_artifact_type(self) -> bool: + return self._is_artifact + + +@pytest.mark.parametrize( + "data, type_hint, expected_results", + [ + ( + 0.5, + Union[int, bytes, float, int], + 1.0, + ), + ( + 0.5, + Union[int, bytes, int], + "Could not unpack data item with the hinted type", + ), + ], +) +def test_plural_type_hint_unpacking( + data: Any, + type_hint: Any, + expected_results: Union[Any, str], +): + """ + Test unpacking when plural type hint is given (for example: a union of types). + + :param data: The data of the data item to unpack. + :param type_hint: The plural type hint of ths data item. + :param expected_results: The expected results that should be unpacked. A string means an error should be raised. + """ + # Prepare the test: + packagers_manager = PackagersManager() + packagers_manager.collect_packagers(packagers=[PackagerC]) + + # Pack an arbitrary amount of objects: + try: + value = packagers_manager.unpack( + data_item=_DummyDataItem(key=data), type_hint=type_hint + ) + except MLRunPackageUnpackingError as error: + # Catch only if the expected results is a string, otherwise it is a legitimate exception: + if isinstance(expected_results, str): + assert expected_results in str(error) + return + raise error + + # Validate multiple packages were packed: + assert value == expected_results diff --git a/tests/package/test_usage.py b/tests/package/test_usage.py new file mode 100644 index 000000000000..a044d54bf92d --- /dev/null +++ b/tests/package/test_usage.py @@ -0,0 +1,266 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import tempfile +from typing import Tuple, Union + +import numpy as np +import pandas as pd +import pytest +from sklearn.impute import SimpleImputer +from sklearn.pipeline import Pipeline +from sklearn.preprocessing import OrdinalEncoder + +import mlrun + +RETURNS_LOG_HINTS = [ + "my_array", + "my_df", + "my_file: path", + {"key": "my_dict", "artifact_type": "object"}, + "my_list: file", + "my_int", + "my_str : result", + "my_object: object", +] + + +def log_artifacts_and_results() -> Tuple[ + np.ndarray, pd.DataFrame, str, dict, list, int, str, Pipeline +]: + encoder_to_imputer = Pipeline( + steps=[ + ( + "imputer", + SimpleImputer(missing_values="", strategy="constant", fill_value="C"), + ), + ("encoder", OrdinalEncoder()), + ] + ) + encoder_to_imputer.fit([["A"], ["B"], ["C"]]) + + context = mlrun.get_or_create_ctx(name="ctx") + context.log_result(key="manually_logged_result", value=10) + + file_path = os.path.join(context.artifact_path, "my_file.txt") + with open(file_path, "w") as file: + file.write("123") + + return ( + np.ones((10, 20)), + pd.DataFrame(np.zeros((20, 10))), + file_path, + {"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}, + [["A"], ["B"], [""]], + 3, + "hello", + encoder_to_imputer, + ) + + +def _assert_parsing( + my_array: np.ndarray, + my_df: mlrun.DataItem, + my_file: Union[int, mlrun.DataItem], + my_dict: dict, + my_list: list, + my_object: Pipeline, + my_int: int, + my_str: str, +): + assert isinstance(my_array, np.ndarray) + assert np.all(my_array == np.ones((10, 20))) + + assert isinstance(my_df, mlrun.DataItem) + my_df = my_df.as_df() + assert my_df.shape == (20, 10) + assert my_df.sum().sum() == 0 + + assert isinstance(my_file, mlrun.DataItem) + my_file = my_file.local() + with open(my_file, "r") as file: + file_content = file.read() + assert file_content == "123" + + assert isinstance(my_dict, dict) + assert my_dict == {"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]} + + assert isinstance(my_list, list) + assert my_list == [["A"], ["B"], [""]] + + assert isinstance(my_object, Pipeline) + assert my_object.transform(my_list).tolist() == [[0], [1], [2]] + + return [my_str] * my_int + + +def parse_inputs_from_type_annotations( + my_array: np.ndarray, + my_df: mlrun.DataItem, + my_file: Union[int, mlrun.DataItem], + my_dict: dict, + my_list: list, + my_object: Pipeline, + my_int: int, + my_str: str, +): + _assert_parsing( + my_array=my_array, + my_df=my_df, + my_file=my_file, + my_dict=my_dict, + my_list=my_list, + my_object=my_object, + my_int=my_int, + my_str=my_str, + ) + + +def parse_inputs_from_mlrun_function( + my_array, my_df, my_file, my_dict, my_list, my_object, my_int, my_str +): + _assert_parsing( + my_array=my_array, + my_df=my_df, + my_file=my_file, + my_dict=my_dict, + my_list=my_list, + my_object=my_object, + my_int=my_int, + my_str=my_str, + ) + + +@pytest.mark.parametrize("is_enabled", [True, False]) +@pytest.mark.parametrize("returns", [RETURNS_LOG_HINTS, []]) +def test_mlconf_packagers_enabled(rundb_mock, is_enabled: bool, returns: list): + """ + Test the packagers logging given the returns parameter in the `run` method and MLRun's `mlconf.packagers.enabled` + configuration. + + :param rundb_mock: A runDB mock fixture. + :param is_enabled: The `mlconf.packagers.enabled` configuration value. + :param returns: Log hints to pass in the 'returns' parameter. + """ + # Set the configuration: + mlrun.mlconf.packagers.enabled = is_enabled + + # Create the function: + mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") + artifact_path = tempfile.TemporaryDirectory() + + # Run the logging function: + log_artifacts_and_results_run = mlrun_function.run( + handler="log_artifacts_and_results", + returns=returns, + artifact_path=artifact_path.name, + local=True, + ) + + # There should always be at least one output - the manually logged result: + if is_enabled and returns: + # Plus all configured returning values: + assert len(log_artifacts_and_results_run.outputs) == 1 + len(RETURNS_LOG_HINTS) + else: + # Plus the default logged output as string MLRun did before packagers and log hints: + assert len(log_artifacts_and_results_run.outputs) == 1 + 1 + + +def test_parse_inputs_from_type_annotations(rundb_mock): + """ + Run the `parse_inputs_from_type_annotations` function with MLRun to see the packagers are parsing the given inputs + (`DataItem`s) to the written type hints. + + :param rundb_mock: A runDB mock fixture. + """ + # Create the function: + mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") + artifact_path = tempfile.TemporaryDirectory() + + # Run the logging functions: + log_artifacts_and_results_run = mlrun_function.run( + handler="log_artifacts_and_results", + returns=RETURNS_LOG_HINTS, + artifact_path=artifact_path.name, + local=True, + ) + + # Run the function that will parse the data items: + mlrun_function.run( + handler="parse_inputs_from_type_annotations", + inputs={ + "my_list": log_artifacts_and_results_run.outputs["my_list"], + "my_array": log_artifacts_and_results_run.outputs["my_array"], + "my_df": log_artifacts_and_results_run.outputs["my_df"], + "my_file": log_artifacts_and_results_run.outputs["my_file"], + "my_object": log_artifacts_and_results_run.outputs["my_object"], + "my_dict": log_artifacts_and_results_run.outputs["my_dict"], + }, + params={ + "my_int": log_artifacts_and_results_run.outputs["my_int"], + "my_str": log_artifacts_and_results_run.outputs["my_str"], + }, + artifact_path=artifact_path.name, + local=True, + ) + + # Clean the test outputs: + artifact_path.cleanup() + + +def test_parse_inputs_from_mlrun_function(rundb_mock): + """ + Run the `parse_inputs_from_mlrun_function` function with MLRun to see the packagers are parsing the given inputs + (`DataItem`s) to the provided configuration in the `run` method. + + :param rundb_mock: A runDB mock fixture. + """ + # Create the function: + mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") + artifact_path = tempfile.TemporaryDirectory() + + # Run the logging functions: + log_artifacts_and_results_run = mlrun_function.run( + handler="log_artifacts_and_results", + returns=RETURNS_LOG_HINTS, + artifact_path=artifact_path.name, + local=True, + ) + + # Run the function that will parse the data items: + mlrun_function.run( + handler="parse_inputs_from_mlrun_function", + inputs={ + "my_list:list": log_artifacts_and_results_run.outputs["my_list"], + "my_array : numpy.ndarray": log_artifacts_and_results_run.outputs[ + "my_array" + ], + "my_df": log_artifacts_and_results_run.outputs["my_df"], + "my_file": log_artifacts_and_results_run.outputs["my_file"], + "my_object: sklearn.pipeline.Pipeline": log_artifacts_and_results_run.outputs[ + "my_object" + ], + "my_dict: dict": log_artifacts_and_results_run.outputs["my_dict"], + }, + params={ + "my_int": log_artifacts_and_results_run.outputs["my_int"], + "my_str": log_artifacts_and_results_run.outputs["my_str"], + }, + artifact_path=artifact_path.name, + local=True, + ) + + # Clean the test outputs: + artifact_path.cleanup() diff --git a/tests/package/utils/__init__.py b/tests/package/utils/__init__.py new file mode 100644 index 000000000000..4f418a506ca1 --- /dev/null +++ b/tests/package/utils/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# flake8: noqa - this is until we take care of the F401 violations with respect to __all__ & sphinx diff --git a/tests/package/utils/test_archiver.py b/tests/package/utils/test_archiver.py new file mode 100644 index 000000000000..645977dc7aa7 --- /dev/null +++ b/tests/package/utils/test_archiver.py @@ -0,0 +1,111 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import tempfile +from pathlib import Path +from typing import List + +import numpy as np +import pytest + +from mlrun.package.utils import ArchiveSupportedFormat + + +@pytest.mark.parametrize( + "archive_format", + ArchiveSupportedFormat.get_all_formats(), +) +@pytest.mark.parametrize( + "directory_layout", + [ + ["my_file.bin"], + ["empty_dir"], + ["a.bin", "b.bin"], + ["inner_dir", os.path.join("inner_dir", "my_file.bin")], + [ + "a.bin", + "b.bin", + "inner_dir", + os.path.join("inner_dir", "my_file.bin"), + os.path.join("inner_dir", "empty_dir"), + "empty_dir", + ], + ], +) +def test_archiver(archive_format: str, directory_layout: List[str]): + """ + Test the archivers for creating archives of multiple layouts and extracting them while keeping their original + layout, names and data. + + :param archive_format: The archive format to use. + :param directory_layout: The layout to archive. + """ + # Create a temporary directory for the test outputs: + test_directory = tempfile.TemporaryDirectory() + + # Generate random array for the content of the files: + files_content: np.ndarray = np.random.random(size=100) + + # Set up the main directory to archive and the output path for the archive file: + directory_name = "my_dir" + directory_path = Path(test_directory.name) / directory_name + output_path = Path(test_directory.name) / "output_path" + os.makedirs(directory_path) + os.makedirs(output_path) + + # Create the files according to the layout provided: + for path in directory_layout: + full_path = directory_path / path + if "." in path: + files_content.tofile(full_path) + assert full_path.is_file() + else: + os.makedirs(full_path) + assert full_path.is_dir() + assert full_path.exists() + assert len(list(directory_path.rglob("*"))) == len(directory_layout) + + # Archive the files: + archiver = ArchiveSupportedFormat.get_format_handler(fmt=archive_format) + archive_path = Path( + archiver.create_archive( + directory_path=str(directory_path), output_path=str(output_path) + ) + ) + assert archive_path.exists() + assert archive_path == output_path / f"{directory_name}.{archive_format}" + + # Extract the files: + extracted_dir_path = Path( + archiver.extract_archive( + archive_path=str(archive_path), output_path=str(output_path) + ) + ) + assert extracted_dir_path.exists() + assert extracted_dir_path == output_path / directory_name + + # Validate all files were extracted as they originally were: + for path in directory_layout: + full_path = extracted_dir_path / path + assert full_path.exists() + if "." in path: + assert full_path.is_file() + np.testing.assert_equal(np.fromfile(file=full_path), files_content) + else: + assert full_path.is_dir() + assert len(list(extracted_dir_path.rglob("*"))) == len(directory_layout) + + # Clean the test outputs: + test_directory.cleanup() diff --git a/tests/package/utils/test_formatter.py b/tests/package/utils/test_formatter.py new file mode 100644 index 000000000000..ef2deba6100b --- /dev/null +++ b/tests/package/utils/test_formatter.py @@ -0,0 +1,60 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tempfile +from pathlib import Path +from typing import Union + +import pytest + +from mlrun.package.utils import StructFileSupportedFormat + + +@pytest.mark.parametrize( + "obj", + [ + {"a": 1, "b": 2}, + [1, 2, 3], + [{"a": [1, 2, 3], "b": [1, 2, 3]}, {"c": [4, 5, 6]}, [1, 2, 3, 4, 5, 6]], + ], +) +@pytest.mark.parametrize( + "file_format", + StructFileSupportedFormat.get_all_formats(), +) +def test_formatter(obj: Union[list, dict], file_format: str): + """ + Test the formatters for writing and reading python objects. + + :param obj: The object to write. + :param file_format: The struct file format to use. + """ + # Create a temporary directory for the test outputs: + test_directory = tempfile.TemporaryDirectory() + + # Set up the main directory to archive and the output path for the archive file: + file_path = Path(test_directory.name) / f"my_struct.{file_format}" + assert not file_path.exists() + + # Archive the files: + formatter = StructFileSupportedFormat.get_format_handler(fmt=file_format) + formatter.write(obj=obj, file_path=str(file_path)) + assert file_path.exists() + + # Extract the files: + read_object = formatter.read(file_path=str(file_path)) + assert read_object == obj + + # Clean the test outputs: + test_directory.cleanup() diff --git a/tests/package/utils/test_log_hint_utils.py b/tests/package/utils/test_log_hint_utils.py new file mode 100644 index 000000000000..d40e7bbf071a --- /dev/null +++ b/tests/package/utils/test_log_hint_utils.py @@ -0,0 +1,79 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from typing import Union + +import pytest + +from mlrun.errors import MLRunInvalidArgumentError +from mlrun.package.utils.log_hint_utils import LogHintKey, LogHintUtils + + +@pytest.mark.parametrize( + "log_hint, expected_log_hint", + [ + ("some_key", {LogHintKey.KEY: "some_key"}), + ( + "some_key:artifact", + {LogHintKey.KEY: "some_key", LogHintKey.ARTIFACT_TYPE: "artifact"}, + ), + ( + "some_key :artifact", + {LogHintKey.KEY: "some_key", LogHintKey.ARTIFACT_TYPE: "artifact"}, + ), + ( + "some_key: artifact", + {LogHintKey.KEY: "some_key", LogHintKey.ARTIFACT_TYPE: "artifact"}, + ), + ( + "some_key : artifact", + {LogHintKey.KEY: "some_key", LogHintKey.ARTIFACT_TYPE: "artifact"}, + ), + ( + "some_key:", + "Incorrect log hint pattern. The ':' in a log hint should specify", + ), + ( + "some_key : artifact : error", + "Incorrect log hint pattern. Log hints can have only a single ':' in them", + ), + ({LogHintKey.KEY: "some_key"}, {LogHintKey.KEY: "some_key"}), + ( + {LogHintKey.KEY: "some_key", LogHintKey.ARTIFACT_TYPE: "artifact"}, + {LogHintKey.KEY: "some_key", LogHintKey.ARTIFACT_TYPE: "artifact"}, + ), + ( + {LogHintKey.ARTIFACT_TYPE: "artifact"}, + "A log hint dictionary must include the 'key'", + ), + ], +) +def test_parse_log_hint( + log_hint: Union[str, dict], expected_log_hint: Union[str, dict] +): + """ + Test the `LogHintUtils.parse_log_hint` function with multiple types. + + :param log_hint: The log hint to parse. + :param expected_log_hint: The expected parsed log hint dictionary. A string value indicates the parsing should fail + with the provided error message in the variable. + """ + try: + parsed_log_hint = LogHintUtils.parse_log_hint(log_hint=log_hint) + assert parsed_log_hint == expected_log_hint + except MLRunInvalidArgumentError as error: + if isinstance(expected_log_hint, str): + assert expected_log_hint in str(error) + else: + raise error diff --git a/tests/package/utils/test_pickler.py b/tests/package/utils/test_pickler.py new file mode 100644 index 000000000000..80b969c762d5 --- /dev/null +++ b/tests/package/utils/test_pickler.py @@ -0,0 +1,87 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import tempfile +from pathlib import Path +from typing import Union + +import cloudpickle +import numpy as np +import pytest + +from mlrun.errors import MLRunInvalidArgumentError +from mlrun.package.utils import Pickler + + +@pytest.mark.parametrize( + "pickle_module_name, expected_notes", + [ + ( + "pickle", + { + "object_module_name": "numpy", + "pickle_module_name": "pickle", + "python_version": Pickler._get_python_version(), + "object_module_version": np.__version__, + }, + ), + ( + "cloudpickle", + { + "object_module_name": "numpy", + "pickle_module_name": "cloudpickle", + "python_version": Pickler._get_python_version(), + "object_module_version": np.__version__, + "pickle_module_version": cloudpickle.__version__, + }, + ), + ("numpy", "A pickle module is expected to have a"), + ], +) +def test_pickler(pickle_module_name: str, expected_notes: Union[dict, str]): + """ + Test the `Pickler` with multiple pickling modules. + + :param pickle_module_name: The pickle module name to use. + :param expected_notes: The expected pickling notes. A string value indicates the `Pickler` should fail with the + provided error message in the variable. + """ + # Create the test temporary directory: + test_directory = tempfile.TemporaryDirectory() + + # Prepare the pickle path and the object to pickle: + output_path = Path(test_directory.name) / "my_array.pkl" + array = np.random.random(size=100) + + # Pickle: + try: + _, notes = Pickler.pickle( + obj=array, + pickle_module_name=pickle_module_name, + output_path=str(output_path), + ) + except MLRunInvalidArgumentError as error: + if isinstance(expected_notes, str): + assert expected_notes in str(error) + return + raise error + assert output_path.exists() + assert notes == expected_notes + + # Unpickle: + pickled_array = Pickler.unpickle(pickle_path=str(output_path), **notes) + np.testing.assert_equal(pickled_array, array) + + # Delete the test directory (with the pickle file that was created): + test_directory.cleanup() diff --git a/tests/package/utils/test_type_hint_utils.py b/tests/package/utils/test_type_hint_utils.py new file mode 100644 index 000000000000..6fbaa50b22e9 --- /dev/null +++ b/tests/package/utils/test_type_hint_utils.py @@ -0,0 +1,240 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import collections +import typing + +import pytest + +from mlrun.errors import MLRunInvalidArgumentError +from mlrun.package.utils.type_hint_utils import TypeHintUtils + + +class SomeClass: + """ + To add a custom type for the type hinting test. + """ + + pass + + +class AnotherClass(SomeClass): + """ + To add a custom inheriting class for match test. + """ + + pass + + +@pytest.mark.parametrize( + "type_hint, expected_result", + [ + (typing.Optional[int], True), + (typing.Union[str, int], True), + (typing.List, True), + (typing.Tuple[int, str], True), + (typing.TypeVar("A", int, str), True), + (typing.ForwardRef("pandas.DataFrame"), True), + (list, False), + (int, False), + (SomeClass, False), + # TODO: Uncomment once we support Python >= 3.9: + # (list[int], True), + # (tuple[int, str], True), + # TODO: Uncomment once we support Python >= 3.10: + # (str | int, True), + ], +) +def test_is_typing_type(type_hint: typing.Type, expected_result: bool): + """ + Test the `TypeHintUtils.is_typing_type` function with multiple types. + + :param type_hint: The type to check. + :param expected_result: The expected result. + """ + assert TypeHintUtils.is_typing_type(type_hint=type_hint) == expected_result + + +@pytest.mark.parametrize( + "type_string, expected_type", + [ + ("int", int), + ("list", list), + ("tests.package.utils.test_type_hint_utils.SomeClass", SomeClass), + ( + "fail", + "MLRun tried to get the type hint 'fail' but it can't as it is not a valid builtin Python type (one of " + "`list`, `dict`, `str`, `int`, etc.) nor a locally declared type (from the `__main__` module).", + ), + ( + "tests.package.utils.test_type_hint_utils.Fail", + "MLRun tried to get the type hint 'Fail' from the module 'tests.package.utils.test_type_hint_utils' but it " + "seems it doesn't exist.", + ), + ( + "module_not_exist.Fail", + "MLRun tried to get the type hint 'Fail' but the module 'module_not_exist' cannot be imported.", + ), + ], +) +def test_parse_type_hint(type_string: str, expected_type: typing.Union[str, type]): + """ + Test the `TypeHintUtils.parse_type_hint` function with multiple types. + + :param type_string: The type to parse and + :param expected_type: The expected parsed type. A string value indicates the parsing should fail with the provided + error message in the variable. + """ + try: + parsed_type = TypeHintUtils.parse_type_hint(type_hint=type_string) + assert parsed_type is expected_type + except MLRunInvalidArgumentError as error: + if isinstance(expected_type, str): + assert expected_type in str(error) + else: + raise error + + +@pytest.mark.parametrize( + "object_type, type_hint, include_subclasses, reduce_type_hint, result", + [ + (int, int, True, False, True), + (int, str, True, True, False), + (typing.Union[int, str], typing.Union[str, int], True, True, True), + (typing.Union[int, str, bool], typing.Union[str, int], True, False, False), + (int, typing.Union[int, str], True, False, False), + (int, typing.Union[int, str], True, True, True), + (AnotherClass, SomeClass, True, False, True), + (AnotherClass, SomeClass, False, False, False), + (SomeClass, AnotherClass, True, False, False), + (AnotherClass, {SomeClass, int, str}, True, False, True), + (AnotherClass, {SomeClass, int, str}, False, False, False), + (SomeClass, {AnotherClass, int, str}, True, False, False), + ], +) +def test_is_matching( + object_type: type, + type_hint: type, + include_subclasses: bool, + reduce_type_hint: bool, + result: bool, +): + """ + Test the `TypeHintUtils.is_matching` function with multiple types. + + :param object_type: The type to match. + :param type_hint: The options to match to (the type hint of an object). + :param include_subclasses: Whether subclasses considered a match. + :param reduce_type_hint: Whether to reduce the type hint to match with its reduced hints. + :param result: Expected test result. + """ + assert ( + TypeHintUtils.is_matching( + object_type=object_type, + type_hint=type_hint, + include_subclasses=include_subclasses, + reduce_type_hint=reduce_type_hint, + ) + == result + ) + + +@pytest.mark.parametrize( + "type_hint, expected_result", + [ + # `typing.TypeVar` usages: + (typing.TypeVar("A", int, str, typing.List[int]), {int, str, typing.List[int]}), + (typing.TypeVar("A"), set()), + (typing.TypeVar, set()), + # `typing.ForwardRef` usage: + (typing.ForwardRef("SomeClass"), set()), + ( + typing.ForwardRef( + "SomeClass", module="tests.package.utils.test_type_hint_utils" + ), + {SomeClass}, + ), + ( + typing.ForwardRef("tests.package.utils.test_type_hint_utils.SomeClass"), + {SomeClass}, + ), + (typing.ForwardRef, set()), + # `typing.Callable` usages: + (typing.Callable, {collections.abc.Callable}), + ( + typing.Callable[[int, int], typing.Tuple[str, str]], + {collections.abc.Callable}, + ), + (collections.abc.Callable, set()), + # `typing.Literal` usages: + (typing.Literal["r", "w", 9], {str, int}), + (typing.Literal, set()), + # `typing.Union` usages: + (typing.Union[int, float], {int, float}), + ( + typing.Union[int, float, typing.Union[str, list]], + {int, float, str, list}, + ), + ( + typing.Union[int, str, typing.List[typing.Tuple[int, str, SomeClass]]], + {int, str, typing.List[typing.Tuple[int, str, SomeClass]]}, + ), + (typing.Union, set()), + # `typing.Optional` usages: + (typing.Optional[int], {type(None), int}), + (typing.Optional[typing.Union[str, list]], {type(None), str, list}), + (typing.Optional, set()), + # `typing.Annotated` usages: + (typing.Annotated[int, 3, 6], {int}), + (typing.Annotated, set()), + # `typing.Final` usages: + ( + typing.Final[typing.List[typing.Tuple[int, str, SomeClass]]], + {typing.List[typing.Tuple[int, str, SomeClass]]}, + ), + (typing.Final, set()), + # `typing.ClassVar` usages: + ( + typing.ClassVar[ + typing.Union[int, str, typing.List[typing.Tuple[int, str, SomeClass]]] + ], + {typing.Union[int, str, typing.List[typing.Tuple[int, str, SomeClass]]]}, + ), + (typing.ClassVar, set()), + # Other `typing`: + (typing.List, {list}), + (typing.List[typing.Tuple[int, str, SomeClass]], {list}), + (typing.Tuple[int, str, SomeClass], {tuple}), + # `collections` types: + (typing.OrderedDict[str, int], {collections.OrderedDict}), + (typing.OrderedDict, {collections.OrderedDict}), + (collections.OrderedDict, set()), + # Multiple types to reduce: + ({int, str, typing.List[int]}, {list}), + # TODO: Uncomment once we support Python >= 3.9: + # (list[str], {list}), + # TODO: Uncomment once we support Python >= 3.10: + # (str | int, {str, int}), + ], +) +def test_reduce_type_hint( + type_hint: typing.Type, expected_result: typing.Set[typing.Type] +): + """ + Test the `TypeHintUtils.reduce_type_hint` function with multiple type hints. + + :param type_hint: The type hint to reduce. + :param expected_result: The expected result. + """ + assert TypeHintUtils.reduce_type_hint(type_hint=type_hint) == expected_result diff --git a/tests/projects/assets/proj-setup.zip b/tests/projects/assets/proj-setup.zip new file mode 100644 index 000000000000..7f49f53949ae Binary files /dev/null and b/tests/projects/assets/proj-setup.zip differ diff --git a/tests/projects/test_local_pipeline.py b/tests/projects/test_local_pipeline.py index 669cb6a0675f..3945145768c4 100644 --- a/tests/projects/test_local_pipeline.py +++ b/tests/projects/test_local_pipeline.py @@ -34,7 +34,7 @@ def _set_functions(self): # kind="job" ) - def test_set_artifact(self): + def test_set_artifact(self, rundb_mock): self.project = mlrun.new_project("test-sa", save=False) self.project.set_artifact( "data1", mlrun.artifacts.Artifact(target_path=self.data_url) @@ -51,7 +51,7 @@ def test_set_artifact(self): artifacts = self.project.list_artifacts(tag="x") assert len(artifacts) == 1 - def test_import_artifacts(self): + def test_import_artifacts(self, rundb_mock): results_path = str(pathlib.Path(tests.conftest.results) / "project") project = mlrun.new_project( "test-sa2", context=str(self.assets_path), save=False @@ -206,3 +206,30 @@ def test_run_pipeline_artifact_path(self): mlrun.projects.pipeline_context._artifact_path == f"{generic_path}/{run_status.run_id}" ) + + def test_run_pipeline_with_ttl(self): + mlrun.projects.pipeline_context.clear(with_project=True) + self._create_project("localpipettl") + self._set_functions() + workflow_path = str(f"{self.assets_path / self.pipeline_path}") + cleanup_ttl = 1234 + run = self.project.run( + "p4", + workflow_path=workflow_path, + workflow_handler="my_pipe", + arguments={"param1": 7}, + local=True, + cleanup_ttl=cleanup_ttl, + ) + assert run.workflow.cleanup_ttl == cleanup_ttl + + self.project.set_workflow("my-workflow", workflow_path=workflow_path) + + run = self.project.run( + "my-workflow", + workflow_handler="my_pipe", + arguments={"param1": 7}, + local=True, + cleanup_ttl=cleanup_ttl, + ) + assert run.workflow.cleanup_ttl == cleanup_ttl diff --git a/tests/projects/test_project.py b/tests/projects/test_project.py index af903fde1041..7409e570e70b 100644 --- a/tests/projects/test_project.py +++ b/tests/projects/test_project.py @@ -13,10 +13,12 @@ # limitations under the License. # import os +import os.path import pathlib import shutil import tempfile import unittest.mock +import warnings import zipfile from contextlib import nullcontext as does_not_raise @@ -45,7 +47,7 @@ def assets_path(): return pathlib.Path(__file__).absolute().parent / "assets" -def test_sync_functions(): +def test_sync_functions(rundb_mock): project_name = "project-name" project = mlrun.new_project(project_name, save=False) project.set_function("hub://describe", "describe") @@ -61,7 +63,7 @@ def test_sync_functions(): assert fn.metadata.name == "describe", "func did not return" # test that functions can be fetched from the DB (w/o set_function) - mlrun.import_function("hub://auto_trainer", new_name="train").save() + mlrun.import_function("hub://auto-trainer", new_name="train").save() fn = project.get_function("train") assert fn.metadata.name == "train", "train func did not return" @@ -268,6 +270,18 @@ def test_build_project_from_minimal_dict(): False, "", ), + ( + "ssh://git@something/something", + "something", + [], + False, + 0, + False, + "", + True, + "Unsupported url scheme, supported schemes are: git://, db:// or " + ".zip/.tar.gz/.yaml file path (could be local or remote) or project name which will be loaded from DB", + ), ], ) def test_load_project( @@ -324,6 +338,21 @@ def test_load_project( assert os.path.exists(os.path.join(context, project_file)) +def test_load_project_with_setup(context): + url = ( + pathlib.Path(tests.conftest.tests_root_directory) + / "projects" + / "assets" + / "proj-setup.zip" + ) + project = mlrun.load_project(context=context, url=url) + assert project.name == "projset" + assert project.spec.context == context + assert project.spec.source == str(url) + assert project.spec.params == {"label_column": "label", "test123": "456"} + print(project.to_yaml()) + + @pytest.mark.parametrize( "sync,expected_num_of_funcs, save", [ @@ -358,12 +387,11 @@ def test_load_project_and_sync_functions( assert len(function_names) == expected_num_of_funcs for func in function_names: fn = project.get_function(func) - assert fn.metadata.name == mlrun.utils.helpers.normalize_name( - func - ), "func did not return" + normalized_name = mlrun.utils.helpers.normalize_name(func) + assert fn.metadata.name == normalized_name, "func did not return" - if save: - assert rundb_mock._function is not None + if save: + assert normalized_name in rundb_mock._functions def _assert_project_function_objects(project, expected_function_objects): @@ -382,7 +410,7 @@ def _assert_project_function_objects(project, expected_function_objects): ) -def test_set_func_requirements(): +def test_set_function_requirements(): project = mlrun.projects.project.MlrunProject.from_dict( { "metadata": { @@ -394,27 +422,26 @@ def test_set_func_requirements(): } ) project.set_function("hub://describe", "desc1", requirements=["x"]) - assert project.get_function("desc1", enrich=True).spec.build.commands == [ - "python -m pip install x", - "python -m pip install 'pandas>1, <3'", + assert project.get_function("desc1", enrich=True).spec.build.requirements == [ + "x", + "pandas>1, <3", ] fn = mlrun.import_function("hub://describe") project.set_function(fn, "desc2", requirements=["y"]) - assert project.get_function("desc2", enrich=True).spec.build.commands == [ - "python -m pip install y", - "python -m pip install 'pandas>1, <3'", + assert project.get_function("desc2", enrich=True).spec.build.requirements == [ + "y", + "pandas>1, <3", ] -def test_set_function_underscore_name(): +def test_backwards_compatibility_get_non_normalized_function_name(rundb_mock): project = mlrun.projects.MlrunProject( "project", default_requirements=["pandas>1, <3"] ) func_name = "name_with_underscores" - - # Create a function with a name that includes underscores func_path = str(pathlib.Path(__file__).parent / "assets" / "handler.py") + func = mlrun.code_to_function( name=func_name, kind="job", @@ -422,14 +449,57 @@ def test_set_function_underscore_name(): handler="myhandler", filename=func_path, ) + # nuclio also normalizes the name, so we de-normalize the function name before storing it + func.metadata.name = func_name + + # mock the normalize function response in order to insert a non-normalized function name to the db + with unittest.mock.patch("mlrun.utils.normalize_name", return_value=func_name): + project.set_function(name=func_name, func=func) + + # getting the function using the original non-normalized name, and ensure that querying it works + enriched_function = project.get_function(key=func_name) + assert enriched_function.metadata.name == func_name + + enriched_function = project.get_function(key=func_name, sync=True) + assert enriched_function.metadata.name == func_name + + # override the function by sending an update request, + # a new function is created, and the old one is no longer accessible + normalized_function_name = mlrun.utils.normalize_name(func_name) + func.metadata.name = normalized_function_name project.set_function(name=func_name, func=func) - # Attempt to get the function using the original name (with underscores) and ensure that it fails - with pytest.raises(mlrun.errors.MLRunNotFoundError): - project.get_function(key=func_name) + # using both normalized and non-normalized names to query the function + enriched_function = project.get_function(key=normalized_function_name) + assert enriched_function.metadata.name == normalized_function_name + + resp = project.get_function(key=func_name) + assert resp.metadata.name == normalized_function_name + + +def test_set_function_underscore_name(rundb_mock): + project = mlrun.projects.MlrunProject( + "project", default_requirements=["pandas>1, <3"] + ) + func_name = "name_with_underscores" - # Get the function using a normalized name and make sure it works + # create a function with a name that includes underscores + func_path = str(pathlib.Path(__file__).parent / "assets" / "handler.py") + func = mlrun.code_to_function( + name=func_name, + kind="job", + image="mlrun/mlrun", + handler="myhandler", + filename=func_path, + ) + project.set_function(name=func_name, func=func) + + # get the function using the original name (with underscores) and ensure that it works and returns normalized name normalized_name = mlrun.utils.normalize_name(func_name) + enriched_function = project.get_function(key=func_name) + assert enriched_function.metadata.name == normalized_name + + # get the function using a normalized name and make sure it works enriched_function = project.get_function(key=normalized_name) assert enriched_function.metadata.name == normalized_name @@ -470,6 +540,57 @@ def test_set_func_with_tag(): assert func.metadata.tag is None +def test_set_function_with_tagged_key(): + project = mlrun.new_project("set-func-tagged-key", save=False) + # create 2 functions with different tags + tag_v1 = "v1" + tag_v2 = "v2" + my_func_v1 = mlrun.code_to_function( + filename=str(pathlib.Path(__file__).parent / "assets" / "handler.py"), + kind="job", + tag=tag_v1, + ) + my_func_v2 = mlrun.code_to_function( + filename=str(pathlib.Path(__file__).parent / "assets" / "handler.py"), + kind="job", + name="my_func", + tag=tag_v2, + ) + + # set the functions + # function key is ("handler") + project.set_function(my_func_v1) + # function key is : ("handler:v1") + project.set_function(my_func_v1, tag=tag_v1) + # function key is "my_func" + project.set_function(my_func_v2, name=my_func_v2.metadata.name) + # function key is "my_func:v2" + project.set_function(my_func_v2, name=f"{my_func_v2.metadata.name}:{tag_v2}") + + assert len(project.spec._function_objects) == 4 + + func = project.get_function(f"{my_func_v1.metadata.name}:{tag_v1}") + assert func.metadata.tag == tag_v1 + + func = project.get_function(my_func_v1.metadata.name, tag=tag_v1) + assert func.metadata.tag == tag_v1 + + func = project.get_function(my_func_v1.metadata.name) + assert func.metadata.tag == tag_v1 + + func = project.get_function(my_func_v2.metadata.name) + assert func.metadata.tag == tag_v2 + + func = project.get_function(f"{my_func_v2.metadata.name}:{tag_v2}") + assert func.metadata.tag == tag_v2 + + func = project.get_function(my_func_v2.metadata.name, tag=tag_v2) + assert func.metadata.tag == tag_v2 + + func = project.get_function(f"{my_func_v2.metadata.name}:{tag_v2}", tag=tag_v2) + assert func.metadata.tag == tag_v2 + + def test_set_function_with_relative_path(context): project = mlrun.new_project("inline", context=str(assets_path()), save=False) @@ -774,6 +895,44 @@ def test_project_ops(): assert run.output("y") == 4 # = x * 2 +def test_clear_context(): + proj = mlrun.new_project("proj", save=False) + proj_with_subpath = mlrun.new_project( + "proj", + subpath="test", + context=pathlib.Path(tests.conftest.tests_root_directory), + save=False, + ) + subdir_path = os.path.join( + proj_with_subpath.spec.context, proj_with_subpath.spec.subpath + ) + # when the context is relative, assert no deletion called + with unittest.mock.patch( + "shutil.rmtree", return_value=True + ) as rmtree, warnings.catch_warnings(record=True) as w: + proj.clear_context() + rmtree.assert_not_called() + + assert len(w) == 2 + assert issubclass(w[-2].category, FutureWarning) + assert ( + "This method deletes all files and clears the context directory or subpath (if defined)!" + " Please keep in mind that this method can produce unexpected outcomes and is not recommended," + " it will be deprecated in 1.6.0." in str(w[-1].message) + ) + + # when the context is not relative and subdir specified, assert that the subdir is deleted rather than the context + with unittest.mock.patch( + "shutil.rmtree", return_value=True + ) as rmtree, unittest.mock.patch( + "os.path.exists", return_value=True + ), unittest.mock.patch( + "os.path.isdir", return_value=True + ): + proj_with_subpath.clear_context() + rmtree.assert_called_once_with(subdir_path) + + @pytest.mark.parametrize( "parameters,hyperparameters,expectation,run_saved", [ @@ -850,3 +1009,13 @@ def test_remove_owner_name_in_load_project_from_yaml(): imported_project = mlrun.load_project("./", str(project_file_path), save=False) assert project.spec.owner == "some_owner" assert imported_project.spec.owner is None + + +def test_set_secrets_file_not_found(): + # Create project and generate owner name + project_name = "project-name" + file_name = ".env-test" + project = mlrun.new_project(project_name, save=False) + with pytest.raises(mlrun.errors.MLRunNotFoundError) as excinfo: + project.set_secrets(file_path=file_name) + assert f"{file_name} does not exist" in str(excinfo.value) diff --git a/tests/projects/test_remote_pipeline.py b/tests/projects/test_remote_pipeline.py index 04ff486d6e45..3277b7b9a84a 100644 --- a/tests/projects/test_remote_pipeline.py +++ b/tests/projects/test_remote_pipeline.py @@ -26,7 +26,7 @@ import mlrun import tests.projects.assets.remote_pipeline_with_overridden_resources import tests.projects.base_pipeline -from mlrun.api.schemas import SecurityContextEnrichmentModes +from mlrun.common.schemas import SecurityContextEnrichmentModes @pytest.fixture() diff --git a/tests/run/assets/kwargs.py b/tests/run/assets/kwargs.py new file mode 100644 index 000000000000..0753ea2d9a12 --- /dev/null +++ b/tests/run/assets/kwargs.py @@ -0,0 +1,26 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +def func(context, x, **kwargs): + context.logger.info(x) + context.logger.info(kwargs) + return kwargs + + +def func_with_default(context, x=4, **kwargs): + context.logger.info(x) + context.logger.info(kwargs) + if not kwargs: + raise Exception("kwargs is empty") + return kwargs diff --git a/tests/run/test_handler_decorator.py b/tests/run/test_handler_decorator.py deleted file mode 100644 index 00a8f3a56a84..000000000000 --- a/tests/run/test_handler_decorator.py +++ /dev/null @@ -1,1023 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import os -import tempfile -import zipfile -from typing import List, Tuple, Union - -import cloudpickle -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import pytest -from sklearn.impute import SimpleImputer -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import OrdinalEncoder - -import mlrun - - -@mlrun.handler(labels={"a": 1, "b": "a test", "c": [1, 2, 3]}) -def set_labels(arg1, arg2=23): - return arg1 - arg2 - - -def test_set_labels_without_mlrun(): - """ - Run the `set_labels` function without MLRun to see the wrapper is transparent. - """ - returned_result = set_labels(24) - assert returned_result == 1 - - returned_result = set_labels(20, 18) - assert returned_result == 2 - - returned_result = set_labels(arg1=24) - assert returned_result == 1 - - returned_result = set_labels(arg1=20, arg2=18) - assert returned_result == 2 - - -def test_set_labels_with_mlrun(): - """ - Run the `set_labels` function with MLRun to see the wrapper is setting the required labels. - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="set_labels", - params={"arg1": 24}, - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.metadata.labels) - - # Assertion: - assert run_object.metadata.labels["a"] == "1" - assert run_object.metadata.labels["b"] == "a test" - assert run_object.metadata.labels["c"] == "[1, 2, 3]" - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler(labels={"wrapper_label": "2"}) -def set_labels_from_function_and_wrapper(context: mlrun.MLClientCtx = None): - if context: - context.set_label("context_label", 1) - - -def test_set_labels_from_function_and_wrapper_without_mlrun(): - """ - Run the `set_labels_from_function_and_wrapper` function without MLRun to see the wrapper is transparent. - """ - returned_result = set_labels_from_function_and_wrapper() - assert returned_result is None - - -def test_set_labels_from_function_and_wrapper_with_mlrun(): - """ - Run the `set_labels_from_function_and_wrapper` function with MLRun to see the wrapper is setting the required - labels without interrupting to the ones set via the context by the user. - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="set_labels_from_function_and_wrapper", - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.metadata.labels) - - # Assertion: - assert run_object.metadata.labels["context_label"] == "1" - assert run_object.metadata.labels["wrapper_label"] == "2" - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler( - outputs=[ - "my_array", - "my_df:dataset", - "my_dict : dataset", - "my_list :dataset", - ] -) -def log_dataset() -> Tuple[np.ndarray, pd.DataFrame, dict, list]: - return ( - np.ones((10, 20)), - pd.DataFrame(np.zeros((20, 10))), - {"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}, - [["A"], ["B"], [""]], - ) - - -def test_log_dataset_without_mlrun(): - """ - Run the `log_dataset` function without MLRun to see the wrapper is transparent. - """ - my_array, my_df, my_dict, my_list = log_dataset() - assert isinstance(my_array, np.ndarray) - assert isinstance(my_df, pd.DataFrame) - assert isinstance(my_dict, dict) - assert isinstance(my_list, list) - - -def test_log_dataset_with_mlrun(): - """ - Run the `log_dataset` function with MLRun to see the wrapper is logging the returned values as datasets artifacts. - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="log_dataset", - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert len(run_object.outputs) == 4 # my_array, my_df, my_dict, my_list - assert run_object.artifact("my_array").as_df().shape == (10, 20) - assert run_object.artifact("my_df").as_df().shape == (20, 10) - assert run_object.artifact("my_dict").as_df().shape == (4, 2) - assert run_object.artifact("my_list").as_df().shape == (3, 1) - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler( - outputs=[ - "my_dir: directory", - ] -) -def log_directory(path: str) -> str: - path = os.path.join(path, "my_new_dir") - os.makedirs(path) - open(os.path.join(path, "a.txt"), "a").close() - open(os.path.join(path, "b.txt"), "a").close() - open(os.path.join(path, "c.txt"), "a").close() - return path - - -def test_log_directory_without_mlrun(): - """ - Run the `log_directory` function without MLRun to see the wrapper is transparent. - """ - temp_dir = tempfile.TemporaryDirectory() - my_dir = log_directory(temp_dir.name) - assert isinstance(my_dir, str) - temp_dir.cleanup() - - -def test_log_directory_with_mlrun(): - """ - Run the `log_directory` function with MLRun to see the wrapper is logging the directory as a zip file. - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="log_directory", - params={"path": artifact_path.name}, - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert len(run_object.outputs) == 1 # my_dir - my_dir_zip = run_object.artifact("my_dir").local() - my_dir = os.path.join(artifact_path.name, "extract_here") - with zipfile.ZipFile(my_dir_zip, "r") as zip_ref: - zip_ref.extractall(my_dir) - my_dir_contents = os.listdir(my_dir) - assert len(my_dir_contents) == 3 - assert "a.txt" in my_dir_contents - assert "b.txt" in my_dir_contents - assert "c.txt" in my_dir_contents - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler( - outputs=[ - "my_file : file", - ] -) -def log_file(path: str) -> str: - my_file = os.path.join(path, "a.txt") - with open(my_file, "a") as f: - f.write("some text") - return my_file - - -def test_log_file_without_mlrun(): - """ - Run the `log_file` function without MLRun to see the wrapper is transparent. - """ - temp_dir = tempfile.TemporaryDirectory() - my_file = log_file(temp_dir.name) - assert isinstance(my_file, str) - temp_dir.cleanup() - - -def test_log_file_with_mlrun(): - """ - Run the `log_file` function with MLRun to see the wrapper is logging the file. - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="log_file", - params={"path": artifact_path.name}, - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert len(run_object.outputs) == 1 # my_file - with open(run_object.artifact("my_file").local(), "r") as my_file: - assert my_file.read() == "some text" - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler(outputs=["my_object : object"]) -def log_object() -> Pipeline: - encoder_to_imputer = Pipeline( - steps=[ - ( - "imputer", - SimpleImputer(missing_values="", strategy="constant", fill_value="C"), - ), - ("encoder", OrdinalEncoder()), - ] - ) - encoder_to_imputer.fit([["A"], ["B"], ["C"]]) - return encoder_to_imputer - - -def test_log_object_without_mlrun(): - """ - Run the `log_object` function without MLRun to see the wrapper is transparent. - """ - temp_dir = tempfile.TemporaryDirectory() - my_object = log_object() - assert isinstance(my_object, Pipeline) - assert my_object.transform([["A"], ["B"], [""]]).tolist() == [[0], [1], [2]] - temp_dir.cleanup() - - -def test_log_object_with_mlrun(): - """ - Run the `log_object` function with MLRun to see the wrapper is logging the object as pickle. - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="log_object", - params={"path": artifact_path.name}, - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert len(run_object.outputs) == 1 # my_file - pickle = run_object.artifact("my_object").local() - assert os.path.basename(pickle) == "my_object.pkl" - with open(pickle, "rb") as pickle_file: - my_object = cloudpickle.load(pickle_file) - assert isinstance(my_object, Pipeline) - assert my_object.transform([["A"], ["B"], [""]]).tolist() == [[0], [1], [2]] - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler(outputs=["my_plot: plot"]) -def log_plot() -> plt.Figure: - my_plot, axes = plt.subplots() - axes.plot([1, 2, 3, 4]) - return my_plot - - -def test_log_plot_without_mlrun(): - """ - Run the `log_plot` function without MLRun to see the wrapper is transparent. - """ - my_plot = log_plot() - assert isinstance(my_plot, plt.Figure) - - -def test_log_plot_with_mlrun(): - """ - Run the `log_plot` function with MLRun to see the wrapper is logging the plots as html files. - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="log_plot", - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert len(run_object.outputs) == 1 # my_plot - assert os.path.basename(run_object.artifact("my_plot").local()) == "my_plot.html" - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler( - outputs=[ - "my_int", - "my_float", - "my_dict: result", - "my_array:result", - ] -) -def log_result() -> Tuple[int, float, dict, np.ndarray]: - return 1, 1.5, {"a": 1, "b": 2}, np.ones(3) - - -def test_log_result_without_mlrun(): - """ - Run the `log_result` function without MLRun to see the wrapper is transparent. - """ - my_int, my_float, my_dict, my_array = log_result() - assert isinstance(my_int, int) - assert isinstance(my_float, float) - assert isinstance(my_dict, dict) - assert isinstance(my_array, np.ndarray) - - -def test_log_result_with_mlrun(): - """ - Run the `log_result` function with MLRun to see the wrapper is logging the returned values as results. - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="log_result", - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert len(run_object.outputs) == 4 # my_int, my_float, my_dict, my_array - assert run_object.outputs["my_int"] == 1 - assert run_object.outputs["my_float"] == 1.5 - assert run_object.outputs["my_dict"] == {"a": 1, "b": 2} - assert run_object.outputs["my_array"] == [1, 1, 1] - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler( - outputs=["my_result", "my_dataset", "my_object", "my_plot", "my_imputer"] -) -def log_as_default_artifact_types(): - my_plot, axes = plt.subplots() - axes.plot([1, 2, 3, 4]) - return ( - 10, - pd.DataFrame(np.ones(10)), - cloudpickle.dumps({"a": 5}), - my_plot, - SimpleImputer(), - ) - - -def test_log_as_default_artifact_types_without_mlrun(): - """ - Run the `log_as_default_artifact_types` function without MLRun to see the wrapper is transparent. - """ - ( - my_result, - my_dataset, - my_object, - my_plot, - my_imputer, - ) = log_as_default_artifact_types() - assert isinstance(my_result, int) - assert isinstance(my_dataset, pd.DataFrame) - assert isinstance(my_object, bytes) - assert isinstance(my_plot, plt.Figure) - assert isinstance(my_imputer, SimpleImputer) - - -def test_log_as_default_artifact_types_with_mlrun(): - """ - Run the `log_as_default_artifact_types` function with MLRun to see the wrapper is logging the returned values - as the correct default artifact types as the artifact types are not provided to the decorator. - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="log_as_default_artifact_types", - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert ( - len(run_object.outputs) == 5 - ) # my_result, my_dataset, my_object, my_plot, my_imputer - assert run_object.outputs["my_result"] == 10 - assert run_object.artifact("my_dataset").as_df().shape == (10, 1) - my_object_pickle = run_object.artifact("my_object").local() - assert os.path.basename(my_object_pickle) == "my_object.pkl" - with open(my_object_pickle, "rb") as pickle_file: - my_object = cloudpickle.load(pickle_file) - assert my_object == {"a": 5} - assert os.path.basename(run_object.artifact("my_plot").local()) == "my_plot.html" - assert isinstance(run_object.outputs["my_imputer"], str) - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler(outputs=["dataset: dataset", "result: result", "no_type", None]) -def log_with_none_values( - is_none_dataset: bool = False, - is_none_result: bool = False, - is_none_no_type: bool = False, -): - return ( - None if is_none_dataset else np.zeros(shape=(5, 5)), - None if is_none_result else 5, - None if is_none_no_type else np.ones(shape=(10, 10)), - 10, - ) - - -def test_log_with_none_values_without_mlrun(): - """ - Run the `log_with_none_values` function without MLRun to see the wrapper is transparent. - """ - dataset, result, no_type, no_to_log = log_with_none_values() - assert isinstance(dataset, np.ndarray) - assert result == 5 - assert isinstance(no_type, np.ndarray) - assert no_to_log == 10 - - -@pytest.mark.parametrize("is_none_dataset", [True, False]) -@pytest.mark.parametrize("is_none_result", [True, False]) -@pytest.mark.parametrize("is_none_no_type", [True, False]) -def test_log_with_none_values_with_mlrun( - is_none_dataset: bool, is_none_result: bool, is_none_no_type: bool -): - """ - Run the `log_with_none_values` function with MLRun to see the wrapper is logging and ignoring the returned values - as needed. Only result type should be logged as None, the dataset is needed to be ignored (not logged). - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="log_with_none_values", - params={ - "is_none_dataset": is_none_dataset, - "is_none_result": is_none_result, - "is_none_no_type": is_none_no_type, - }, - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert ( - len(run_object.outputs) == (0 if is_none_dataset else 1) + 1 + 1 - ) # dataset only if True, result, no_type - if not is_none_dataset: - assert run_object.artifact("dataset").as_df().shape == (5, 5) - assert run_object.outputs["result"] == "None" if is_none_result else 5 - if is_none_no_type: - assert run_object.outputs["no_type"] == "None" - else: - assert run_object.artifact("no_type").as_df().shape == (10, 10) - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler(outputs=["wrapper_dataset: dataset", "wrapper_result: result"]) -def log_from_function_and_wrapper(context: mlrun.MLClientCtx = None): - if context: - context.log_result(key="context_result", value=1) - context.log_dataset(key="context_dataset", df=pd.DataFrame(np.arange(10))) - return [1, 2, 3, 4], "hello" - - -def test_log_from_function_and_wrapper_without_mlrun(): - """ - Run the `log_from_function_and_wrapper` function without MLRun to see the wrapper is transparent. - """ - my_dataset, my_result = log_from_function_and_wrapper() - assert isinstance(my_dataset, list) - assert isinstance(my_result, str) - - -def test_log_from_function_and_wrapper_with_mlrun(): - """ - Run the `log_from_function_and_wrapper` function with MLRun to see the wrapper is logging the returned values - among the other values logged via the context manually inside the function. - """ - # Create the function and run: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - run_object = mlrun_function.run( - handler="log_from_function_and_wrapper", - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert ( - len(run_object.outputs) == 4 - ) # context_dataset, context_result, wrapper_dataset, wrapper_result - assert run_object.artifact("context_dataset").as_df().shape == (10, 1) - assert run_object.outputs["context_result"] == 1 - assert run_object.artifact("wrapper_dataset").as_df().shape == (4, 1) - assert run_object.outputs["wrapper_result"] == "hello" - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler() -def parse_inputs_from_type_hints( - my_data: list, - my_encoder: Pipeline, - data_2, - data_3: mlrun.DataItem, - data_4: List[int], - add, - mul: int = 2, -): - assert data_2 is None or isinstance(data_2, mlrun.DataItem) - assert data_3 is None or isinstance(data_3, mlrun.DataItem) - assert data_4 is None or isinstance(data_4, mlrun.DataItem) - - return (my_encoder.transform(my_data) + add * mul).tolist() - - -def test_parse_inputs_from_type_hints_without_mlrun(): - """ - Run the `parse_inputs_from_type_hints` function without MLRun to see the wrapper is transparent. - """ - _, _, _, my_data = log_dataset() - my_encoder = log_object() - result = parse_inputs_from_type_hints( - my_data, my_encoder=my_encoder, data_2=None, data_3=None, data_4=None, add=1 - ) - assert isinstance(result, list) - assert result == [[2], [3], [4]] - - -def test_parse_inputs_from_type_hints_with_mlrun(): - """ - Run the `parse_inputs_from_type_hints` function with MLRun to see the wrapper is parsing the given inputs - (`DataItem`s) to the written type hints. - """ - # Create the function and run 2 of the previous functions to create a dataset and encoder objects: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - log_dataset_run = mlrun_function.run( - handler="log_dataset", - artifact_path=artifact_path.name, - local=True, - ) - log_object_run = mlrun_function.run( - handler="log_object", - artifact_path=artifact_path.name, - local=True, - ) - - # Run the function that will parse the data items: - run_object = mlrun_function.run( - handler="parse_inputs_from_type_hints", - inputs={ - "my_data": log_dataset_run.outputs["my_list"], - "my_encoder": log_object_run.outputs["my_object"], - "data_2": log_dataset_run.outputs["my_array"], - "data_3": log_dataset_run.outputs["my_dict"], - "data_4": log_dataset_run.outputs["my_list"], - }, - params={"add": 1}, - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert len(run_object.outputs) == 1 # return - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler( - inputs={"my_data": np.ndarray, "my_union": Union[np.ndarray, pd.DataFrame]} -) -def parse_inputs_from_wrapper_using_types( - my_data, my_encoder, my_union, add, mul: int = 2 -): - if my_union is not None: - assert isinstance(my_union, mlrun.DataItem) - my_union = my_union.as_df() - assert my_union.shape == (20, 10) - if isinstance(my_encoder, mlrun.DataItem): - my_encoder = my_encoder.local() - with open(my_encoder, "rb") as pickle_file: - my_encoder = cloudpickle.load(pickle_file) - return (my_encoder.transform(my_data) + add * mul).tolist() - - -def test_parse_inputs_from_wrapper_using_types_without_mlrun(): - """ - Run the `parse_inputs_from_wrapper_using_types` function without MLRun to see the wrapper is transparent. - """ - _, _, _, my_data = log_dataset() - my_encoder = log_object() - result = parse_inputs_from_wrapper_using_types( - pd.DataFrame(my_data), my_encoder=my_encoder, my_union=None, add=1 - ) - assert isinstance(result, list) - assert result == [[2], [3], [4]] - - -def test_parse_inputs_from_wrapper_using_types_with_mlrun(): - """ - Run the `parse_inputs_from_wrapper_using_types` function with MLRun to see the wrapper is parsing the given inputs - (`DataItem`s) to the written configuration provided to the wrapper. - """ - # Create the function and run 2 of the previous functions to create a dataset and encoder objects: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - log_dataset_run = mlrun_function.run( - handler="log_dataset", - artifact_path=artifact_path.name, - local=True, - ) - log_object_run = mlrun_function.run( - handler="log_object", - artifact_path=artifact_path.name, - local=True, - ) - - # Run the function that will parse the data items: - run_object = mlrun_function.run( - handler="parse_inputs_from_wrapper_using_types", - inputs={ - "my_data": log_dataset_run.outputs["my_list"], - "my_encoder": log_object_run.outputs["my_object"], - "my_union": log_dataset_run.outputs["my_df"], - }, - params={"add": 1}, - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert len(run_object.outputs) == 1 # return - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler( - inputs={ - "my_list": "list", - "my_array": "numpy.ndarray", - "my_encoder": "sklearn.pipeline.Pipeline", - "my_union": "typing.Union[numpy.array, pandas.DataFrame]", - }, - outputs=["result"], -) -def parse_inputs_from_wrapper_using_strings( - my_list, my_array, my_df, my_encoder, my_union, add, mul: int = 2 -): - if my_union is not None: - assert isinstance(my_union, mlrun.DataItem) - my_union = my_union.as_df() - assert my_union.shape == (20, 10) - if isinstance(my_df, mlrun.DataItem): - my_df = my_df.as_df() - assert my_list == [["A"], ["B"], [""]] - assert isinstance(my_encoder, Pipeline) - return int((my_df.sum().sum() + my_array.sum() + add) * mul) - - -def test_parse_inputs_from_wrapper_using_strings_without_mlrun(): - """ - Run the `parse_inputs_from_wrapper_using_strings` function without MLRun to see the wrapper is transparent. - """ - my_array, my_df, _, my_list = log_dataset() - my_encoder = log_object() - result = parse_inputs_from_wrapper_using_strings( - my_list, my_array, my_df=my_df, my_encoder=my_encoder, my_union=None, add=1 - ) - assert isinstance(result, int) - assert result == 402 - - -def test_parse_inputs_from_wrapper_using_strings_with_mlrun(): - """ - Run the `parse_inputs_from_wrapper_using_strings` function with MLRun to see the wrapper is parsing the given inputs - (`DataItem`s) to the written configuration provided to the wrapper. - """ - # Create the function and run 2 of the previous functions to create a dataset and encoder objects: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - log_dataset_run = mlrun_function.run( - handler="log_dataset", - artifact_path=artifact_path.name, - local=True, - ) - log_object_run = mlrun_function.run( - handler="log_object", - artifact_path=artifact_path.name, - local=True, - ) - - # Run the function that will parse the data items: - run_object = mlrun_function.run( - handler="parse_inputs_from_wrapper_using_strings", - inputs={ - "my_list": log_dataset_run.outputs["my_list"], - "my_array": log_dataset_run.outputs["my_array"], - "my_df": log_dataset_run.outputs["my_df"], - "my_encoder": log_object_run.outputs["my_object"], - "my_union": log_dataset_run.outputs["my_df"], - }, - params={"add": 1}, - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert len(run_object.outputs) == 1 # result - assert run_object.outputs["result"] == 402 - - # Clean the test outputs: - artifact_path.cleanup() - - -@mlrun.handler(outputs=["error_numpy"]) -def raise_error_while_logging(): - return np.ones(shape=(7, 7, 7)) - - -def test_raise_error_while_logging_with_mlrun(): - """ - Run the `raise_error_while_logging` function with MLRun to see the wrapper is raising the relevant error. - """ - # Create the function: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - - # Run and expect an error: - try: - mlrun_function.run( - handler="raise_error_while_logging", - artifact_path=artifact_path.name, - local=True, - ) - assert False - except Exception as e: - mlrun.utils.logger.info(e) - assert "MLRun tried to log 'error_numpy' as " in str(e) - - # Clean the test outputs: - artifact_path.cleanup() - - -def test_raise_error_while_parsing_with_mlrun(): - """ - Run the `parse_inputs_from_wrapper_using_types` function with MLRun and send it wrong types to see the wrapper is - raising the relevant exception. - """ - # Create the function and run 2 of the previous functions to create a dataset and encoder objects: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - log_dataset_run = mlrun_function.run( - handler="log_dataset", - artifact_path=artifact_path.name, - local=True, - ) - log_object_run = mlrun_function.run( - handler="log_object", - artifact_path=artifact_path.name, - local=True, - ) - - # Run the function that will parse the data items and expect an error: - try: - mlrun_function.run( - handler="parse_inputs_from_wrapper_using_types", - inputs={ - "my_data": log_object_run.outputs["my_object"], - "my_encoder": log_dataset_run.outputs["my_list"], - }, - params={"add": 1}, - artifact_path=artifact_path.name, - local=True, - ) - assert False - except Exception as e: - mlrun.utils.logger.info(e) - assert "MLRun tried to parse a `DataItem` of type " in str(e) - - # Clean the test outputs: - artifact_path.cleanup() - - -class MyClass: - def __init__(self, class_parameter: int): - assert isinstance(class_parameter, int) - self._parameter = class_parameter - - @mlrun.handler( - outputs=[ - "my_array:dataset", - "my_df: dataset", - "my_dict :dataset", - "my_list : dataset", - ] - ) - def log_dataset(self) -> Tuple[np.ndarray, pd.DataFrame, dict, list]: - return ( - np.ones((10, 20)), - pd.DataFrame(np.zeros((20, 10))), - {"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}, - [["A"], ["B"], [""]], - ) - - @mlrun.handler(outputs=["my_object: object"]) - def log_object(self) -> Pipeline: - encoder_to_imputer = Pipeline( - steps=[ - ( - "imputer", - SimpleImputer( - missing_values="", strategy="constant", fill_value="C" - ), - ), - ("encoder", OrdinalEncoder()), - ] - ) - encoder_to_imputer.fit([["A"], ["B"], ["C"]]) - return encoder_to_imputer - - @mlrun.handler(outputs=["result"]) - def parse_inputs_from_type_hints( - self, - my_data: list, - my_encoder: Pipeline, - data_2, - data_3: mlrun.DataItem, - mul: int, - ): - assert data_2 is None or isinstance(data_2, mlrun.DataItem) - assert data_3 is None or isinstance(data_3, mlrun.DataItem) - - return int(sum(my_encoder.transform(my_data) + self._parameter * mul)) - - -def test_class_methods_without_mlrun(): - """ - Run the `log_dataset` function without MLRun to see the wrapper is transparent. - """ - temp_dir = tempfile.TemporaryDirectory() - - my_class = MyClass(class_parameter=1) - - my_array, my_df, my_dict, my_list = my_class.log_dataset() - assert isinstance(my_array, np.ndarray) - assert isinstance(my_df, pd.DataFrame) - assert isinstance(my_dict, dict) - assert isinstance(my_list, list) - - my_object = my_class.log_object() - assert isinstance(my_object, Pipeline) - assert my_object.transform([["A"], ["B"], [""]]).tolist() == [[0], [1], [2]] - - result = my_class.parse_inputs_from_type_hints( - my_list, my_encoder=my_object, data_2=None, data_3=None, mul=2 - ) - assert result == 9 - - temp_dir.cleanup() - - -def test_class_methods_with_mlrun(): - """ - Run the `log_dataset` function with MLRun to see the wrapper is logging the returned values as datasets artifacts. - """ - # Create the function and run 2 of the previous functions to create a dataset and encoder objects: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - log_dataset_run = mlrun_function.run( - handler="MyClass::log_dataset", - params={"_init_args": {"class_parameter": 1}}, - artifact_path=artifact_path.name, - local=True, - ) - log_object_run = mlrun_function.run( - handler="MyClass::log_object", - params={"_init_args": {"class_parameter": 1}}, - artifact_path=artifact_path.name, - local=True, - ) - - # Run the function that will parse the data items: - run_object = mlrun_function.run( - handler="MyClass::parse_inputs_from_type_hints", - inputs={ - "my_data": log_dataset_run.outputs["my_list"], - "my_encoder": log_object_run.outputs["my_object"], - "data_2": log_dataset_run.outputs["my_array"], - "data_3": log_dataset_run.outputs["my_dict"], - }, - params={"_init_args": {"class_parameter": 1}, "mul": 2}, - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(run_object.outputs) - - # Assertion: - assert len(run_object.outputs) == 1 - assert run_object.outputs["result"] == 9 - - # Clean the test outputs: - artifact_path.cleanup() diff --git a/tests/run/test_hyper.py b/tests/run/test_hyper.py index 71ccbe0b9d17..07b8e8944d8b 100644 --- a/tests/run/test_hyper.py +++ b/tests/run/test_hyper.py @@ -188,7 +188,7 @@ def hyper_func2(context, p1=1): context.log_dataset("df2", df=df) -def test_hyper_get_artifact(): +def test_hyper_get_artifact(rundb_mock): fn = mlrun.new_function("test_hyper_get_artifact") run = mlrun.run_function( fn, diff --git a/tests/run/test_main.py b/tests/run/test_main.py deleted file mode 100644 index e9870e85da88..000000000000 --- a/tests/run/test_main.py +++ /dev/null @@ -1,446 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import datetime -import os -import pathlib -import sys -import traceback -from base64 import b64encode -from subprocess import PIPE, run -from sys import executable, stderr - -import pytest - -import mlrun -from tests.conftest import examples_path, out_path, tests_root_directory - - -def exec_main(op, args, cwd=examples_path, raise_on_error=True): - cmd = [executable, "-m", "mlrun", op] - if args: - cmd += args - out = run(cmd, stdout=PIPE, stderr=PIPE, cwd=cwd) - if out.returncode != 0: - print(out.stderr.decode("utf-8"), file=stderr) - print(out.stdout.decode("utf-8"), file=stderr) - print(traceback.format_exc()) - if raise_on_error: - raise Exception(out.stderr.decode("utf-8")) - else: - # return out so that we can check the error message on stdout and stderr - return out - - return out.stdout.decode("utf-8") - - -def exec_run(cmd, args, test, raise_on_error=True): - args = args + ["--name", test, "--dump", cmd] - return exec_main("run", args, raise_on_error=raise_on_error) - - -def compose_param_list(params: dict, flag="-p"): - composed_params = [] - for k, v in params.items(): - composed_params += [flag, f"{k}={v}"] - return composed_params - - -def test_main_run_basic(): - out = exec_run( - f"{examples_path}/training.py", - compose_param_list(dict(p1=5, p2='"aaa"')), - "test_main_run_basic", - ) - print(out) - assert out.find("state: completed") != -1, out - - -def test_main_run_wait_for_completion(): - """ - Test that the run command waits for the run to complete before returning - (mainly sanity as this is expected when running local function) - """ - path = str(pathlib.Path(__file__).absolute().parent / "assets" / "sleep.py") - time_to_sleep = 10 - start_time = datetime.datetime.now() - out = exec_run( - path, - compose_param_list(dict(time_to_sleep=time_to_sleep)) - + ["--handler", "handler"], - "test_main_run_wait_for_completion", - ) - end_time = datetime.datetime.now() - print(out) - assert out.find("state: completed") != -1, out - assert ( - end_time - start_time - ).seconds >= time_to_sleep, "run did not wait for completion" - - -def test_main_run_hyper(): - out = exec_run( - f"{examples_path}/training.py", - compose_param_list(dict(p2=[4, 5, 6]), "-x"), - "test_main_run_hyper", - ) - print(out) - assert out.find("state: completed") != -1, out - assert out.find("iterations:") != -1, out - - -def test_main_run_args(): - out = exec_run( - f"{tests_root_directory}/no_ctx.py -x " + "{p2}", - ["--uid", "123457"] + compose_param_list(dict(p1=5, p2="aaa")), - "test_main_run_args", - ) - print(out) - assert out.find("state: completed") != -1, out - db = mlrun.get_run_db() - state, log = db.get_log("123457") - print(log) - assert str(log).find(", -x, aaa") != -1, "params not detected in argv" - - -def test_main_run_args_with_url_placeholder_missing_env(): - args = [ - "--name", - "test_main_run_args_with_url_placeholder_missing_env", - "--dump", - "*", - "--arg1", - "value1", - "--arg2", - "value2", - ] - out = exec_main( - "run", - args, - raise_on_error=False, - ) - out_stdout = out.stdout.decode("utf-8") - print(out) - assert ( - out_stdout.find( - "command/url '*' placeholder is not allowed when code is not from env" - ) - != -1 - ), out - - -def test_main_run_args_with_url_placeholder_from_env(): - os.environ["MLRUN_EXEC_CODE"] = b64encode(code.encode("utf-8")).decode("utf-8") - args = [ - "--name", - "test_main_run_args_with_url_placeholder_from_env", - "--uid", - "123456789", - "--from-env", - "--dump", - "*", - "--arg1", - "value1", - "--arg2", - "value2", - ] - exec_main( - "run", - args, - raise_on_error=True, - ) - db = mlrun.get_run_db() - _run = db.read_run("123456789") - print(_run) - assert _run["status"]["results"]["my_args"] == [ - "main.py", - "--arg1", - "value1", - "--arg2", - "value2", - ] - assert _run["status"]["state"] == "completed" - - args = [ - "--name", - "test_main_run_args_with_url_placeholder_with_origin_file", - "--uid", - "987654321", - "--from-env", - "--dump", - "*", - "--origin-file", - "my_file.py", - "--arg3", - "value3", - ] - exec_main( - "run", - args, - raise_on_error=True, - ) - db = mlrun.get_run_db() - _run = db.read_run("987654321") - print(_run) - assert _run["status"]["results"]["my_args"] == ["my_file.py", "--arg3", "value3"] - assert _run["status"]["state"] == "completed" - - -def test_main_with_url_placeholder(): - os.environ["MLRUN_EXEC_CODE"] = b64encode(code.encode("utf-8")).decode("utf-8") - args = [ - "--name", - "test_main_with_url_placeholder", - "--uid", - "123456789", - "--from-env", - "*", - ] - exec_main( - "run", - args, - raise_on_error=True, - ) - db = mlrun.get_run_db() - _run = db.read_run("123456789") - print(_run) - assert _run["status"]["results"]["my_args"] == ["main.py"] - assert _run["status"]["state"] == "completed" - - -@pytest.mark.parametrize( - "op,args,raise_on_error,expected_output", - [ - # bad flag before command - [ - "run", - [ - "--bad-flag", - "--name", - "test_main_run_basic", - "--dump", - f"{examples_path}/training.py", - ], - False, - "Error: Invalid value for '[URL]': URL (--bad-flag) cannot start with '-', " - "ensure the command options are typed correctly. Preferably use '--' to separate options and " - "arguments e.g. 'mlrun run --option1 --option2 -- [URL] [--arg1|arg1] [--arg2|arg2]'", - ], - # bad flag with no command - [ - "run", - ["--name", "test_main_run_basic", "--bad-flag"], - False, - "Error: Invalid value for '[URL]': URL (--bad-flag) cannot start with '-', " - "ensure the command options are typed correctly. Preferably use '--' to separate options and " - "arguments e.g. 'mlrun run --option1 --option2 -- [URL] [--arg1|arg1] [--arg2|arg2]'", - ], - # bad flag after -- separator - [ - "run", - ["--name", "test_main_run_basic", "--", "-notaflag"], - False, - "Error: Invalid value for '[URL]': URL (-notaflag) cannot start with '-', " - "ensure the command options are typed correctly. Preferably use '--' to separate options and " - "arguments e.g. 'mlrun run --option1 --option2 -- [URL] [--arg1|arg1] [--arg2|arg2]'", - ], - # correct command with -- separator - [ - "run", - [ - "--name", - "test_main_run_basic", - "--", - f"{examples_path}/training.py", - "--some-arg", - ], - True, - "status=completed", - ], - ], -) -def test_main_run_args_validation(op, args, raise_on_error, expected_output): - out = exec_main( - op, - args, - raise_on_error=raise_on_error, - ) - if not raise_on_error: - out = out.stderr.decode("utf-8") - - assert out.find(expected_output) != -1, out - - -code = """ -import mlrun, sys -if __name__ == "__main__": - context = mlrun.get_or_create_ctx("test1") - context.log_result("my_args", sys.argv) - context.commit(completed=True) -""" - - -def test_main_run_args_from_env(): - os.environ["MLRUN_EXEC_CODE"] = b64encode(code.encode("utf-8")).decode("utf-8") - os.environ["MLRUN_EXEC_CONFIG"] = ( - '{"spec":{"parameters":{"x": "bbb"}},' - '"metadata":{"uid":"123459", "name":"tst", "labels": {"kind": "job"}}}' - ) - - out = exec_run( - "'main.py -x {x}'", - ["--from-env"], - "test_main_run_args_from_env", - ) - db = mlrun.get_run_db() - run = db.read_run("123459") - print(out) - assert run["status"]["state"] == "completed", out - assert run["status"]["results"]["my_args"] == [ - "main.py", - "-x", - "bbb", - ], "params not detected in argv" - - -nonpy_code = """ -echo "abc123" $1 -""" - - -@pytest.mark.skipif(sys.platform == "win32", reason="skip on windows") -def test_main_run_nonpy_from_env(): - os.environ["MLRUN_EXEC_CODE"] = b64encode(nonpy_code.encode("utf-8")).decode( - "utf-8" - ) - os.environ[ - "MLRUN_EXEC_CONFIG" - ] = '{"spec":{},"metadata":{"uid":"123411", "name":"tst", "labels": {"kind": "job"}}}' - - # --kfp flag will force the logs to print (for the assert) - out = exec_run( - "bash {codefile} xx", - ["--from-env", "--mode", "pass", "--kfp"], - "test_main_run_nonpy_from_env", - ) - db = mlrun.get_run_db() - run = db.read_run("123411") - assert run["status"]["state"] == "completed", out - state, log = db.get_log("123411") - print(state, log) - assert str(log).find("abc123 xx") != -1, "incorrect output" - - -def test_main_run_pass(): - out = exec_run( - "python -c print(56)", - ["--mode", "pass", "--uid", "123458"], - "test_main_run_pass", - ) - print(out) - assert out.find("state: completed") != -1, out - db = mlrun.get_run_db() - state, log = db.get_log("123458") - assert str(log).find("56") != -1, "incorrect output" - - -def test_main_run_pass_args(): - out = exec_run( - "'python -c print({x})'", - ["--mode", "pass", "--uid", "123451", "-p", "x=33"], - "test_main_run_pass", - ) - print(out) - assert out.find("state: completed") != -1, out - db = mlrun.get_run_db() - state, log = db.get_log("123451") - print(log) - assert str(log).find("33") != -1, "incorrect output" - - -def test_main_run_archive(): - args = f"--source {examples_path}/archive.zip --handler handler -p p1=1" - out = exec_run("./myfunc.py", args.split(), "test_main_run_archive") - assert out.find("state: completed") != -1, out - - -def test_main_local_source(): - args = f"--source {examples_path} --handler my_func" - with pytest.raises(Exception) as e: - exec_run("./handler.py", args.split(), "test_main_local_source") - assert ( - "source must be a compressed (tar.gz / zip) file, a git repo, a file path or in the project's context (.)" - in str(e.value) - ) - - -def test_main_run_archive_subdir(): - runtime = '{"spec":{"pythonpath":"./subdir"}}' - args = f"--source {examples_path}/archive.zip -r {runtime}" - out = exec_run("./subdir/func2.py", args.split(), "test_main_run_archive_subdir") - print(out) - assert out.find("state: completed") != -1, out - - -def test_main_local_project(): - project_path = str(pathlib.Path(__file__).parent / "assets") - args = "-f simple -p x=2 --dump" - out = exec_main("run", args.split(), cwd=project_path) - assert out.find("state: completed") != -1, out - assert out.find("y: 4") != -1, out # y = x * 2 - - -def test_main_local_flag(): - fn = mlrun.code_to_function( - filename=f"{examples_path}/handler.py", kind="job", handler="my_func" - ) - yaml_path = f"{out_path}/myfunc.yaml" - fn.export(yaml_path) - args = f"-f {yaml_path} --local" - out = exec_run("", args.split(), "test_main_local_flag") - print(out) - assert out.find("state: completed") != -1, out - - -def test_main_run_class(): - function_path = str(pathlib.Path(__file__).parent / "assets" / "handler.py") - - out = exec_run( - function_path, - compose_param_list(dict(x=8)) + ["--handler", "mycls::mtd"], - "test_main_run_class", - ) - assert out.find("state: completed") != -1, out - assert out.find("rx: 8") != -1, out - - -def test_run_from_module(): - args = ["--name", "test1", "--dump", "--handler", "json.dumps", "-p", "obj=[6,7]"] - out = exec_main("run", args) - assert out.find("state: completed") != -1, out - assert out.find("return: '[6, 7]'") != -1, out - - -def test_main_env_file(): - # test run with env vars loaded from a .env file - function_path = str(pathlib.Path(__file__).parent / "assets" / "handler.py") - envfile = str(pathlib.Path(__file__).parent / "assets" / "envfile") - - out = exec_run( - function_path, - ["--handler", "env_file_test", "--env-file", envfile], - "test_main_env_file", - ) - assert out.find("state: completed") != -1, out - assert out.find("ENV_ARG1: '123'") != -1, out - assert out.find("kfp_ttl: 12345") != -1, out diff --git a/tests/run/test_run.py b/tests/run/test_run.py index b1df0ea2ff96..0b227457297f 100644 --- a/tests/run/test_run.py +++ b/tests/run/test_run.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import datetime +import contextlib +import io import pathlib +import sys from unittest.mock import MagicMock, Mock import pytest import mlrun import mlrun.errors -from mlrun import MLClientCtx, get_run_db, new_function, new_task +import mlrun.launcher.factory +from mlrun import new_function, new_task from tests.conftest import ( examples_path, has_secrets, @@ -43,7 +46,19 @@ assets_path = str(pathlib.Path(__file__).parent / "assets") -def test_noparams(): +@contextlib.contextmanager +def captured_output(): + new_out, new_err = io.StringIO(), io.StringIO() + old_out, old_err = sys.stdout, sys.stderr + try: + sys.stdout, sys.stderr = new_out, new_err + yield sys.stdout, sys.stderr + finally: + sys.stdout, sys.stderr = old_out, old_err + + +def test_noparams(db): + mlrun.get_or_create_project("default") # Since we're executing the function without inputs, it will try to use the input name as the file path result = new_function().run( params={"input_name": str(input_file_path)}, handler=my_func @@ -60,7 +75,7 @@ def test_noparams(): def test_failed_schedule_not_creating_run(): function = new_function() # mock we're with remote api (only there schedule is relevant) - function._use_remote_api = Mock(return_value=True) + function._is_remote = True # mock failure in submit job (failed schedule) db = MagicMock() function.set_db_connection(db) @@ -87,7 +102,8 @@ def test_invalid_name(): ) -def test_with_params(): +def test_with_params(db): + mlrun.get_or_create_project("default") spec = tag_test(base_spec, "test_with_params") result = new_function().run(spec, handler=my_func) @@ -121,13 +137,47 @@ def test_local_runtime(): verify_state(result) -def test_local_runtime_failure_before_executing_the_function_code(): +def test_local_runtime_failure_before_executing_the_function_code(db): function = new_function(command=f"{assets_path}/fail.py") with pytest.raises(mlrun.runtimes.utils.RunError) as exc: function.run(local=True, handler="handler") assert "failed on pre-loading" in str(exc.value) +@pytest.mark.parametrize( + "handler_name,params,kwargs,expected_kwargs", + [ + ("func", {"x": 2}, {"y": 3, "z": 4}, {"y": 3, "z": 4}), + ("func", {"x": 2}, {}, {}), + ("func_with_default", {}, {"y": 3, "z": 4}, {"y": 3, "z": 4}), + ], +) +def test_local_runtime_with_kwargs( + rundb_mock, handler_name, params, kwargs, expected_kwargs +): + params.update(kwargs) + function = new_function(command=f"{assets_path}/kwargs.py") + result = function.run(local=True, params=params, handler=handler_name) + verify_state(result) + assert result.outputs.get("return", {}) == expected_kwargs + + +def test_local_runtime_with_kwargs_with_code_to_function(db): + mlrun.get_or_create_project("default") + function = mlrun.code_to_function( + "kwarg", + filename=f"{assets_path}/kwargs.py", + image="mlrun/mlrun", + kind="job", + handler="func", + ) + kwargs = {"y": 3, "z": 4} + params = {"x": 2} + params.update(kwargs) + result = function.run(local=True, params=params) + assert result.outputs["return"] == kwargs + + def test_local_runtime_hyper(): spec = tag_test(base_spec, "test_local_runtime_hyper") spec.with_hyper_params({"p1": [1, 5, 3]}, selector="max.accuracy") @@ -188,157 +238,20 @@ def test_is_watchable(rundb_mock, kind, watch, expected_watch_count): assert mlrun.RunObject.logs.call_count == expected_watch_count -def test_local_args(): +@pytest.mark.asyncio +async def test_local_args(db, db_session): spec = tag_test(base_spec, "test_local_no_context") spec.spec.parameters = {"xyz": "789"} - result = new_function( - command=f"{tests_root_directory}/no_ctx.py --xyz {{xyz}}" - ).run(spec) - verify_state(result) - - db = get_run_db() - state, log = db.get_log(result.metadata.uid) - log = str(log) - print(state) - print(log) - assert log.find(", --xyz, 789") != -1, "params not detected in argv" - - -def test_local_context(rundb_mock): - project_name = "xtst" - mlrun.mlconf.artifact_path = out_path - context = mlrun.get_or_create_ctx("xx", project=project_name, upload_artifacts=True) - db = mlrun.get_run_db() - run = db.read_run(context._uid, project=project_name) - assert run["struct"]["status"]["state"] == "running", "run status not updated in db" - - with context: - context.log_artifact("xx", body="123", local_path="a.txt") - context.log_model("mdl", body="456", model_file="mdl.pkl", artifact_path="+/mm") - context.get_param("p1", 1) - context.get_param("p2", "a string") - context.log_result("accuracy", 16) - context.set_label("label-key", "label-value") - context.set_annotation("annotation-key", "annotation-value") - context._set_input("input-key", "input-url") - - artifact = context.get_cached_artifact("xx") - artifact.format = "z" - context.update_artifact(artifact) - - assert context._state == "completed", "task did not complete" - - run = db.read_run(context._uid, project=project_name) - run = run["struct"] - - # run state should not be updated by the context - assert run["status"]["state"] == "running", "run status was updated in db" - assert ( - run["status"]["artifacts"][0]["metadata"]["key"] == "xx" - ), "artifact not updated in db" - assert ( - run["status"]["artifacts"][0]["spec"]["format"] == "z" - ), "run/artifact attribute not updated in db" - assert run["status"]["artifacts"][1]["spec"]["target_path"].startswith( - out_path - ), "artifact not uploaded to subpath" - - db_artifact = db.read_artifact(artifact.db_key, project=project_name) - assert db_artifact["spec"]["format"] == "z", "artifact attribute not updated in db" - - assert run["spec"]["parameters"]["p1"] == 1, "param not updated in db" - assert run["spec"]["parameters"]["p2"] == "a string", "param not updated in db" - assert run["status"]["results"]["accuracy"] == 16, "result not updated in db" - assert run["metadata"]["labels"]["label-key"] == "label-value", "label not updated" - assert ( - run["metadata"]["annotations"]["annotation-key"] == "annotation-value" - ), "annotation not updated" - - assert run["spec"]["inputs"]["input-key"] == "input-url", "input not updated" - - -def test_context_from_dict_when_start_time_is_string(): - context = mlrun.get_or_create_ctx("ctx") - context_dict = context.to_dict() - context = mlrun.MLClientCtx.from_dict(context_dict) - assert isinstance(context._start_time, datetime.datetime) - - -def test_context_from_run_dict(): - run_dict = { - "metadata": { - "name": "test-context-from-run-dict", - "project": "default", - "labels": {"label-key": "label-value"}, - "annotations": {"annotation-key": "annotation-value"}, - }, - "spec": { - "parameters": {"p1": 1, "p2": "a string"}, - "inputs": {"input-key": "input-url"}, - }, - } - runtime = mlrun.runtimes.base.BaseRuntime.from_dict(run_dict) - run = runtime._create_run_object(run_dict) - handler = "my_func" - out_path = "test_artifact_path" - run = runtime._enrich_run( - run, - handler, - run_dict["metadata"]["project"], - run_dict["metadata"]["name"], - run_dict["spec"]["parameters"], - run_dict["spec"]["inputs"], - returns="", - hyperparams=None, - hyper_param_options=None, - verbose=False, - scrape_metrics=None, - out_path=out_path, - artifact_path="", - workdir="", - ) - context = MLClientCtx.from_dict(run.to_dict()) - assert context.name == run_dict["metadata"]["name"] - assert context._project == run_dict["metadata"]["project"] - assert context._labels == run_dict["metadata"]["labels"] - assert context._annotations == run_dict["metadata"]["annotations"] - assert context.get_param("p1") == run_dict["spec"]["parameters"]["p1"] - assert context.get_param("p2") == run_dict["spec"]["parameters"]["p2"] - assert ( - context.get_input("input-key").artifact_url - == run_dict["spec"]["inputs"]["input-key"] - ) - assert context.labels["label-key"] == run_dict["metadata"]["labels"]["label-key"] - assert ( - context.annotations["annotation-key"] - == run_dict["metadata"]["annotations"]["annotation-key"] - ) - assert context.artifact_path == out_path + function = new_function(command=f"{tests_root_directory}/no_ctx.py --xyz {{xyz}}") + with captured_output() as (out, err): + result = function.run(spec) -@pytest.mark.parametrize( - "state, error, expected_state", - [ - ("running", None, "completed"), - ("completed", None, "completed"), - (None, "error message", "error"), - (None, "", "error"), - ], -) -def test_context_set_state(rundb_mock, state, error, expected_state): - project_name = "test_context_error" - mlrun.mlconf.artifact_path = out_path - context = mlrun.get_or_create_ctx("xx", project=project_name, upload_artifacts=True) - db = mlrun.get_run_db() - run = db.read_run(context._uid, project=project_name) - assert run["struct"]["status"]["state"] == "running", "run status not updated in db" + output = out.getvalue().strip() - with context: - context.set_state(execution_state=state, error=error, commit=False) - context.commit(completed=True) + verify_state(result) - assert context._state == expected_state, "task state was not set correctly" - assert context._error == error, "task error was not set" + assert output.find(", --xyz, 789") != -1, "params not detected in argv" def test_run_class_code(): @@ -372,15 +285,15 @@ def test_run_from_module(): def test_args_integrity(): spec = tag_test(base_spec, "test_local_no_context") spec.spec.parameters = {"xyz": "789"} - result = new_function( + function = new_function( command=f"{tests_root_directory}/no_ctx.py", args=["It's", "a", "nice", "day!"], - ).run(spec) + ) + + with captured_output() as (out, err): + result = function.run(spec) + + output = out.getvalue().strip() verify_state(result) - db = get_run_db() - state, log = db.get_log(result.metadata.uid) - log = str(log) - print(state) - print(log) - assert log.find("It's, a, nice, day!") != -1, "params not detected in argv" + assert output.find("It's, a, nice, day!") != -1, "params not detected in argv" diff --git a/tests/rundb/test_dbs.py b/tests/rundb/test_dbs.py index da37918e26ae..1a8d7542f002 100644 --- a/tests/rundb/test_dbs.py +++ b/tests/rundb/test_dbs.py @@ -23,13 +23,12 @@ from mlrun.api.initial_data import init_data from mlrun.api.utils.singletons.db import initialize_db from mlrun.config import config -from mlrun.db import SQLDB, FileRunDB, sqldb +from mlrun.db import SQLDB, sqldb from mlrun.db.base import RunDBInterface from tests.conftest import new_run, run_now dbs = [ "sql", - "file", # TODO: 'httpdb', ] @@ -42,13 +41,11 @@ def db(request): db_file = f"{path}/mlrun.db" dsn = f"sqlite:///{db_file}?check_same_thread=false" config.httpdb.dsn = dsn - _init_engine(dsn) + _init_engine(dsn=dsn) init_data() initialize_db() db_session = create_session() db = SQLDB(dsn, session=db_session) - elif request.param == "file": - db = FileRunDB(path) else: assert False, f"unknown db type - {request.param}" @@ -139,8 +136,6 @@ def test_artifacts(db: RunDBInterface): def test_list_runs(db: RunDBInterface): - if isinstance(db, FileRunDB): - pytest.skip("FIXME") uid = "u183" run = new_run("s1", {"l1": "v1", "l2": "v2"}, uid, x=1) count = 5 diff --git a/tests/rundb/test_filedb.py b/tests/rundb/test_filedb.py deleted file mode 100644 index 523df8f51a47..000000000000 --- a/tests/rundb/test_filedb.py +++ /dev/null @@ -1,91 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime, timedelta, timezone -from tempfile import mkdtemp - -import pytest - -from mlrun.db import FileRunDB - - -@pytest.fixture -def db(): - path = mkdtemp(prefix="mlrun-test") - db = FileRunDB(dirpath=path) - db.connect() - return db - - -def test_save_get_function(db: FileRunDB): - func, name, proj = {"x": 1, "y": 2}, "f1", "p2" - db.store_function(func, name, proj) - db_func = db.get_function(name, proj) - - # db methods enriches metadata - del db_func["metadata"] - del func["metadata"] - assert db_func == func, "wrong func" - - -def test_list_functions(db: FileRunDB): - proj = "p4" - count = 5 - for i in range(count): - name = f"func{i}" - func = {"fid": i} - db.store_function(func, name, proj) - db.store_function({}, "f2", "p7") - - out = db.list_functions("", proj) - assert len(out) == count, "bad list" - - -def test_schedules(db: FileRunDB): - count = 7 - for i in range(count): - data = {"i": i} - db.store_schedule(data) - - scheds = list(db.list_schedules()) - assert count == len(scheds), "wrong number of schedules" - assert set(range(count)) == set(s["i"] for s in scheds), "bad scheds" - - -def test_list_artifact_date(db: FileRunDB): - print("dirpath: ", db.dirpath) - t1 = datetime(2020, 2, 16, tzinfo=timezone.utc) - t2 = t1 - timedelta(days=7) - t3 = t2 - timedelta(days=7) - prj = "p7" - - db.store_artifact("k1", {"updated": t1.isoformat()}, "u1", project=prj) - db.store_artifact("k2", {"updated": t2.isoformat()}, "u2", project=prj) - db.store_artifact("k3", {"updated": t3.isoformat()}, "u3", project=prj) - - # FIXME: We get double what we expect since latest is an alias - arts = db.list_artifacts(project=prj, since=t3, tag="*") - assert 6 == len(arts), "since t3" - - arts = db.list_artifacts(project=prj, since=t2, tag="*") - assert 4 == len(arts), "since t2" - - arts = db.list_artifacts(project=prj, since=t1 + timedelta(days=1), tag="*") - assert not arts, "since t1+" - - arts = db.list_artifacts(project=prj, until=t2, tag="*") - assert 4 == len(arts), "until t2" - - arts = db.list_artifacts(project=prj, since=t2, until=t2, tag="*") - assert 2 == len(arts), "since/until t2" diff --git a/tests/rundb/test_httpdb.py b/tests/rundb/test_httpdb.py index e2840d88e456..0ba78c6b54f6 100644 --- a/tests/rundb/test_httpdb.py +++ b/tests/rundb/test_httpdb.py @@ -14,6 +14,8 @@ import codecs import io +import sys +import time import unittest.mock from collections import namedtuple from os import environ @@ -30,10 +32,10 @@ import requests_mock import mlrun.artifacts.base +import mlrun.common.schemas import mlrun.errors import mlrun.projects.project from mlrun import RunObject -from mlrun.api import schemas from mlrun.db.httpdb import HTTPRunDB from tests.conftest import tests_root_directory, wait_for_server @@ -51,7 +53,7 @@ def free_port(): def check_server_up(url): health_url = f"{url}/{HTTPRunDB.get_api_path_prefix()}/healthz" - timeout = 30 + timeout = 90 if not wait_for_server(health_url, timeout): raise RuntimeError(f"server did not start after {timeout} sec") @@ -201,6 +203,27 @@ def test_log(create_server): assert data == body, "bad log data" +@pytest.mark.skipif( + sys.platform == "darwin", + reason="We are developing on Apple Silicon Macs," + " which will most likely fail this test due to the qemu being slow," + " but should pass on native architecture", +) +def test_api_boot_speed(create_server): + run_times = 5 + expected_time = 30 + runs = [] + for _ in range(run_times): + start_time = time.perf_counter() + create_server() + end_time = time.perf_counter() + runs.append(end_time - start_time) + avg_run_time = sum(runs) / run_times + assert ( + avg_run_time <= expected_time + ), "Seems like a performance hit on creating api server" + + def test_run(create_server): server: Server = create_server() db = server.conn @@ -633,7 +656,7 @@ def test_feature_vectors(create_server): feature_vector_update, project, tag="latest", - patch_mode=schemas.PatchMode.additive, + patch_mode=mlrun.common.schemas.PatchMode.additive, ) feature_vectors = db.list_feature_vectors(project=project) assert len(feature_vectors) == count, "bad list results - wrong number of members" @@ -662,7 +685,10 @@ def test_feature_vectors(create_server): # Perform a replace (vs. additive as done earlier) - now should only have 2 features db.patch_feature_vector( - name, feature_vector_update, project, patch_mode=schemas.PatchMode.replace + name, + feature_vector_update, + project, + patch_mode=mlrun.common.schemas.PatchMode.replace, ) feature_vector = db.get_feature_vector(name, project) assert ( @@ -677,7 +703,7 @@ def test_project_file_db_roundtrip(create_server): project_name = "project-name" description = "project description" goals = "project goals" - desired_state = mlrun.api.schemas.ProjectState.archived + desired_state = mlrun.common.schemas.ProjectState.archived params = {"param_key": "param value"} artifact_path = "/tmp" conda = "conda" diff --git a/tests/rundb/test_nopdb.py b/tests/rundb/test_nopdb.py new file mode 100644 index 000000000000..c8dc848e52b0 --- /dev/null +++ b/tests/rundb/test_nopdb.py @@ -0,0 +1,44 @@ +# Copyright 2022 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest + +import mlrun + + +def test_nopdb(): + # by default we use a nopdb with raise_error = False + assert mlrun.mlconf.httpdb.nop_db.raise_error is False + + rundb = mlrun.get_run_db() + assert isinstance(rundb, mlrun.db.NopDB) + + # not expected to fail as it in the white list + rundb.connect() + + # not expected to fail + rundb.read_run("123") + + # set raise_error to True + mlrun.mlconf.httpdb.nop_db.raise_error = True + + assert mlrun.mlconf.httpdb.nop_db.raise_error is True + + # not expected to fail as it in the white list + rundb.connect() + + # expected to fail + with pytest.raises(mlrun.errors.MLRunBadRequestError): + rundb.read_run("123") diff --git a/tests/rundb/test_sqldb.py b/tests/rundb/test_sqldb.py index 58a3b183cfcb..8535bcbee0ed 100644 --- a/tests/rundb/test_sqldb.py +++ b/tests/rundb/test_sqldb.py @@ -19,7 +19,7 @@ import deepdiff from sqlalchemy.orm import Session -import mlrun.api.schemas +import mlrun.common.schemas from mlrun.api.db.sqldb.db import SQLDB from mlrun.api.db.sqldb.models import Artifact from mlrun.lists import ArtifactList @@ -59,12 +59,12 @@ def test_list_artifact_tags(db: SQLDB, db_session: Session): # filter by category model_tags = db.list_artifact_tags( - db_session, "p1", mlrun.api.schemas.ArtifactCategories.model + db_session, "p1", mlrun.common.schemas.ArtifactCategories.model ) assert [("p1", "k2", "t3"), ("p1", "k2", "latest")] == model_tags model_tags = db.list_artifact_tags( - db_session, "p2", mlrun.api.schemas.ArtifactCategories.dataset + db_session, "p2", mlrun.common.schemas.ArtifactCategories.dataset ) assert [("p2", "k3", "t4"), ("p2", "k3", "latest")] == model_tags @@ -200,10 +200,12 @@ def test_read_and_list_artifacts_with_tags(db: SQLDB, db_session: Session): def test_projects_crud(db: SQLDB, db_session: Session): - project = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name="p1"), - spec=mlrun.api.schemas.ProjectSpec(description="banana", other_field="value"), - status=mlrun.api.schemas.ObjectStatus(state="active"), + project = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name="p1"), + spec=mlrun.common.schemas.ProjectSpec( + description="banana", other_field="value" + ), + status=mlrun.common.schemas.ObjectStatus(state="active"), ) db.create_project(db_session, project) project_output = db.get_project(db_session, name=project.metadata.name) @@ -221,12 +223,12 @@ def test_projects_crud(db: SQLDB, db_session: Session): project_output = db.get_project(db_session, name=project.metadata.name) assert project_output.spec.description == project_patch["spec"]["description"] - project_2 = mlrun.api.schemas.Project( - metadata=mlrun.api.schemas.ProjectMetadata(name="p2"), + project_2 = mlrun.common.schemas.Project( + metadata=mlrun.common.schemas.ProjectMetadata(name="p2"), ) db.create_project(db_session, project_2) projects_output = db.list_projects( - db_session, format_=mlrun.api.schemas.ProjectsFormat.name_only + db_session, format_=mlrun.common.schemas.ProjectsFormat.name_only ) assert [project.metadata.name, project_2.metadata.name] == projects_output.projects diff --git a/tests/runtimes/arc.txt b/tests/runtimes/arc.txt index 81f3f251f205..750176621555 100644 --- a/tests/runtimes/arc.txt +++ b/tests/runtimes/arc.txt @@ -2,7 +2,7 @@ def arc_to_parquet( context: MLClientCtx, - archive_url: Union[str, DataItem], + archive_url: Optional[Union[str, DataItem]], header: Optional[List[str]] = None, chunksize: int = 10_000, dtype=None, @@ -10,6 +10,9 @@ def arc_to_parquet( key: str = "data", dataset: Optional[str] = None, part_cols = [], + str_list: List[str] = [], + full_import: mlrun.run.RunObject = [], + full_import_with_slice: typing.Union[typing.List[str], mlrun.run.RunObject] = [], ) -> None: """Open a file/object archive and save as a parquet file. Partitioning requires precise specification of column types. diff --git a/tests/runtimes/assets/verbose_stderr_return_code_0.py b/tests/runtimes/assets/verbose_stderr_return_code_0.py new file mode 100644 index 000000000000..2d88e2b40ab9 --- /dev/null +++ b/tests/runtimes/assets/verbose_stderr_return_code_0.py @@ -0,0 +1,22 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +print("some output") + +for i in range(10000): + print("123456789", file=sys.stderr) + +sys.exit(0) diff --git a/tests/runtimes/assets/verbose_stderr_return_code_1.py b/tests/runtimes/assets/verbose_stderr_return_code_1.py new file mode 100644 index 000000000000..46d58a43bad5 --- /dev/null +++ b/tests/runtimes/assets/verbose_stderr_return_code_1.py @@ -0,0 +1,22 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +print("some output") + +for i in range(10000): + print("123456789", file=sys.stderr) + +sys.exit(1) diff --git a/tests/runtimes/info_cases.yml b/tests/runtimes/info_cases.yml index 04505b4191a6..e4e08b0822bb 100644 --- a/tests/runtimes/info_cases.yml +++ b/tests/runtimes/info_cases.yml @@ -30,6 +30,8 @@ doc: "" default: "" lineno: 1 + has_varargs: false + has_kwargs: false id: inc_ann - code: | def inc(n): @@ -48,6 +50,8 @@ doc: "" default: "" lineno: 1 + has_varargs: false + has_kwargs: false id: inc_no_ann - code: | def inc(n: int) -> int: @@ -72,6 +76,8 @@ doc: number to increment default: "" lineno: 1 + has_varargs: false + has_kwargs: false id: inc_ann_doc - code: | def inc(n: int, delta: int = 1) -> int: @@ -95,6 +101,8 @@ doc: "" default: "1" lineno: 1 + has_varargs: false + has_kwargs: false id: inc_ann_default - code: | def open_archive(context, @@ -129,4 +137,6 @@ doc: "source archive path/url" default: "''" lineno: 1 + has_varargs: false + has_kwargs: false id: undocumented param diff --git a/tests/runtimes/test_base.py b/tests/runtimes/test_base.py index 6a770b00f123..ee8ff07c2bc9 100644 --- a/tests/runtimes/test_base.py +++ b/tests/runtimes/test_base.py @@ -82,47 +82,107 @@ def test_auto_mount_v3io(self, cred_only, rundb_mock): "requirements,encoded_requirements", [ # strip spaces - (["pandas==1.0.0", "numpy==1.0.0 "], "pandas==1.0.0 numpy==1.0.0"), + (["pandas==1.0.0", "numpy==1.0.0 "], ["pandas==1.0.0", "numpy==1.0.0"]), # handle ranges - (["pandas>=1.0.0, <2"], "'pandas>=1.0.0, <2'"), - (["pandas>=1.0.0,<2"], "'pandas>=1.0.0,<2'"), + (["pandas>=1.0.0, <2"], ["pandas>=1.0.0, <2"]), + (["pandas>=1.0.0,<2"], ["pandas>=1.0.0,<2"]), # handle flags - (["-r somewhere/requirements.txt"], "-r somewhere/requirements.txt"), + (["-r somewhere/requirements.txt"], ["-r somewhere/requirements.txt"]), # handle flags and specific # handle escaping within specific ( ["-r somewhere/requirements.txt", "pandas>=1.0.0, <2"], - "-r somewhere/requirements.txt 'pandas>=1.0.0, <2'", + ["-r somewhere/requirements.txt", "pandas>=1.0.0, <2"], ), # handle from git ( ["something @ git+https://somewhere.com/a/b.git@v0.0.0#egg=something"], - "'something @ git+https://somewhere.com/a/b.git@v0.0.0#egg=something'", + ["something @ git+https://somewhere.com/a/b.git@v0.0.0#egg=something"], ), # handle comments - (["# dont care", "faker"], "faker"), - (["faker # inline dontcare"], "faker"), - (["faker #inline dontcare2"], "faker"), + (["# dont care", "faker"], ["faker"]), + (["faker # inline dontcare"], ["faker"]), + (["faker #inline dontcare2"], ["faker"]), + ( + [ + "numpy==1.0.0 ", + "pandas>=1.0.0, <2", + "# dont care", + "pandas2>=1.0.0,<2 # just an inline comment", + "-r somewhere/requirements.txt", + "something @ git+https://somewhere.com/a/b.git@v0.0.0#egg=something", + ], + [ + "numpy==1.0.0", + "pandas>=1.0.0, <2", + "pandas2>=1.0.0,<2", + "-r somewhere/requirements.txt", + "something @ git+https://somewhere.com/a/b.git@v0.0.0#egg=something", + ], + ), ], ) - def test_encode_requirements(self, requirements, encoded_requirements): + def test_resolve_requirements(self, requirements, encoded_requirements): for requirements_as_file in [True, False]: if requirements_as_file: # create a temporary file with the requirements - with tempfile.NamedTemporaryFile( - delete=False, dir=self._temp_dir - ) as temp_file: - with open(temp_file.name, "w") as f: - for requirement in requirements: - f.write(requirement + "\n") - requirements = temp_file.name - - encoded = self._generate_runtime()._encode_requirements(requirements) + requirements = self._create_temp_requirements_file(requirements) + + encoded = self._generate_runtime().spec.build._resolve_requirements( + requirements + ) assert ( encoded == encoded_requirements ), f"Failed to encode {requirements} as file {requirements_as_file}" + @pytest.mark.parametrize( + "requirements,requirements_in_file,encoded_requirements,requirements_as_file", + [ + ( + ["pandas==1.0.0", "numpy==1.0.0"], + ["something==1.0.0", "otherthing==1.0.0"], + [ + "something==1.0.0", + "otherthing==1.0.0", + "pandas==1.0.0", + "numpy==1.0.0", + ], + False, + ), + ( + ["pandas==1.0.0", "numpy==1.0.0"], + ["something==1.0.0", "otherthing==1.0.0"], + [ + "something==1.0.0", + "otherthing==1.0.0", + "pandas==1.0.0", + "numpy==1.0.0", + ], + True, + ), + ], + ) + def test_resolve_requirements_file( + self, + requirements, + requirements_in_file, + encoded_requirements, + requirements_as_file, + ): + # create requirements file + requirements_file = self._create_temp_requirements_file(requirements_in_file) + + if requirements_as_file: + requirements = self._create_temp_requirements_file(requirements) + + encoded = self._generate_runtime().spec.build._resolve_requirements( + requirements, requirements_file + ) + assert ( + encoded == encoded_requirements + ), f"Failed to encode {requirements.extend(requirements_in_file)} as file {requirements_file}" + def test_fill_credentials(self, rundb_mock): """ expects to set the generate access key so that the API will enrich with the auth session that is being passed @@ -273,3 +333,12 @@ def test_auto_mount_env(self, rundb_mock): rundb_mock.reset() self._execute_run(runtime) rundb_mock.assert_env_variables(expected_env) + + def _create_temp_requirements_file(self, requirements): + with tempfile.NamedTemporaryFile( + delete=False, dir=self._temp_dir, suffix=".txt" + ) as temp_file: + with open(temp_file.name, "w") as f: + for requirement in requirements: + f.write(requirement + "\n") + return temp_file.name diff --git a/tests/runtimes/test_funcdoc.py b/tests/runtimes/test_funcdoc.py index 05b72facd26f..ca04938ccf7c 100644 --- a/tests/runtimes/test_funcdoc.py +++ b/tests/runtimes/test_funcdoc.py @@ -101,13 +101,15 @@ def inc(n): "return": funcdoc.param_dict(), "params": [funcdoc.param_dict("n")], "lineno": 6, + "has_varargs": False, + "has_kwargs": False, }, ] def test_find_handlers(): funcs = funcdoc.find_handlers(find_handlers_code) - assert find_handlers_expected == funcs + assert funcs == find_handlers_expected ast_code_cases = [ @@ -139,10 +141,64 @@ def test_ast_none(): def fn() -> None: pass """ - fn = ast.parse(dedent(code)).body[0] + fn: ast.FunctionDef = ast.parse(dedent(code)).body[0] funcdoc.ast_func_info(fn) +@pytest.mark.parametrize( + "func_code,expected_has_varargs,expected_has_kwargs", + [ + ( + """ + def fn(p1,p2,*args,**kwargs) -> None: + pass + """, + True, + True, + ), + ( + """ + def fn(p1,p2,*args) -> None: + pass + """, + True, + False, + ), + ( + """ + def fn(p1,p2,**kwargs) -> None: + pass + """, + False, + True, + ), + ( + """ + def fn(p1,p2) -> None: + pass + """, + False, + False, + ), + ( + """ + def fn(p1,p2,**something) -> None: + pass + """, + False, + True, + ), + ], +) +def test_ast_func_info_with_kwargs_and_args( + func_code, expected_has_varargs, expected_has_kwargs +): + fn: ast.FunctionDef = ast.parse(dedent(func_code)).body[0] + func_info = funcdoc.ast_func_info(fn) + assert func_info["has_varargs"] == expected_has_varargs + assert func_info["has_kwargs"] == expected_has_kwargs + + def test_ast_compound(): param_types = [] with open(f"{tests_root_directory}/runtimes/arc.txt") as fp: @@ -150,7 +206,7 @@ def test_ast_compound(): # collect the types of the function parameters # assumes each param is in a new line for simplicity - for line in code.splitlines()[3:12]: + for line in code.splitlines()[3:15]: if ":" not in line: param_types.append(None) continue diff --git a/tests/runtimes/test_function.py b/tests/runtimes/test_function.py index f7a5829e947c..24c73659792b 100644 --- a/tests/runtimes/test_function.py +++ b/tests/runtimes/test_function.py @@ -20,10 +20,6 @@ import mlrun from mlrun import code_to_function -from mlrun.runtimes.function import ( - _resolve_nuclio_runtime_python_image, - _resolve_work_dir_and_handler, -) from mlrun.utils.helpers import resolve_git_reference_from_source from tests.runtimes.test_base import TestAutoMount @@ -154,41 +150,6 @@ def test_v3io_stream_trigger(): assert trigger["attributes"]["ackWindowSize"] == 10 -def test_resolve_work_dir_and_handler(): - cases = [ - (None, ("", "main:handler")), - ("x", ("", "x:handler")), - ("x:y", ("", "x:y")), - ("dir#", ("dir", "main:handler")), - ("dir#x", ("dir", "x:handler")), - ("dir#x:y", ("dir", "x:y")), - ] - for handler, expected in cases: - assert expected == _resolve_work_dir_and_handler(handler) - - -@pytest.mark.parametrize( - "mlrun_client_version,python_version,expected_runtime", - [ - ("1.3.0", "3.9.16", "python:3.9"), - ("1.3.0", "3.7.16", "python:3.7"), - (None, None, "python:3.7"), - (None, "3.9.16", "python:3.7"), - ("1.3.0", None, "python:3.7"), - ("0.0.0-unstable", "3.9.16", "python:3.9"), - ("0.0.0-unstable", "3.7.16", "python:3.7"), - ("1.2.0", "3.9.16", "python:3.7"), - ("1.2.0", "3.7.16", "python:3.7"), - ], -) -def test_resolve_nuclio_runtime_python_image( - mlrun_client_version, python_version, expected_runtime -): - assert expected_runtime == _resolve_nuclio_runtime_python_image( - mlrun_client_version, python_version - ) - - def test_resolve_git_reference_from_source(): cases = [ # source, (repo, refs, branch) diff --git a/tests/runtimes/test_local.py b/tests/runtimes/test_local.py new file mode 100644 index 000000000000..902ea21fda39 --- /dev/null +++ b/tests/runtimes/test_local.py @@ -0,0 +1,38 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pathlib + +import pytest + +from mlrun.runtimes.local import run_exec + + +def test_run_exec_basic(): + out, err = run_exec(["echo"], ["hello"]) + assert out == "hello\n" + assert err == "" + + +# ML-3710 +@pytest.mark.parametrize("return_code", [0, 1]) +def test_run_exec_verbose_stderr(return_code): + script_path = str( + pathlib.Path(__file__).parent + / "assets" + / f"verbose_stderr_return_code_{return_code}.py" + ) + out, err = run_exec(["python"], [script_path]) + assert out == "some output\n" + expected_err_length = 100000 if return_code else 0 + assert len(err) == expected_err_length diff --git a/tests/runtimes/test_logging_and_parsing.py b/tests/runtimes/test_logging_and_parsing.py deleted file mode 100644 index ee2e6ec172d7..000000000000 --- a/tests/runtimes/test_logging_and_parsing.py +++ /dev/null @@ -1,126 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import tempfile -from typing import Tuple - -import numpy as np -import pandas as pd -from sklearn.impute import SimpleImputer -from sklearn.pipeline import Pipeline -from sklearn.preprocessing import OrdinalEncoder - -import mlrun - - -def log_artifacts_and_results() -> Tuple[ - np.ndarray, pd.DataFrame, dict, list, int, str, Pipeline -]: - encoder_to_imputer = Pipeline( - steps=[ - ( - "imputer", - SimpleImputer(missing_values="", strategy="constant", fill_value="C"), - ), - ("encoder", OrdinalEncoder()), - ] - ) - encoder_to_imputer.fit([["A"], ["B"], ["C"]]) - return ( - np.ones((10, 20)), - pd.DataFrame(np.zeros((20, 10))), - {"a": [1, 2, 3, 4], "b": [5, 6, 7, 8]}, - [["A"], ["B"], [""]], - 3, - "hello", - encoder_to_imputer, - ) - - -def parse_inputs(my_array, my_df, my_dict: dict, my_list, my_object, my_int, my_str): - assert isinstance(my_array, np.ndarray) - assert np.all(my_array == np.ones((10, 20))) - - assert isinstance(my_df, mlrun.DataItem) - my_df = my_df.as_df() - assert my_df.shape == (20, 10) - assert my_df.sum().sum() == 0 - - assert isinstance(my_dict, dict) - assert my_dict == {"a": {0: 1, 1: 2, 2: 3, 3: 4}, "b": {0: 5, 1: 6, 2: 7, 3: 8}} - - assert isinstance(my_list, list) - assert my_list == [["A"], ["B"], [""]] - - assert isinstance(my_object, Pipeline) - assert my_object.transform(my_list).tolist() == [[0], [1], [2]] - - return [my_str] * my_int - - -def test_parse_inputs_from_mlrun_function(): - """ - Run the `parse_inputs_from_mlrun_function` function with MLRun to see the wrapper is parsing the given inputs - (`DataItem`s) to the written configuration provided to the wrapper. - """ - # Create the function and run 2 of the previous functions to create a dataset and encoder objects: - mlrun_function = mlrun.code_to_function(filename=__file__, kind="job") - artifact_path = tempfile.TemporaryDirectory() - log_artifacts_and_results_run = mlrun_function.run( - handler="log_artifacts_and_results", - returns=[ - "my_array", - "my_df:dataset", - {"key": "my_dict", "artifact_type": "dataset"}, - "my_list: dataset", - "my_int", - "my_str : result", - "my_object: object", - ], - artifact_path=artifact_path.name, - local=True, - ) - - # Run the function that will parse the data items: - parse_inputs_run = mlrun_function.run( - handler="parse_inputs", - inputs={ - "my_list:list": log_artifacts_and_results_run.outputs["my_list"], - "my_array : numpy.ndarray": log_artifacts_and_results_run.outputs[ - "my_array" - ], - "my_df": log_artifacts_and_results_run.outputs["my_df"], - "my_object: sklearn.pipeline.Pipeline": log_artifacts_and_results_run.outputs[ - "my_object" - ], - "my_dict: dict": log_artifacts_and_results_run.outputs["my_dict"], - }, - returns=["result_list: result"], - params={ - "my_int": log_artifacts_and_results_run.outputs["my_int"], - "my_str": log_artifacts_and_results_run.outputs["my_str"], - }, - artifact_path=artifact_path.name, - local=True, - ) - - # Manual validation: - mlrun.utils.logger.info(parse_inputs_run.outputs) - - # Assertion: - assert len(parse_inputs_run.outputs) == 1 # result - assert parse_inputs_run.outputs["result_list"] == ["hello", "hello", "hello"] - - # Clean the test outputs: - artifact_path.cleanup() diff --git a/tests/runtimes/test_pod.py b/tests/runtimes/test_pod.py index 421089f9a93a..02016dd70047 100644 --- a/tests/runtimes/test_pod.py +++ b/tests/runtimes/test_pod.py @@ -205,3 +205,29 @@ def test_volume_mounts_addition(): sanitized_dict_volume_mount, ] assert len(function.spec.volume_mounts) == 1 + + +def test_build_config_with_multiple_commands(): + image = "mlrun/ml-models" + fn = mlrun.new_function( + "some-function", "some-project", "some-tag", image=image, kind="job" + ) + fn.build_config(commands=["pip install pandas", "pip install numpy"]) + assert len(fn.spec.build.commands) == 2 + + fn.build_config(commands=["pip install pandas"]) + assert len(fn.spec.build.commands) == 2 + + +def test_build_config_preserve_order(): + function = mlrun.new_function("some-function", kind="job") + # run a lot of times as order change + commands = [] + for index in range(10): + commands.append(str(index)) + # when using un-stable (doesn't preserve order) methods to make a list unique (like list(set(x))) it's random + # whether the order will be preserved, therefore run in a loop + for _ in range(100): + function.spec.build.commands = [] + function.build_config(commands=commands) + assert function.spec.build.commands == commands diff --git a/tests/runtimes/test_run.py b/tests/runtimes/test_run.py index 86693594cb88..12b149950ed7 100644 --- a/tests/runtimes/test_run.py +++ b/tests/runtimes/test_run.py @@ -13,7 +13,9 @@ # limitations under the License. # import copy +import pathlib +import pytest from deepdiff import DeepDiff import mlrun @@ -44,12 +46,13 @@ def _get_runtime(): "volume_mounts": [], "env": [], "description": "", - "build": {"commands": []}, + "build": {"commands": [], "requirements": []}, "affinity": None, "disable_auto_mount": False, "priority_class_name": "", "tolerations": None, "security_context": None, + "clone_target_dir": "", }, "verbose": False, } @@ -311,3 +314,22 @@ def test_new_function_invalid_characters(): invalid_function_name = "invalid_name with_spaces" function = mlrun.new_function(name=invalid_function_name, runtime=runtime) assert function.metadata.name == "invalid-name-with-spaces" + + +def test_set_envs(): + assets_path = pathlib.Path(__file__).parent.parent / "assets" + env_path = str(assets_path / "envfile") + runtime = _get_runtime() + function = mlrun.new_function(runtime=runtime) + function.set_envs(file_path=env_path) + assert function.get_env("ENV_ARG1") == "123" + assert function.get_env("ENV_ARG2") == "abc" + + +def test_set_envs_file_not_find(): + runtime = _get_runtime() + function = mlrun.new_function(runtime=runtime) + file_name = ".env-test" + with pytest.raises(mlrun.errors.MLRunNotFoundError) as excinfo: + function.set_envs(file_path=file_name) + assert f"{file_name} does not exist" in str(excinfo.value) diff --git a/tests/serving/test_async_flow.py b/tests/serving/test_async_flow.py index 66f0b748f9da..09761fbfb017 100644 --- a/tests/serving/test_async_flow.py +++ b/tests/serving/test_async_flow.py @@ -74,7 +74,6 @@ def test_async_nested(): graph.add_step(name="final", class_name="Echo", after="ensemble").respond() - logger.info(graph.to_yaml()) server = function.to_mock_server() # plot the graph for test & debug @@ -89,22 +88,25 @@ def test_on_error(): function = mlrun.new_function("tests", kind="serving") graph = function.set_topology("flow", engine="async") chain = graph.to("Chain", name="s1") - chain.to("Raiser").error_handler("catch").to("Chain", name="s3") + chain.to("Raiser").error_handler( + name="catch", class_name="EchoError", full_event=True + ).to("Chain", name="s3") - graph.add_step( - name="catch", class_name="EchoError", after="" - ).respond().full_event = True function.verbose = True server = function.to_mock_server() - logger.info(graph.to_yaml()) # plot the graph for test & debug graph.plot(f"{results}/serving/on_error.png") resp = server.test(body=[]) server.wait_for_completion() - assert ( - resp["error"] and resp["origin_state"] == "Raiser" - ), f"error wasnt caught, resp={resp}" + if isinstance(resp, dict): + assert ( + resp["error"] and resp["origin_state"] == "Raiser" + ), f"error wasn't caught, resp={resp}" + else: + assert ( + resp.error and resp.origin_state == "Raiser" + ), f"error wasn't caught, resp={resp}" def test_push_error(): @@ -118,7 +120,6 @@ def test_push_error(): server.error_stream = "dummy:///nothing" # Force an error inside push_error itself server._error_stream_object = _DummyStreamRaiser() - logger.info(graph.to_yaml()) server.test(body=[]) server.wait_for_completion() diff --git a/tests/serving/test_flow.py b/tests/serving/test_flow.py index 94a65196eac0..3caa895d87db 100644 --- a/tests/serving/test_flow.py +++ b/tests/serving/test_flow.py @@ -19,7 +19,6 @@ import mlrun from mlrun.serving import GraphContext, V2ModelServer from mlrun.serving.states import TaskStep -from mlrun.utils import logger from .demo_states import * # noqa @@ -70,7 +69,6 @@ def test_basic_flow(): server = fn.to_mock_server() # graph.plot("flow.png") - print("\nFlow1:\n", graph.to_yaml()) resp = server.test(body=[]) assert resp == ["s1", "s2", "s3"], "flow1 result is incorrect" @@ -82,7 +80,6 @@ def test_basic_flow(): graph.add_step(name="s3", class_name="Chain", after="s2") server = fn.to_mock_server() - logger.info(f"flow: {graph.to_yaml()}") resp = server.test(body=[]) assert resp == ["s1", "s2", "s3"], "flow2 result is incorrect" @@ -92,7 +89,6 @@ def test_basic_flow(): graph.add_step(name="s2", class_name="Chain", after="s1", before="s3") server = fn.to_mock_server() - logger.info(f"flow: {graph.to_yaml()}") resp = server.test(body=[]) assert resp == ["s1", "s2", "s3"], "flow3 result is incorrect" assert server.context.project == "x", "context.project was not set" @@ -122,7 +118,7 @@ def test_handler_with_context(): ) server = fn.to_mock_server() resp = server.test(body=5) - # expext 5 * 2 * 2 * 2 = 40 + # expect 5 * 2 * 2 * 2 = 40 assert resp == 40, f"got unexpected result {resp}" @@ -141,15 +137,16 @@ def test_on_error(): graph = fn.set_topology("flow", engine="sync") graph.add_step(name="s1", class_name="Chain") graph.add_step(name="raiser", class_name="Raiser", after="$prev").error_handler( - "catch" + name="catch", class_name="EchoError", full_event=True ) graph.add_step(name="s3", class_name="Chain", after="$prev") - graph.add_step(name="catch", class_name="EchoError").full_event = True server = fn.to_mock_server() - logger.info(f"flow: {graph.to_yaml()}") resp = server.test(body=[]) - assert resp["error"] and resp["origin_state"] == "raiser", "error wasnt caught" + if isinstance(resp, dict): + assert resp["error"] and resp["origin_state"] == "raiser", "error wasn't caught" + else: + assert resp.error and resp.origin_state == "raiser", "error wasn't caught" def return_type(event): @@ -205,7 +202,6 @@ def test_add_model(): graph = fn.set_topology("flow", engine="sync") graph.to("Echo", "e1").to("*", "router").to("Echo", "e2") fn.add_model("m1", class_name="ModelTestingClass", model_path=".") - print(graph.to_yaml()) assert "m1" in graph["router"].routes, "model was not added to router" @@ -214,7 +210,6 @@ def test_add_model(): graph = fn.set_topology("flow", engine="sync") graph.to("Echo", "e1").to("*", "r1").to("Echo", "e2").to("*", "r2") fn.add_model("m1", class_name="ModelTestingClass", model_path=".", router_step="r2") - print(graph.to_yaml()) assert "m1" in graph["r2"].routes, "model was not added to proper router" @@ -273,7 +268,6 @@ def test_path_control_routers(): "*", name="r1", input_path="x", result_path="y" ).to(name="s3", class_name="Echo").respond() function.add_model("m1", class_name="ModelClass", model_path=".") - logger.info(graph.to_yaml()) server = function.to_mock_server() resp = server.test("/v2/models/m1/infer", body={"x": {"inputs": [5]}}) @@ -292,7 +286,6 @@ def test_path_control_routers(): ).to(name="s3", class_name="Echo").respond() function.add_model("m1", class_name="ModelClassList", model_path=".", multiplier=10) function.add_model("m2", class_name="ModelClassList", model_path=".", multiplier=20) - logger.info(graph.to_yaml()) server = function.to_mock_server() resp = server.test("/v2/models/infer", body={"x": {"inputs": [[5]]}}) diff --git a/tests/serving/test_remote.py b/tests/serving/test_remote.py index 06ffdd89681f..e8e8b804402b 100644 --- a/tests/serving/test_remote.py +++ b/tests/serving/test_remote.py @@ -70,18 +70,25 @@ def test_remote_step(httpserver, engine): {"post": "ok"} ) url = httpserver.url_for("/") - for params, request, expected in tests_map: - print(f"test params: {params}") + print(f"test params: {params}, request: {request}, expected: {expected}") server = _new_server(url, engine, **params) - resp = server.test(**request) - server.wait_for_completion() + try: + resp = server.test(**request) + except Exception as e: + raise e + finally: + server.wait_for_completion() assert resp == expected # test with url generated with expression (from the event) server = _new_server(None, engine, method="GET", url_expression="event['myurl']") - resp = server.test(body={"myurl": httpserver.url_for("/foo")}) - server.wait_for_completion() + try: + resp = server.test(body={"myurl": httpserver.url_for("/foo")}) + except Exception as e: + raise e + finally: + server.wait_for_completion() assert resp == {"foo": "ok"} @@ -106,8 +113,12 @@ def test_remote_step_bad_status_code(httpserver, engine): for params, request, expected in tests_map: print(f"test params: {params}") server = _new_server(url, engine, **params) - resp = server.test(**request) - server.wait_for_completion() + try: + resp = server.test(**request) + except Exception as e: + raise e + finally: + server.wait_for_completion() assert resp == expected # test with url generated with expression (from the event) @@ -136,8 +147,12 @@ def test_remote_class(httpserver, engine): ).to(name="s3", handler="echo").respond() server = function.to_mock_server() - resp = server.test(body={"req": {"x": 5}}) - server.wait_for_completion() + try: + resp = server.test(body={"req": {"x": 5}}) + except Exception as e: + raise e + finally: + server.wait_for_completion() assert resp == {"req": {"x": 5}, "resp": {"cat": "ok"}} @@ -225,8 +240,12 @@ def test_remote_advance(httpserver, engine): ).to(name="s3", handler="echo").respond() server = function.to_mock_server() - resp = server.test(body={"req": {"url": "/dog", "data": {"x": 5}}}) - server.wait_for_completion() + try: + resp = server.test(body={"req": {"url": "/dog", "data": {"x": 5}}}) + except Exception as e: + raise e + finally: + server.wait_for_completion() assert resp == {"req": {"url": "/dog", "data": {"x": 5}}, "resp": {"post": "ok"}} diff --git a/tests/serving/test_serving.py b/tests/serving/test_serving.py index 114988eb30e7..c1bef0ce8236 100644 --- a/tests/serving/test_serving.py +++ b/tests/serving/test_serving.py @@ -16,6 +16,7 @@ import os import pathlib import time +import unittest.mock import pandas as pd import pytest @@ -239,7 +240,6 @@ def test_ensemble_get_models(): ) graph.routes = generate_test_routes("EnsembleModelTestingClass") server = fn.to_mock_server() - logger.info(f"flow: {graph.to_yaml()}") resp = server.test("/v2/models/") # expected: {"models": ["m1", "m2", "m3:v1", "m3:v2", "VotingEnsemble"], # "weights": None} @@ -256,7 +256,6 @@ def test_ensemble_get_metadata_of_models(): ) graph.routes = generate_test_routes("EnsembleModelTestingClass") server = fn.to_mock_server() - logger.info(f"flow: {graph.to_yaml()}") resp = server.test("/v2/models/m1") expected = {"name": "m1", "version": "", "inputs": [], "outputs": []} assert resp == expected, f"wrong get models response {resp}" @@ -499,7 +498,7 @@ def test_v2_explain(): assert data["outputs"]["explained"] == 5, f"wrong explain response {resp.body}" -def test_v2_get_modelmeta(): +def test_v2_get_modelmeta(rundb_mock): project = mlrun.new_project("tstsrv", save=False) fn = mlrun.new_function("tst", kind="serving") model_uri = _log_model(project) @@ -556,6 +555,8 @@ def test_v2_model_ready(): event = MockEvent("", path="/v2/models/m1/ready", method="GET") resp = context.mlrun_handler(context, event) assert resp.status_code == 200, f"didnt get proper ready resp {resp.body}" + resp_body = resp.body.decode("utf-8") + assert resp_body == f"Model m1 is ready (event_id = {event.id})" def test_v2_health(): @@ -588,12 +589,11 @@ def test_v2_mock(): def test_function(): fn = mlrun.new_function("tests", kind="serving") - graph = fn.set_topology("router") + fn.set_topology("router") fn.add_model("my", ".", class_name=ModelTestingClass(multiplier=100)) fn.set_tracking("dummy://") # track using the _DummyStream server = fn.to_mock_server() - logger.info(f"flow: {graph.to_yaml()}") resp = server.test("/v2/models/my/infer", testdata) # expected: source (5) * multiplier (100) assert resp["outputs"] == 5 * 100, f"wrong data response {resp}" @@ -714,3 +714,30 @@ def test_mock_invoke(): # return config valued mlrun.mlconf.mock_nuclio_deployment = mock_nuclio_config + + +def test_deploy_with_dashboard_argument(): + fn = mlrun.new_function("tests", kind="serving") + fn.add_model("my", ".", class_name=ModelTestingClass(multiplier=100)) + db_instance = fn._get_db() + db_instance.remote_builder = unittest.mock.Mock( + return_value={ + "data": { + "metadata": { + "name": "test", + }, + "status": { + "state": "ready", + "external_invocation_urls": ["http://test-url.com"], + }, + }, + }, + ) + db_instance.get_builder_status = unittest.mock.Mock( + return_value=(None, None), + ) + + mlrun.deploy_function(fn, dashboard="bad-address") + + # test that the remote builder was called even with dashboard argument + assert db_instance.remote_builder.call_count == 1 diff --git a/tests/system/api/assets/function.py b/tests/system/api/assets/function.py index ede3e5d724dd..c56fa92f8062 100644 --- a/tests/system/api/assets/function.py +++ b/tests/system/api/assets/function.py @@ -12,15 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. # -def secret_test_function(context, secrets: list = []): + + +def secret_test_function(context, secrets: list = None): """Validate that given secrets exists :param context: the MLRun context :param secrets: name of the secrets that we want to look at """ context.logger.info("running function") + secrets = secrets or [] for sec_name in secrets: sec_value = context.get_secret(sec_name) context.logger.info("Secret: {} ==> {}".format(sec_name, sec_value)) context.log_result(sec_name, sec_value) return True + + +def log_artifact_test_function(context, body_size: int = 1000, inline: bool = True): + """Logs artifact given its event body + :param context: the MLRun context + :param body_size: size of the artifact body + :param inline: whether to log the artifact body inline or not + """ + context.logger.info("running function") + body = b"a" * body_size + context.log_artifact("test", body=body, is_inline=inline) + context.logger.info("run complete!", body_len=len(body)) + return True diff --git a/tests/system/api/test_artifacts.py b/tests/system/api/test_artifacts.py new file mode 100644 index 000000000000..09ccac7b686f --- /dev/null +++ b/tests/system/api/test_artifacts.py @@ -0,0 +1,60 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pathlib + +import pytest + +import mlrun.common.schemas +import mlrun.errors +from tests.system.base import TestMLRunSystem + + +@TestMLRunSystem.skip_test_if_env_not_configured +class TestAPIArtifacts(TestMLRunSystem): + project_name = "db-system-test-project" + + @pytest.mark.enterprise + def test_fail_overflowing_artifact(self): + """ + Test that we fail when trying to (inline) log an artifact that is too big + This is done to ensure that we don't corrupt the DB while truncating the data + """ + filename = str(pathlib.Path(__file__).parent / "assets" / "function.py") + function = mlrun.code_to_function( + name="test-func", + project=self.project_name, + filename=filename, + handler="log_artifact_test_function", + kind="job", + image="mlrun/mlrun", + ) + task = mlrun.new_task() + + # run artifact field is MEDIUMBLOB which is limited to 16MB by mysql + # overflow and expect it to fail execution and not allow db to truncate the data + # to avoid data corruption + with pytest.raises(mlrun.runtimes.utils.RunError): + function.run( + task, params={"body_size": 16 * 1024 * 1024 + 1, "inline": True} + ) + + runs = mlrun.get_run_db().list_runs() + assert len(runs) == 1, "run should not be created" + run = runs[0] + assert run["status"]["state"] == "error", "run should fail" + assert ( + "Failed committing changes to DB" in run["status"]["error"] + ), "run should fail with a reason" diff --git a/tests/system/api/test_secrets.py b/tests/system/api/test_secrets.py index 6220d821acbc..33f0a0fee026 100644 --- a/tests/system/api/test_secrets.py +++ b/tests/system/api/test_secrets.py @@ -18,7 +18,7 @@ import deepdiff import pytest -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.errors from tests.system.base import TestMLRunSystem @@ -96,7 +96,7 @@ def test_k8s_project_secrets_using_api(self): def test_k8s_project_secrets_using_httpdb(self): secrets = {"secret1": "value1", "secret2": "value2"} - expected_results = mlrun.api.schemas.SecretKeysData( + expected_results = mlrun.common.schemas.SecretKeysData( provider="kubernetes", secret_keys=list(secrets.keys()) ) diff --git a/tests/system/base.py b/tests/system/base.py index c371902966a0..9aeef5a17076 100644 --- a/tests/system/base.py +++ b/tests/system/base.py @@ -21,7 +21,7 @@ import yaml from deepdiff import DeepDiff -import mlrun.api.schemas +import mlrun.common.schemas from mlrun import get_run_db, mlconf, set_environment from mlrun.utils import create_logger @@ -57,6 +57,7 @@ def setup_class(cls): cls._setup_env(cls._get_env_from_file()) cls._run_db = get_run_db() cls.custom_setup_class() + cls._logger = logger.get_child(cls.__name__.lower()) # the dbpath is already configured on the test startup before this stage # so even though we set the env var, we still need to directly configure @@ -68,7 +69,9 @@ def custom_setup_class(cls): pass def setup_method(self, method): - logger.info(f"Setting up test {self.__class__.__name__}::{method.__name__}") + self._logger.info( + f"Setting up test {self.__class__.__name__}::{method.__name__}" + ) self._setup_env(self._get_env_from_file()) self._run_db = get_run_db() @@ -79,7 +82,7 @@ def setup_method(self, method): self.custom_setup() - logger.info( + self._logger.info( f"Finished setting up test {self.__class__.__name__}::{method.__name__}" ) @@ -91,13 +94,15 @@ def _delete_test_project(self, name=None): if self._should_clean_resources(): self._run_db.delete_project( name or self.project_name, - deletion_strategy=mlrun.api.schemas.DeletionStrategy.cascading, + deletion_strategy=mlrun.common.schemas.DeletionStrategy.cascading, ) def teardown_method(self, method): - logger.info(f"Tearing down test {self.__class__.__name__}::{method.__name__}") + self._logger.info( + f"Tearing down test {self.__class__.__name__}::{method.__name__}" + ) - logger.debug("Removing test data from database") + self._logger.debug("Removing test data from database") if self._should_clean_resources(): fsets = self._run_db.list_feature_sets() if fsets: @@ -108,7 +113,7 @@ def teardown_method(self, method): self.custom_teardown() - logger.info( + self._logger.info( f"Finished tearing down test {self.__class__.__name__}::{method.__name__}" ) @@ -187,7 +192,7 @@ def _get_env_from_file(cls) -> dict: @classmethod def _setup_env(cls, env: dict): - logger.debug("Setting up test environment") + cls._logger.debug("Setting up test environment") cls._test_env.update(env) # save old env vars for returning them on teardown @@ -203,7 +208,7 @@ def _setup_env(cls, env: dict): @classmethod def _teardown_env(cls): - logger.debug("Tearing down test environment") + cls._logger.debug("Tearing down test environment") for env_var in cls._test_env: if env_var in os.environ: del os.environ[env_var] @@ -232,7 +237,7 @@ def _verify_run_spec( data_stores: list = None, scrape_metrics: bool = None, ): - logger.debug("Verifying run spec", spec=run_spec) + self._logger.debug("Verifying run spec", spec=run_spec) if parameters: self._assert_with_deepdiff(parameters, run_spec["parameters"]) if inputs: @@ -259,7 +264,7 @@ def _verify_run_metadata( labels: dict = None, iteration: int = None, ): - logger.debug("Verifying run metadata", spec=run_metadata) + self._logger.debug("Verifying run metadata", spec=run_metadata) if uid: assert run_metadata["uid"] == uid if name: @@ -285,11 +290,14 @@ def _verify_run_outputs( best_iteration: int = None, iteration_results: bool = False, ): - logger.debug("Verifying run outputs", spec=run_outputs) - assert run_outputs["model"].startswith(str(output_path)) - assert run_outputs["html_result"].startswith(str(output_path)) + self._logger.debug("Verifying run outputs", spec=run_outputs) assert run_outputs["chart"].startswith(str(output_path)) assert run_outputs["mydf"] == f"store://artifacts/{project}/{name}_mydf:{uid}" + assert run_outputs["model"] == f"store://artifacts/{project}/{name}_model:{uid}" + assert ( + run_outputs["html_result"] + == f"store://artifacts/{project}/{name}_html_result:{uid}" + ) if accuracy: assert run_outputs["accuracy"] == accuracy if loss: diff --git a/tests/system/conftest.py b/tests/system/conftest.py index fe7435fb9bc3..76f22b35739f 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -64,13 +64,18 @@ def post_report_session_finish_to_slack( session: Session, exitstatus: ExitCode, slack_webhook_url ): mlrun_version = os.getenv("MLRUN_VERSION", "") + mlrun_current_branch = os.getenv("MLRUN_SYSTEM_TESTS_BRANCH", "") mlrun_system_tests_component = os.getenv("MLRUN_SYSTEM_TESTS_COMPONENT", "") total_executed_tests = session.testscollected total_failed_tests = session.testsfailed + text = "" + if mlrun_current_branch: + text += f"[{mlrun_current_branch}] " + if exitstatus == ExitCode.OK: - text = f"All {total_executed_tests} tests passed successfully" + text += f"All {total_executed_tests} tests passed successfully" else: - text = f"{total_failed_tests} out of {total_executed_tests} tests failed" + text += f"{total_failed_tests} out of {total_executed_tests} tests failed" test_session_info = "" if mlrun_system_tests_component: diff --git a/tests/system/datastore/test_http.py b/tests/system/datastore/test_http.py new file mode 100644 index 000000000000..69a9291adf7e --- /dev/null +++ b/tests/system/datastore/test_http.py @@ -0,0 +1,41 @@ +# Copyright 2022 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import mlrun.datastore +from tests.system.base import TestMLRunSystem + + +class TestHttpDataStore(TestMLRunSystem): + def test_https_auth_token_with_env(self): + mlrun.mlconf.hub_url = ( + "https://raw.githubusercontent.com/mlrun/private-system-tests/" + ) + os.environ["HTTPS_AUTH_TOKEN"] = os.environ["MLRUN_SYSTEM_TESTS_GIT_TOKEN"] + func = mlrun.import_function( + "hub://support_private_hub_repo/func:main", + secrets=None, + ) + assert func.metadata.name == "func" + + def test_https_auth_token_with_secrets_flag(self): + mlrun.mlconf.hub_url = ( + "https://raw.githubusercontent.com/mlrun/private-system-tests/" + ) + secrets = {"HTTPS_AUTH_TOKEN": os.environ["MLRUN_SYSTEM_TESTS_GIT_TOKEN"]} + func = mlrun.import_function( + "hub://support_private_hub_repo/func:main", secrets=secrets + ) + assert func.metadata.name == "func" diff --git a/tests/system/demos/churn/assets/data_clean_function.py b/tests/system/demos/churn/assets/data_clean_function.py index f32b631b7be2..45cc2da4920b 100644 --- a/tests/system/demos/churn/assets/data_clean_function.py +++ b/tests/system/demos/churn/assets/data_clean_function.py @@ -63,7 +63,7 @@ def data_clean( TODO: * parallelize where possible * more abstraction (more parameters, chain sklearn transformers) - * convert to marketplace function + * convert to hub function :param context: the function execution context :param src: an artifact or file path diff --git a/tests/system/demos/churn/test_churn.py b/tests/system/demos/churn/test_churn.py index 06df66360295..9bfba810abc2 100644 --- a/tests/system/demos/churn/test_churn.py +++ b/tests/system/demos/churn/test_churn.py @@ -59,11 +59,11 @@ def create_demo_project(self) -> mlrun.projects.MlrunProject: self._logger.debug("Setting project functions") demo_project.set_function(clean_data_function) demo_project.set_function("hub://describe", "describe") - demo_project.set_function("hub://xgb_trainer", "classify") - demo_project.set_function("hub://xgb_test", "xgbtest") - demo_project.set_function("hub://coxph_trainer", "survive") - demo_project.set_function("hub://coxph_test", "coxtest") - demo_project.set_function("hub://churn_server", "server") + demo_project.set_function("hub://xgb-trainer", "classify") + demo_project.set_function("hub://xgb-test", "xgbtest") + demo_project.set_function("hub://coxph-trainer", "survive") + demo_project.set_function("hub://coxph-test", "coxtest") + demo_project.set_function("hub://churn-server", "server") self._logger.debug("Setting project workflow") demo_project.set_workflow( diff --git a/tests/system/demos/horovod/test_horovod.py b/tests/system/demos/horovod/test_horovod.py index 785a6ac17663..c19669d239b6 100644 --- a/tests/system/demos/horovod/test_horovod.py +++ b/tests/system/demos/horovod/test_horovod.py @@ -73,7 +73,7 @@ def create_demo_project(self) -> mlrun.projects.MlrunProject: trainer.spec.service_type = "NodePort" demo_project.set_function(trainer) - demo_project.set_function("hub://tf2_serving", "serving") + demo_project.set_function("hub://tf2-serving", "serving") demo_project.log_artifact( "images", diff --git a/tests/system/demos/sklearn/test_sklearn.py b/tests/system/demos/sklearn/test_sklearn.py index 1909a2101e3e..a71292b47163 100644 --- a/tests/system/demos/sklearn/test_sklearn.py +++ b/tests/system/demos/sklearn/test_sklearn.py @@ -52,9 +52,9 @@ def create_demo_project(self) -> mlrun.projects.MlrunProject: self._logger.debug("Setting project functions") demo_project.set_function(iris_generator_function) demo_project.set_function("hub://describe", "describe") - demo_project.set_function("hub://auto_trainer", "auto_trainer") - demo_project.set_function("hub://model_server", "serving") - demo_project.set_function("hub://model_server_tester", "live_tester") + demo_project.set_function("hub://auto-trainer", "auto-trainer") + demo_project.set_function("hub://model-server", "serving") + demo_project.set_function("hub://model-server-tester", "live-tester") self._logger.debug("Setting project workflow") demo_project.set_workflow( diff --git a/tests/system/feature_store/assets/fields_with_space.csv b/tests/system/feature_store/assets/fields_with_space.csv new file mode 100644 index 000000000000..47b0819acdb1 --- /dev/null +++ b/tests/system/feature_store/assets/fields_with_space.csv @@ -0,0 +1,6 @@ +name,city of birth +John,New York +Emma,London +Michael,Los Angeles +Sophia,Paris +David,Sydney diff --git a/tests/system/feature_store/assets/testdata_short.csv b/tests/system/feature_store/assets/testdata_short.csv new file mode 100644 index 000000000000..0690a58d7e72 --- /dev/null +++ b/tests/system/feature_store/assets/testdata_short.csv @@ -0,0 +1,4 @@ +id,name,number,float_number,date_of_birth +1,John,10,1.5,1990-01-01 +2,Jane,20,2.5,1995-05-10 +3,Bob,30,3.5,1985-12-15 diff --git a/tests/system/feature_store/expected_stats.py b/tests/system/feature_store/expected_stats.py index e1194ae931c1..8af9ba7d8800 100644 --- a/tests/system/feature_store/expected_stats.py +++ b/tests/system/feature_store/expected_stats.py @@ -50,7 +50,7 @@ "rr_is_error": 0.015789473684210527, "spo2": 98.77894736842106, "spo2_is_error": 0.015789473684210527, - "timestamp": "2020-12-01T17:28:31.695824", + "timestamp": "2020-12-01T17:28:31.695824+00:00", "turn_count": 1.3398340922970073, "turn_count_is_error": 0.015789473684210527, }, @@ -88,7 +88,7 @@ "rr_is_error": False, "spo2": 85.0, "spo2_is_error": False, - "timestamp": "2020-12-01T17:24:15.906352", + "timestamp": "2020-12-01T17:24:15.906352+00:00", "turn_count": 0.0, "turn_count_is_error": False, }, @@ -107,7 +107,7 @@ "rr_is_error": False, "spo2": 99.0, "spo2_is_error": False, - "timestamp": "2020-12-01T17:26:15.906352", + "timestamp": "2020-12-01T17:26:15.906352+00:00", "turn_count": 0.0, "turn_count_is_error": False, }, @@ -126,7 +126,7 @@ "rr_is_error": False, "spo2": 99.0, "spo2_is_error": False, - "timestamp": "2020-12-01T17:28:15.906352", + "timestamp": "2020-12-01T17:28:15.906352+00:00", "turn_count": 1.1724099011618052, "turn_count_is_error": False, }, @@ -145,7 +145,7 @@ "rr_is_error": False, "spo2": 99.0, "spo2_is_error": False, - "timestamp": "2020-12-01T17:31:15.906352", + "timestamp": "2020-12-01T17:31:15.906352+00:00", "turn_count": 2.951729964062169, "turn_count_is_error": False, }, @@ -164,7 +164,7 @@ "rr_is_error": True, "spo2": 99.0, "spo2_is_error": True, - "timestamp": "2020-12-01T17:33:15.906352", + "timestamp": "2020-12-01T17:33:15.906352+00:00", "turn_count": 3.0, "turn_count_is_error": True, }, @@ -474,26 +474,26 @@ "timestamp": [ [20, 0, 20, 0, 20, 0, 20, 0, 20, 0, 0, 20, 0, 20, 0, 20, 0, 20, 0, 10], [ - "2020-12-01T17:24:15.910000", - "2020-12-01T17:24:42.910000", - "2020-12-01T17:25:09.910000", - "2020-12-01T17:25:36.910000", - "2020-12-01T17:26:03.910000", - "2020-12-01T17:26:30.910000", - "2020-12-01T17:26:57.910000", - "2020-12-01T17:27:24.910000", - "2020-12-01T17:27:51.910000", - "2020-12-01T17:28:18.910000", - "2020-12-01T17:28:45.910000", - "2020-12-01T17:29:12.910000", - "2020-12-01T17:29:39.910000", - "2020-12-01T17:30:06.910000", - "2020-12-01T17:30:33.910000", - "2020-12-01T17:31:00.910000", - "2020-12-01T17:31:27.910000", - "2020-12-01T17:31:54.910000", - "2020-12-01T17:32:21.910000", - "2020-12-01T17:32:48.910000", + "2020-12-01T17:24:15.910000+00:00", + "2020-12-01T17:24:42.910000+00:00", + "2020-12-01T17:25:09.910000+00:00", + "2020-12-01T17:25:36.910000+00:00", + "2020-12-01T17:26:03.910000+00:00", + "2020-12-01T17:26:30.910000+00:00", + "2020-12-01T17:26:57.910000+00:00", + "2020-12-01T17:27:24.910000+00:00", + "2020-12-01T17:27:51.910000+00:00", + "2020-12-01T17:28:18.910000+00:00", + "2020-12-01T17:28:45.910000+00:00", + "2020-12-01T17:29:12.910000+00:00", + "2020-12-01T17:29:39.910000+00:00", + "2020-12-01T17:30:06.910000+00:00", + "2020-12-01T17:30:33.910000+00:00", + "2020-12-01T17:31:00.910000+00:00", + "2020-12-01T17:31:27.910000+00:00", + "2020-12-01T17:31:54.910000+00:00", + "2020-12-01T17:32:21.910000+00:00", + "2020-12-01T17:32:48.910000+00:00", ], ], "turn_count": [ diff --git a/tests/system/feature_store/test_feature_store.py b/tests/system/feature_store/test_feature_store.py index da2a3ffd20ef..1f4eb6f2b083 100644 --- a/tests/system/feature_store/test_feature_store.py +++ b/tests/system/feature_store/test_feature_store.py @@ -26,8 +26,10 @@ import fsspec import numpy as np import pandas as pd +import pyarrow import pyarrow.parquet as pq import pytest +import pytz import requests from pandas.util.testing import assert_frame_equal from storey import MapClass @@ -369,17 +371,13 @@ def test_get_offline_features_with_or_without_indexes(self): # with_indexes = False, entity_timestamp_column = None default_df = fstore.get_offline_features(vector).to_dataframe() - assert isinstance( - default_df.index, pd.core.indexes.range.RangeIndex - ), "index column is not of default type" - assert default_df.index.name is None, "index column is not of default type" - assert "time" not in default_df.columns, "'time' column shouldn't be present" - assert ( - "ticker" not in default_df.columns - ), "'ticker' column shouldn't be present" + assert isinstance(default_df.index, pd.core.indexes.range.RangeIndex) + assert default_df.index.name is None + assert "time" not in default_df.columns + assert "ticker" not in default_df.columns # with_indexes = False, entity_timestamp_column = "time" - resp = fstore.get_offline_features(vector, entity_timestamp_column="time") + resp = fstore.get_offline_features(vector) df_no_time = resp.to_dataframe() tmpdir = tempfile.mkdtemp() @@ -392,33 +390,21 @@ def test_get_offline_features_with_or_without_indexes(self): read_back_df = pd.read_csv(csv_path, parse_dates=[2]) assert read_back_df.equals(df_no_time) - assert isinstance( - df_no_time.index, pd.core.indexes.range.RangeIndex - ), "index column is not of default type" - assert df_no_time.index.name is None, "index column is not of default type" - assert "time" not in df_no_time.columns, "'time' column should not be present" - assert ( - "ticker" not in df_no_time.columns - ), "'ticker' column shouldn't be present" - assert ( - "another_time" in df_no_time.columns - ), "'another_time' column should be present" + assert isinstance(df_no_time.index, pd.core.indexes.range.RangeIndex) + assert df_no_time.index.name is None + assert "time" not in df_no_time.columns + assert "ticker" not in df_no_time.columns + assert "another_time" in df_no_time.columns # with_indexes = False, entity_timestamp_column = "invalid" - should return the timestamp column - df_with_time = fstore.get_offline_features( - vector, entity_timestamp_column="another_time" - ).to_dataframe() + df_without_time_and_indexes = fstore.get_offline_features(vector).to_dataframe() assert isinstance( - df_with_time.index, pd.core.indexes.range.RangeIndex - ), "index column is not of default type" - assert df_with_time.index.name is None, "index column is not of default type" - assert ( - "ticker" not in df_with_time.columns - ), "'ticker' column shouldn't be present" - assert "time" in df_with_time.columns, "'time' column should be present" - assert ( - "another_time" not in df_with_time.columns - ), "'another_time' column should not be present" + df_without_time_and_indexes.index, pd.core.indexes.range.RangeIndex + ) + assert df_without_time_and_indexes.index.name is None + assert "ticker" not in df_without_time_and_indexes.columns + assert "time" not in df_without_time_and_indexes.columns + assert "another_time" in df_without_time_and_indexes.columns vector.spec.with_indexes = True df_with_index = fstore.get_offline_features(vector).to_dataframe() @@ -778,7 +764,8 @@ def test_featureset_column_types(self): verify_ingest(data, key, targets=[TargetTypes.nosql]) verify_ingest(data, key, targets=[TargetTypes.nosql], infer=True) - def test_filtering_parquet_by_time(self): + @pytest.mark.parametrize("with_tz", [False, True]) + def test_filtering_parquet_by_time(self, with_tz): key = "patient_id" measurements = fstore.FeatureSet( "measurements", entities=[Entity(key)], timestamp_key="timestamp" @@ -786,8 +773,10 @@ def test_filtering_parquet_by_time(self): source = ParquetSource( "myparquet", path=os.path.relpath(str(self.assets_path / "testdata.parquet")), - start_time=datetime(2020, 12, 1, 17, 33, 15), - end_time="2020-12-01 17:33:16", + start_time=datetime( + 2020, 12, 1, 17, 33, 15, tzinfo=pytz.UTC if with_tz else None + ), + end_time="2020-12-01 17:33:16" + ("+00:00" if with_tz else ""), ) resp = fstore.ingest( @@ -801,8 +790,10 @@ def test_filtering_parquet_by_time(self): source = ParquetSource( "myparquet", path=os.path.relpath(str(self.assets_path / "testdata.parquet")), - start_time=datetime(2022, 12, 1, 17, 33, 15), - end_time="2022-12-01 17:33:16", + start_time=datetime( + 2022, 12, 1, 17, 33, 15, tzinfo=pytz.UTC if with_tz else None + ), + end_time="2022-12-01 17:33:16" + ("+00:00" if with_tz else ""), ) resp = fstore.ingest( @@ -846,20 +837,22 @@ def test_ingest_partitioned_by_key_and_time( f"{name}.*", ] vector = fstore.FeatureVector("myvector", features) - resp2 = fstore.get_offline_features( - vector, entity_timestamp_column="timestamp", with_indexes=True - ) + resp2 = fstore.get_offline_features(vector, with_indexes=True) resp2 = resp2.to_dataframe().to_dict() assert resp1 == resp2 + major_pyarrow_version = int(pyarrow.__version__.split(".")[0]) file_system = fsspec.filesystem("v3io") path = measurements.get_target_path("parquet") dataset = pq.ParquetDataset( - path, + path if major_pyarrow_version < 11 else path[len("v3io://") :], filesystem=file_system, ) - partitions = [key for key, _ in dataset.pieces[0].partition_keys] + if major_pyarrow_version < 11: + partitions = [key for key, _ in dataset.pieces[0].partition_keys] + else: + partitions = dataset.partitioning.schema.names if key_bucketing_number is None: expected_partitions = [] @@ -891,7 +884,7 @@ def test_ingest_partitioned_by_key_and_time( vector, start_time=datetime(2020, 12, 1, 17, 33, 15), end_time="2020-12-01 17:33:16", - entity_timestamp_column="timestamp", + timestamp_for_filtering="timestamp", ) resp2 = resp.to_dataframe() assert len(resp2) == 10 @@ -918,6 +911,9 @@ def test_passthrough_feature_set(self, engine): expected = source.to_dataframe().set_index("patient_id") + # The file is sorted by time. 10 is just an arbitrary number. + start_time = expected["timestamp"][10] + if engine != "pandas": # pandas engine does not support preview (ML-2694) preview_pd = fstore.preview( measurements_set, @@ -936,11 +932,12 @@ def test_passthrough_feature_set(self, engine): # verify that get_offline (and preview) equals the source vector = fstore.FeatureVector("myvector", features=[f"{name}.*"]) resp = fstore.get_offline_features( - vector, entity_timestamp_column="timestamp", with_indexes=True + vector, with_indexes=True, start_time=start_time ) get_offline_pd = resp.to_dataframe() - get_offline_pd["timestamp"] = pd.to_datetime(get_offline_pd["timestamp"]) + # check time filter with passthrough + expected = expected[(expected["timestamp"] > start_time)] assert_frame_equal(expected, get_offline_pd, check_like=True, check_dtype=False) # assert get_online correctness @@ -1047,9 +1044,7 @@ def test_ordered_pandas_asof_merge(self): feature_vector = fstore.FeatureVector( "test_fv", features, description="test FV" ) - res = fstore.get_offline_features( - feature_vector, entity_timestamp_column="time" - ) + res = fstore.get_offline_features(feature_vector) res = res.to_dataframe() assert res.shape[0] == left.shape[0] @@ -1067,9 +1062,7 @@ def test_left_not_ordered_pandas_asof_merge(self): feature_vector = fstore.FeatureVector( "test_fv", features, description="test FV" ) - res = fstore.get_offline_features( - feature_vector, entity_timestamp_column="time" - ) + res = fstore.get_offline_features(feature_vector) res = res.to_dataframe() assert res.shape[0] == left.shape[0] @@ -1087,44 +1080,38 @@ def test_right_not_ordered_pandas_asof_merge(self): feature_vector = fstore.FeatureVector( "test_fv", features, description="test FV" ) - res = fstore.get_offline_features( - feature_vector, entity_timestamp_column="time" - ) + res = fstore.get_offline_features(feature_vector) res = res.to_dataframe() assert res.shape[0] == left.shape[0] def test_read_csv(self): - from storey import CSVSource, ReduceToDataFrame, build_flow - - csv_path = str(self.results_path / _generate_random_name() / ".csv") - targets = [CSVTarget("mycsv", path=csv_path)] + source = CSVSource( + "mycsv", + path=os.path.relpath(str(self.assets_path / "testdata_short.csv")), + parse_dates=["date_of_birth"], + ) stocks_set = fstore.FeatureSet( - "tests", entities=[Entity("ticker", ValueType.STRING)] + "tests", entities=[Entity("id", ValueType.INT64)] ) - fstore.ingest( + result = fstore.ingest( stocks_set, - stocks, + source=source, infer_options=fstore.InferOptions.default(), - targets=targets, ) - - # reading csv file - final_path = stocks_set.get_target_path("mycsv") - controller = build_flow([CSVSource(final_path), ReduceToDataFrame()]).run() - termination_result = controller.await_termination() - expected = pd.DataFrame( { - 0: ["ticker", "MSFT", "GOOG", "AAPL"], - 1: ["name", "Microsoft Corporation", "Alphabet Inc", "Apple Inc"], - 2: ["exchange", "NASDAQ", "NASDAQ", "NASDAQ"], - } + "name": ["John", "Jane", "Bob"], + "number": [10, 20, 30], + "float_number": [1.5, 2.5, 3.5], + "date_of_birth": [ + datetime(1990, 1, 1), + datetime(1995, 5, 10), + datetime(1985, 12, 15), + ], + }, + index=pd.Index([1, 2, 3], name="id"), ) - - assert termination_result.equals( - expected - ), f"{termination_result}\n!=\n{expected}" - os.remove(final_path) + assert result.equals(expected) def test_multiple_entities(self): name = f"measurements_{uuid.uuid4()}" @@ -1223,7 +1210,7 @@ def test_offline_features_filter_non_partitioned(self): resp = fstore.get_offline_features( vector, - entity_timestamp_column="time_stamp", + timestamp_for_filtering="time_stamp", start_time="2021-06-09 09:30", end_time=datetime(2021, 6, 9, 10, 30), ) @@ -1284,7 +1271,7 @@ def test_filter_offline_multiple_featuresets(self): vector = fstore.FeatureVector("vector", features) resp = fstore.get_offline_features( vector, - entity_timestamp_column="time_stamp", + timestamp_for_filtering="time_stamp", start_time=datetime(2021, 6, 9, 9, 30), end_time=None, # will translate to now() ) @@ -1848,7 +1835,7 @@ def test_sync_pipeline_chunks(self, with_graph): self._logger.info(f"output df:\n{df}") reference_df = pd.read_csv(csv_file) - reference_df = reference_df[0:chunksize].set_index("patient_id") + reference_df = reference_df.set_index("patient_id") # patient_id (index) and timestamp (timestamp_key) are not in features list assert features + ["timestamp"] == list(reference_df.columns) @@ -2380,7 +2367,8 @@ def test_join_with_table(self): inner_join=True, ) df = fstore.ingest( - fset, df, targets=[], infer_options=fstore.InferOptions.default() + fset, + df, ) assert df.to_dict() == { "foreignkey1": {"mykey1": "AB", "mykey2": "DE"}, @@ -2423,7 +2411,7 @@ def test_directional_graph(self): attributes=["aug"], inner_join=True, ) - df = fstore.ingest(fset, df, targets=[]) + df = fstore.ingest(fset, df) assert df.to_dict() == { "foreignkey1": { "mykey1_1": "AB", @@ -2743,9 +2731,7 @@ def test_map_with_state_with_table(self): group_by_key=True, _fn="map_with_state_test_function", ) - df = fstore.ingest( - fset, df, targets=[], infer_options=fstore.InferOptions.default() - ) + df = fstore.ingest(fset, df) assert df.to_dict() == { "name": {"a": "a", "b": "b"}, "sum": {"a": 16, "b": 26}, @@ -3246,6 +3232,48 @@ def test_pandas_write_parquet(self): expected_df = pd.DataFrame({"number": [11, 22]}, index=["a", "b"]) assert read_back_df.equals(expected_df) + def test_pandas_write_partitioned_parquet(self): + prediction_set = fstore.FeatureSet( + name="myset", + entities=[fstore.Entity("id")], + timestamp_key="time", + engine="pandas", + ) + + df = pd.DataFrame( + { + "id": ["a", "b"], + "number": [11, 22], + "time": [pd.Timestamp(2022, 1, 1, 1), pd.Timestamp(2022, 1, 1, 1, 1)], + } + ) + + with tempfile.TemporaryDirectory() as tempdir: + outdir = f"{tempdir}/test_pandas_write_partitioned_parquet/" + prediction_set.set_targets( + with_defaults=False, targets=[(ParquetTarget(path=outdir))] + ) + + returned_df = fstore.ingest(prediction_set, df) + # check that partitions are created as expected (ML-3404) + read_back_df = pd.read_parquet( + f"{prediction_set.get_target_path()}year=2022/month=01/day=01/hour=01/" + ) + + assert read_back_df.equals(returned_df) + + expected_df = pd.DataFrame( + { + "number": [11, 22], + "time": [ + pd.Timestamp(2022, 1, 1, 1), + pd.Timestamp(2022, 1, 1, 1, 1), + ], + }, + index=["a", "b"], + ) + assert read_back_df.equals(expected_df) + # regression test for #2557 @pytest.mark.parametrize( ["index_columns"], @@ -3350,8 +3378,7 @@ def test_pandas_stats_include_index(self, index_columns): @pytest.mark.parametrize("with_indexes", [True, False]) @pytest.mark.parametrize("engine", ["local", "dask"]) - @pytest.mark.parametrize("join_type", ["inner", "outer"]) - def test_relation_join(self, engine, join_type, with_indexes): + def test_relation_join(self, engine, with_indexes): """Test 3 option of using get offline feature with relations""" engine_args = {} if engine == "dask": @@ -3409,7 +3436,6 @@ def test_relation_join(self, engine, join_type, with_indexes): join_employee_department = pd.merge( employees_with_department, departments, - how=join_type, left_on=["department_id"], right_on=["d_id"], suffixes=("_employees", "_departments"), @@ -3418,7 +3444,6 @@ def test_relation_join(self, engine, join_type, with_indexes): join_employee_managers = pd.merge( join_employee_department, managers, - how=join_type, left_on=["manager_id"], right_on=["m_id"], suffixes=("_manage", "_"), @@ -3427,7 +3452,6 @@ def test_relation_join(self, engine, join_type, with_indexes): join_employee_sets = pd.merge( employees_with_department, employees_with_class, - how=join_type, left_on=["id"], right_on=["id"], suffixes=("_employees", "_e_mini"), @@ -3436,7 +3460,6 @@ def test_relation_join(self, engine, join_type, with_indexes): _merge_step = pd.merge( join_employee_department, employees_with_class, - how=join_type, left_on=["id"], right_on=["id"], suffixes=("_", "_e_mini"), @@ -3445,7 +3468,6 @@ def test_relation_join(self, engine, join_type, with_indexes): join_all = pd.merge( _merge_step, classes, - how=join_type, left_on=["class_id"], right_on=["c_id"], suffixes=("_e_mini", "_cls"), @@ -3561,7 +3583,6 @@ def test_relation_join(self, engine, join_type, with_indexes): with_indexes=with_indexes, engine=engine, engine_args=engine_args, - join_type=join_type, order_by="name", ) if with_indexes: @@ -3586,7 +3607,6 @@ def test_relation_join(self, engine, join_type, with_indexes): with_indexes=with_indexes, engine=engine, engine_args=engine_args, - join_type=join_type, order_by="n", ) assert_frame_equal(join_employee_department, resp_1.to_dataframe()) @@ -3607,7 +3627,6 @@ def test_relation_join(self, engine, join_type, with_indexes): with_indexes=with_indexes, engine=engine, engine_args=engine_args, - join_type=join_type, order_by=["n"], ) assert_frame_equal(join_employee_managers, resp_2.to_dataframe()) @@ -3624,7 +3643,6 @@ def test_relation_join(self, engine, join_type, with_indexes): with_indexes=with_indexes, engine=engine, engine_args=engine_args, - join_type=join_type, order_by="name", ) assert_frame_equal(join_employee_sets, resp_3.to_dataframe()) @@ -3646,15 +3664,13 @@ def test_relation_join(self, engine, join_type, with_indexes): with_indexes=with_indexes, engine=engine, engine_args=engine_args, - join_type=join_type, order_by="n", ) assert_frame_equal(join_all, resp_4.to_dataframe()) @pytest.mark.parametrize("with_indexes", [True, False]) @pytest.mark.parametrize("engine", ["local", "dask"]) - @pytest.mark.parametrize("join_type", ["inner", "outer"]) - def test_relation_join_multi_entities(self, engine, join_type, with_indexes): + def test_relation_join_multi_entities(self, engine, with_indexes): engine_args = {} if engine == "dask": dask_cluster = mlrun.new_function( @@ -3690,7 +3706,6 @@ def test_relation_join_multi_entities(self, engine, join_type, with_indexes): join_employee_department = pd.merge( employees_with_department, departments, - how=join_type, left_on=["department_id", "department_name"], right_on=["d_id", "name"], suffixes=("_employees", "_departments"), @@ -3743,7 +3758,6 @@ def test_relation_join_multi_entities(self, engine, join_type, with_indexes): with_indexes=with_indexes, engine=engine, engine_args=engine_args, - join_type=join_type, order_by="n", ) assert_frame_equal(join_employee_department, resp_1.to_dataframe()) @@ -4021,6 +4035,202 @@ def test_ingest_with_steps_drop_features(self): ): fstore.ingest(measurements, source) + @pytest.mark.parametrize("engine", ["local", "dask"]) + def test_as_of_join_different_ts(self, engine): + engine_args = {} + if engine == "dask": + dask_cluster = mlrun.new_function( + "dask_tests", kind="dask", image="mlrun/ml-models" + ) + dask_cluster.apply(mlrun.mount_v3io()) + dask_cluster.spec.remote = True + dask_cluster.with_worker_requests(mem="2G") + dask_cluster.save() + engine_args = { + "dask_client": dask_cluster, + "dask_cluster_uri": dask_cluster.uri, + } + test_base_time = datetime.fromisoformat("2020-07-21T12:00:00+00:00") + + df_left = pd.DataFrame( + { + "ent": ["a", "b"], + "f1": ["a-val", "b-val"], + "ts_l": [test_base_time, test_base_time], + } + ) + + df_right = pd.DataFrame( + { + "ent": ["a", "a", "a", "b"], + "ts_r": [ + test_base_time - pd.Timedelta(minutes=1), + test_base_time - pd.Timedelta(minutes=2), + test_base_time - pd.Timedelta(minutes=3), + test_base_time - pd.Timedelta(minutes=2), + ], + "f2": ["newest", "middle", "oldest", "only-value"], + } + ) + + expected_df = pd.DataFrame( + { + "f1": ["a-val", "b-val"], + "f2": ["newest", "only-value"], + } + ) + + fset1 = fstore.FeatureSet("fs1-as-of", entities=["ent"], timestamp_key="ts_l") + fset2 = fstore.FeatureSet("fs2-as-of", entities=["ent"], timestamp_key="ts_r") + + fstore.ingest(fset1, df_left) + fstore.ingest(fset2, df_right) + + vec = fstore.FeatureVector("vec1", ["fs1-as-of.*", "fs2-as-of.*"]) + + resp = fstore.get_offline_features(vec, engine=engine, engine_args=engine_args) + res_df = resp.to_dataframe().sort_index(axis=1) + + assert_frame_equal(expected_df, res_df) + + @pytest.mark.parametrize("engine", ["local", "dask"]) + @pytest.mark.parametrize( + "timestamp_for_filtering", + [None, "other_ts", "bad_ts", {"fs1": "other_ts"}, {"fs1": "bad_ts"}], + ) + def test_time_and_columns_filter(self, engine, timestamp_for_filtering): + engine_args = {} + if engine == "dask": + dask_cluster = mlrun.new_function( + "dask_tests", kind="dask", image="mlrun/ml-models" + ) + dask_cluster.apply(mlrun.mount_v3io()) + dask_cluster.spec.remote = True + dask_cluster.with_worker_requests(mem="2G") + dask_cluster.save() + engine_args = { + "dask_client": dask_cluster, + "dask_cluster_uri": dask_cluster.uri, + } + test_base_time = datetime.fromisoformat("2020-07-21T12:00:00") + + df = pd.DataFrame( + { + "ent": ["a", "b", "c", "d"], + "ts_key": [ + test_base_time - pd.Timedelta(minutes=1), + test_base_time - pd.Timedelta(minutes=2), + test_base_time - pd.Timedelta(minutes=3), + test_base_time - pd.Timedelta(minutes=4), + ], + "other_ts": [ + test_base_time - pd.Timedelta(minutes=4), + test_base_time - pd.Timedelta(minutes=3), + test_base_time - pd.Timedelta(minutes=2), + test_base_time - pd.Timedelta(minutes=1), + ], + "val": [1, 2, 3, 4], + } + ) + + fset1 = fstore.FeatureSet("fs1", entities=["ent"], timestamp_key="ts_key") + + fstore.ingest(fset1, df) + + vec = fstore.FeatureVector("vec1", ["fs1.val"]) + if isinstance(timestamp_for_filtering, dict): + timestamp_for_filtering_str = timestamp_for_filtering["fs1"] + else: + timestamp_for_filtering_str = timestamp_for_filtering + if timestamp_for_filtering_str != "bad_ts": + resp = fstore.get_offline_features( + vec, + start_time=test_base_time - pd.Timedelta(minutes=3), + end_time=test_base_time, + timestamp_for_filtering=timestamp_for_filtering, + engine=engine, + engine_args=engine_args, + ) + res_df = resp.to_dataframe().sort_index(axis=1) + + if not timestamp_for_filtering_str: + assert res_df["val"].tolist() == [1, 2] + elif timestamp_for_filtering_str == "other_ts": + assert res_df["val"].tolist() == [3, 4] + assert res_df.columns == ["val"] + else: + with pytest.raises( + mlrun.errors.MLRunInvalidArgumentError, + match="Feature set `fs1` does not have a column named `bad_ts` to filter on.", + ): + fstore.get_offline_features( + vec, + start_time=test_base_time - pd.Timedelta(minutes=3), + end_time=test_base_time, + timestamp_for_filtering=timestamp_for_filtering, + engine=engine, + engine_args=engine_args, + ) + + # ML-3900 + def test_get_online_features_after_ingest_without_inference(self): + feature_set = fstore.FeatureSet( + "my-fset", + entities=[ + fstore.Entity("fn0"), + fstore.Entity( + "fn1", + value_type=mlrun.data_types.data_types.ValueType.STRING, + ), + ], + ) + + df = pd.DataFrame( + { + "fn0": [1, 2, 3, 4], + "fn1": [1, 2, 3, 4], + "fn2": [1, 1, 1, 1], + "fn3": [2, 2, 2, 2], + } + ) + + fstore.ingest(feature_set, df, infer_options=InferOptions.Null) + + features = ["my-fset.*"] + vector = fstore.FeatureVector("my-vector", features) + vector.save() + + with pytest.raises( + mlrun.errors.MLRunRuntimeError, + match="No features found for feature vector 'my-vector'", + ): + fstore.get_online_feature_service( + f"store://feature-vectors/{self.project_name}/my-vector:latest" + ) + + def test_ingest_with_rename_columns(self): + csv_path = str(self.assets_path / "fields_with_space.csv") + name = f"test_ingest_with_rename_columns_{uuid.uuid4()}" + data = pd.read_csv(csv_path) + expected_result = data.copy().rename(columns={"city of birth": "city_of_birth"}) + expected_result.set_index("name", inplace=True) + feature_set = fstore.FeatureSet( + name=name, + entities=[fstore.Entity("name")], + ) + fstore.preview( + feature_set, + data, + ) + inspect_result = fstore.ingest(feature_set, data) + feature_vector = fstore.FeatureVector( + name=name, features=[f"{self.project_name}/{name}.*"] + ) + feature_vector.spec.with_indexes = True + offline_features_df = fstore.get_offline_features(feature_vector).to_dataframe() + assert offline_features_df.equals(inspect_result) + assert offline_features_df.equals(expected_result) + def verify_purge(fset, targets): fset.reload(update_spec=False) diff --git a/tests/system/feature_store/test_google_big_query.py b/tests/system/feature_store/test_google_big_query.py index f955b7553991..9f269abd52a9 100644 --- a/tests/system/feature_store/test_google_big_query.py +++ b/tests/system/feature_store/test_google_big_query.py @@ -30,7 +30,7 @@ ) -def _resolve_google_credentials_json_path() -> typing.Optional[pathlib.Path]: +def resolve_google_credentials_json_path() -> typing.Optional[pathlib.Path]: default_path = pathlib.Path(CREDENTIALS_JSON_DEFAULT_PATH) if os.getenv(CREDENTIALS_ENV): return pathlib.Path(os.getenv(CREDENTIALS_ENV)) @@ -39,71 +39,26 @@ def _resolve_google_credentials_json_path() -> typing.Optional[pathlib.Path]: return None -def _are_google_credentials_not_set() -> bool: - # credentials_path = _resolve_google_credentials_json_path() - # return not credentials_path - - # Once issues with installation of packages - 'google-cloud-bigquery' and 'six' - will be fixed - # uncomment the above and let the tests run. - return True +def are_google_credentials_not_set() -> bool: + credentials_path = resolve_google_credentials_json_path() + return not credentials_path # Marked as enterprise because of v3io mount and pipelines @TestMLRunSystem.skip_test_if_env_not_configured @pytest.mark.skipif( - _are_google_credentials_not_set(), + are_google_credentials_not_set(), reason=f"Environment variable {CREDENTIALS_ENV} is not defined, and credentials file not in default path" f" {CREDENTIALS_JSON_DEFAULT_PATH}, skipping...", ) @pytest.mark.enterprise class TestFeatureStoreGoogleBigQuery(TestMLRunSystem): project_name = "fs-system-test-google-big-query" + max_results = 100 - def test_big_query_source_query(self): - max_results = 100 - query_string = f"select *\nfrom `bigquery-public-data.chicago_taxi_trips.taxi_trips`\nlimit {max_results}" - source = BigQuerySource( - "BigQuerySource", - query=query_string, - materialization_dataset="chicago_taxi_trips", - ) - self._test_big_query_source("query", source, max_results) - - def test_big_query_source_query_with_chunk_size(self): - max_results = 100 - query_string = f"select *\nfrom `bigquery-public-data.chicago_taxi_trips.taxi_trips`\nlimit {max_results * 2}" - source = BigQuerySource( - "BigQuerySource", - query=query_string, - materialization_dataset="chicago_taxi_trips", - chunksize=max_results, - ) - self._test_big_query_source("query_c", source, max_results) - - def test_big_query_source_table(self): - max_results = 100 - source = BigQuerySource( - "BigQuerySource", - table="bigquery-public-data.chicago_taxi_trips.taxi_trips", - max_results_for_table=max_results, - materialization_dataset="chicago_taxi_trips", - ) - self._test_big_query_source("table", source, max_results) - - def test_big_query_source_table_with_chunk_size(self): - max_results = 100 - source = BigQuerySource( - "BigQuerySource", - table="bigquery-public-data.chicago_taxi_trips.taxi_trips", - max_results_for_table=max_results * 2, - materialization_dataset="chicago_taxi_trips", - chunksize=max_results, - ) - self._test_big_query_source("table_c", source, max_results) - - @staticmethod - def _test_big_query_source(name: str, source: BigQuerySource, max_results: int): - credentials_path = _resolve_google_credentials_json_path() + @classmethod + def ingest_and_assert(cls, name: str, source: BigQuerySource): + credentials_path = resolve_google_credentials_json_path() os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(credentials_path) targets = [ @@ -120,8 +75,30 @@ def _test_big_query_source(name: str, source: BigQuerySource, max_results: int): timestamp_key="trip_start_timestamp", engine="pandas", ) - ingest_df = fstore.ingest(feature_set, source, targets, return_df=False) + ingest_df = fstore.ingest(feature_set, source, targets) assert ingest_df is not None - assert len(ingest_df) == max_results + assert len(ingest_df) == cls.max_results assert ingest_df.dtypes["pickup_latitude"] == "float64" assert ingest_df.dtypes["trip_seconds"] == pd.Int64Dtype() + + @pytest.mark.parametrize("chunksize", [None, 30]) + def test_big_query_source_query(self, chunksize): + query_string = f"select *\nfrom `bigquery-public-data.chicago_taxi_trips.taxi_trips`\nlimit {self.max_results}" + source = BigQuerySource( + "BigQuerySource", + query=query_string, + materialization_dataset="chicago_taxi_trips", + chunksize=chunksize, + ) + self.ingest_and_assert("query", source) + + @pytest.mark.parametrize("chunksize", [None, 50]) + def test_big_query_source_table(self, chunksize): + source = BigQuerySource( + "BigQuerySource", + table="bigquery-public-data.chicago_taxi_trips.taxi_trips", + max_results_for_table=self.max_results, + materialization_dataset="chicago_taxi_trips", + chunksize=chunksize, + ) + self.ingest_and_assert("table_c", source) diff --git a/tests/system/feature_store/test_spark_engine.py b/tests/system/feature_store/test_spark_engine.py index ebe347483f5a..2476c64121ee 100644 --- a/tests/system/feature_store/test_spark_engine.py +++ b/tests/system/feature_store/test_spark_engine.py @@ -15,6 +15,7 @@ import os import pathlib import sys +import tempfile import uuid from datetime import datetime @@ -53,11 +54,26 @@ # Marked as enterprise because of v3io mount and remote spark @pytest.mark.enterprise class TestFeatureStoreSparkEngine(TestMLRunSystem): + """ + This suite tests feature store functionality with the remote spark runtime (spark service). It does not test spark + operator. Make sure that, in env.yml, MLRUN_SYSTEM_TESTS_DEFAULT_SPARK_SERVICE is set to the name of a spark service + that exists on the remote system, or alternative set spark_service (below) to that name. + + To run the tests against code other than mlrun/mlrun@development, set test_branch below. + + After any tests have already run at least once, you may want to set spark_image_deployed=True (below) to avoid + rebuilding the image on subsequent runs, as it takes several minutes. + + It is also possible to run most tests in this suite locally if you have pyspark installed. To run locally, set + run_local=True. This can be very useful for debugging. + """ + project_name = "fs-system-spark-engine" spark_service = "" pq_source = "testdata.parquet" pq_target = "testdata_target" csv_source = "testdata.csv" + run_local = False spark_image_deployed = ( False # Set to True if you want to avoid the image building phase ) @@ -66,7 +82,10 @@ class TestFeatureStoreSparkEngine(TestMLRunSystem): @classmethod def _init_env_from_file(cls): env = cls._get_env_from_file() - cls.spark_service = env["MLRUN_SYSTEM_TESTS_DEFAULT_SPARK_SERVICE"] + if cls.run_local: + cls.spark_service = None + else: + cls.spark_service = env["MLRUN_SYSTEM_TESTS_DEFAULT_SPARK_SERVICE"] @classmethod def get_local_pq_source_path(cls): @@ -80,22 +99,18 @@ def get_remote_pq_source_path(cls, without_prefix=False): path += "/bigdata/" + cls.pq_source return path - def _print_full_df(self, df: pd.DataFrame, df_name: str, passthrough: str) -> None: + @classmethod + def get_pq_source_path(cls): + if cls.run_local: + return cls.get_local_pq_source_path() + else: + return cls.get_remote_pq_source_path() + + def _print_full_df(self, df: pd.DataFrame, df_name: str, passthrough: bool) -> None: with pd.option_context("display.max_rows", None, "display.max_columns", None): self._logger.info(f"{df_name}-passthrough_{passthrough}:") self._logger.info(df) - def get_remote_pq_target_path(self, without_prefix=False, clean_up=True): - path = "v3io://" - if without_prefix: - path = "" - path += "/bigdata/" + self.pq_target - if clean_up: - fsys = fsspec.filesystem(v3iofs.fs.V3ioFS.protocol) - for f in fsys.listdir(path): - fsys._rm(f["name"]) - return path - @classmethod def get_local_csv_source_path(cls): return os.path.relpath(str(cls.get_assets_path() / cls.csv_source)) @@ -108,14 +123,26 @@ def get_remote_csv_source_path(cls, without_prefix=False): path += "/bigdata/" + cls.csv_source return path + @classmethod + def get_csv_source_path(cls): + if cls.run_local: + return cls.get_local_csv_source_path() + else: + return cls.get_remote_csv_source_path() + @classmethod def custom_setup_class(cls): + cls._init_env_from_file() + + if not cls.run_local: + cls._setup_remote_run() + + @classmethod + def _setup_remote_run(cls): from mlrun import get_run_db from mlrun.run import new_function from mlrun.runtimes import RemoteSparkRuntime - cls._init_env_from_file() - store, _ = store_manager.get_or_create_store(cls.get_remote_pq_source_path()) store.upload( cls.get_remote_pq_source_path(without_prefix=True), @@ -143,20 +170,24 @@ def custom_setup_class(cls): cls.spark_image_deployed = True @staticmethod - def read_parquet_and_assert(out_path_spark, out_path_storey): + def is_path_spark_metadata(path): + return path.endswith("/_SUCCESS") or path.endswith(".crc") + + @classmethod + def read_parquet_and_assert(cls, out_path_spark, out_path_storey): read_back_df_spark = None - file_system = fsspec.filesystem("v3io") + file_system = fsspec.filesystem("file" if cls.run_local else "v3io") for file_entry in file_system.ls(out_path_spark): - filepath = file_entry["name"] - if not filepath.endswith("/_SUCCESS"): - read_back_df_spark = pd.read_parquet(f"v3io://{filepath}") + filepath = file_entry if cls.run_local else f'v3io://{file_entry["name"]}' + if not cls.is_path_spark_metadata(filepath): + read_back_df_spark = pd.read_parquet(filepath) break assert read_back_df_spark is not None read_back_df_storey = None for file_entry in file_system.ls(out_path_storey): - filepath = file_entry["name"] - read_back_df_storey = pd.read_parquet(f"v3io://{filepath}") + filepath = file_entry if cls.run_local else f'v3io://{file_entry["name"]}' + read_back_df_storey = pd.read_parquet(filepath) break assert read_back_df_storey is not None @@ -166,34 +197,87 @@ def read_parquet_and_assert(out_path_spark, out_path_storey): # spark does not support indexes, so we need to reset the storey result to match it read_back_df_storey.reset_index(inplace=True) - assert read_back_df_spark.sort_index(axis=1).equals( - read_back_df_storey.sort_index(axis=1) + pd.testing.assert_frame_equal( + read_back_df_spark, + read_back_df_storey, + check_categorical=False, + check_like=True, ) + @classmethod + def read_csv(cls, csv_path: str) -> pd.DataFrame: + file_system = fsspec.filesystem("file" if cls.run_local else "v3io") + if file_system.isdir(csv_path): + for file_entry in file_system.ls(csv_path): + filepath = ( + file_entry if cls.run_local else f'v3io://{file_entry["name"]}' + ) + if not cls.is_path_spark_metadata(filepath): + return pd.read_csv(filepath) + else: + return pd.read_csv(csv_path) + raise AssertionError(f"No files found in {csv_path}") + @staticmethod def read_csv_and_assert(csv_path_spark, csv_path_storey): - read_back_df_spark = None - file_system = fsspec.filesystem("v3io") - for file_entry in file_system.ls(csv_path_spark): - filepath = file_entry["name"] - if not filepath.endswith("/_SUCCESS"): - read_back_df_spark = pd.read_csv(f"v3io://{filepath}") - break - assert read_back_df_spark is not None - - read_back_df_storey = None - for file_entry in file_system.ls(csv_path_storey): - filepath = file_entry["name"] - read_back_df_storey = pd.read_csv(f"v3io://{filepath}") - break - assert read_back_df_storey is not None + read_back_df_spark = TestFeatureStoreSparkEngine.read_csv( + csv_path=csv_path_spark + ) + read_back_df_storey = TestFeatureStoreSparkEngine.read_csv( + csv_path=csv_path_storey + ) read_back_df_storey = read_back_df_storey.dropna(axis=1, how="all") read_back_df_spark = read_back_df_spark.dropna(axis=1, how="all") - assert read_back_df_spark.sort_index(axis=1).equals( - read_back_df_storey.sort_index(axis=1) - ) + pd.testing.assert_frame_equal( + read_back_df_storey, + read_back_df_spark, + check_categorical=False, + check_like=True, + ) + + def setup_method(self, method): + super().setup_method(method) + if self.run_local: + self._tmpdir = tempfile.TemporaryDirectory() + + def teardown_method(self, method): + super().teardown_method(method) + if self.run_local: + self._tmpdir.cleanup() + + def output_dir(self, url=True): + if self.run_local: + prefix = "file://" if url else "" + base_dir = f"{prefix}{self._tmpdir.name}" + else: + base_dir = f"v3io:///projects/{self.project_name}" + result = f"{base_dir}/spark-tests-output" + if self.run_local: + os.makedirs(result, exist_ok=True) + return result + + @staticmethod + def test_name(): + return ( + os.environ.get("PYTEST_CURRENT_TEST") + .split(":")[-1] + .split(" ")[0] + .replace("[", "__") + .replace("]", "") + ) + + def test_output_subdir_path(self, url=True): + return f"{self.output_dir(url=url)}/{self.test_name()}" + + def set_targets(self, feature_set, also_in_remote=False): + dir_name = self.test_name() + if self.run_local or also_in_remote: + target_path = f"{self.output_dir(url=False)}/{dir_name}" + feature_set.set_targets( + [ParquetTarget(path=target_path)], with_defaults=False + ) def test_basic_remote_spark_ingest(self): key = "patient_id" @@ -203,13 +287,14 @@ def test_basic_remote_spark_ingest(self): timestamp_key="timestamp", engine="spark", ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) + self.set_targets(measurements) fstore.ingest( measurements, source, return_df=True, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) assert measurements.status.targets[0].run_id is not None @@ -232,7 +317,7 @@ def test_basic_remote_spark_ingest_csv(self): measurements.graph.to(name="rename_column", handler="rename_column") source = CSVSource( "mycsv", - path=self.get_remote_csv_source_path(), + path=self.get_csv_source_path(), ) filename = str( pathlib.Path(sys.modules[self.__module__].__file__).absolute().parent @@ -240,8 +325,9 @@ def test_basic_remote_spark_ingest_csv(self): ) func = code_to_function("func", kind="remote-spark", filename=filename) run_config = fstore.RunConfig( - local=False, function=func, handler="ingest_handler" + local=self.run_local, function=func, handler="ingest_handler" ) + self.set_targets(measurements) fstore.ingest( measurements, source, @@ -277,13 +363,14 @@ def test_error_flow(self): df, return_df=True, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) def test_ingest_to_csv(self): key = "patient_id" - csv_path_spark = "v3io:///bigdata/test_ingest_to_csv_spark" - csv_path_storey = "v3io:///bigdata/test_ingest_to_csv_storey.csv" + base_path = self.test_output_subdir_path() + csv_path_spark = f"{base_path}_spark" + csv_path_storey = f"{base_path}_storey.csv" measurements = fstore.FeatureSet( "measurements_spark", @@ -291,14 +378,14 @@ def test_ingest_to_csv(self): timestamp_key="timestamp", engine="spark", ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [CSVTarget(name="csv", path=csv_path_spark)] fstore.ingest( measurements, source, targets, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) csv_path_spark = measurements.get_target_path(name="csv") @@ -307,7 +394,7 @@ def test_ingest_to_csv(self): entities=[fstore.Entity(key)], timestamp_key="timestamp", ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [CSVTarget(name="csv", path=csv_path_storey)] fstore.ingest( measurements, @@ -317,19 +404,17 @@ def test_ingest_to_csv(self): csv_path_storey = measurements.get_target_path(name="csv") read_back_df_spark = None - file_system = fsspec.filesystem("v3io") + file_system = fsspec.filesystem("file" if self.run_local else "v3io") for file_entry in file_system.ls(csv_path_spark): - filepath = file_entry["name"] - if not filepath.endswith("/_SUCCESS"): - read_back_df_spark = pd.read_csv(f"v3io://{filepath}") + filepath = file_entry if self.run_local else f'v3io://{file_entry["name"]}' + if not self.is_path_spark_metadata(filepath): + read_back_df_spark = pd.read_csv(filepath) break assert read_back_df_spark is not None - read_back_df_storey = None - for file_entry in file_system.ls(csv_path_storey): - filepath = file_entry["name"] - read_back_df_storey = pd.read_csv(f"v3io://{filepath}") - break + filepath = csv_path_storey if self.run_local else f"v3io://{csv_path_storey}" + read_back_df_storey = pd.read_csv(filepath) + assert read_back_df_storey is not None assert read_back_df_spark.sort_index(axis=1).equals( @@ -350,14 +435,14 @@ def test_ingest_to_redis(self): timestamp_key="timestamp", engine="spark", ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [RedisNoSqlTarget()] measurements.set_targets(targets, with_defaults=False) fstore.ingest( measurements, source, spark_context=self.spark_service, - run_config=fstore.RunConfig(False), + run_config=fstore.RunConfig(local=self.run_local), overwrite=True, ) # read the dataframe from the redis back @@ -384,6 +469,70 @@ def test_ingest_to_redis(self): } ] + @pytest.mark.skipif( + run_local, + reason="We don't normally have redis or v3io jars when running locally", + ) + @pytest.mark.parametrize( + "target_kind", + ["Redis", "v3io"] if mlrun.mlconf.redis.url else ["v3io"], + ) + def test_ingest_multiple_entities(self, target_kind): + key1 = "patient_id" + key2 = "bad" + key3 = "department" + name = "measurements_spark" + + measurements = fstore.FeatureSet( + name, + entities=[fstore.Entity(key1), fstore.Entity(key2), fstore.Entity(key3)], + timestamp_key="timestamp", + engine="spark", + ) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) + if target_kind == "Redis": + targets = [RedisNoSqlTarget()] + else: + targets = [NoSqlTarget()] + measurements.set_targets(targets, with_defaults=False) + + fstore.ingest( + measurements, + source, + spark_context=self.spark_service, + run_config=fstore.RunConfig(local=self.run_local), + overwrite=True, + ) + # read the dataframe + vector = fstore.FeatureVector("myvector", features=[f"{name}.*"]) + with fstore.get_online_feature_service(vector) as svc: + resp = svc.get( + [ + { + "patient_id": "305-90-1613", + "bad": 95, + "department": "01e9fe31-76de-45f0-9aed-0f94cc97bca0", + } + ] + ) + assert resp == [ + { + "room": 2, + "hr": 220.0, + "hr_is_error": False, + "rr": 25, + "rr_is_error": False, + "spo2": 99, + "spo2_is_error": False, + "movements": 4.614601941071927, + "movements_is_error": False, + "turn_count": 0.3582583538239813, + "turn_count_is_error": False, + "is_in_bed": 1, + "is_in_bed_is_error": False, + } + ] + @pytest.mark.skipif( not mlrun.mlconf.redis.url, reason="mlrun.mlconf.redis.url is not set, skipping until testing against real redis", @@ -398,14 +547,14 @@ def test_ingest_to_redis_numeric_index(self): timestamp_key="timestamp", engine="spark", ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [RedisNoSqlTarget()] measurements.set_targets(targets, with_defaults=False) fstore.ingest( measurements, source, spark_context=self.spark_service, - run_config=fstore.RunConfig(False), + run_config=fstore.RunConfig(local=self.run_local), overwrite=True, ) # read the dataframe from the redis back @@ -433,14 +582,17 @@ def test_ingest_to_redis_numeric_index(self): ] # tests that data is filtered by time in scheduled jobs + @pytest.mark.skipif(run_local, reason="Local scheduling is not supported") @pytest.mark.parametrize("partitioned", [True, False]) def test_schedule_on_filtered_by_time(self, partitioned): name = f"sched-time-{str(partitioned)}" now = datetime.now() - path = "v3io:///bigdata/bla.parquet" - fsys = fsspec.filesystem(v3iofs.fs.V3ioFS.protocol) + path = f"{self.output_dir()}/bla.parquet" + fsys = fsspec.filesystem( + "file" if self.run_local else v3iofs.fs.V3ioFS.protocol + ) pd.DataFrame( { "time": [ @@ -470,7 +622,7 @@ def test_schedule_on_filtered_by_time(self, partitioned): NoSqlTarget(), ParquetTarget( name="tar1", - path="v3io:///bigdata/fs1/", + path=f"{self.output_dir()}/fs1/", partitioned=True, partition_cols=["time"], ), @@ -478,7 +630,7 @@ def test_schedule_on_filtered_by_time(self, partitioned): else: targets = [ ParquetTarget( - name="tar2", path="v3io:///bigdata/fs2/", partitioned=False + name="tar2", path=f"{self.output_dir()}/fs2/", partitioned=False ), NoSqlTarget(), ] @@ -486,7 +638,7 @@ def test_schedule_on_filtered_by_time(self, partitioned): fstore.ingest( feature_set, source, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), targets=targets, spark_context=self.spark_service, ) @@ -516,7 +668,7 @@ def test_schedule_on_filtered_by_time(self, partitioned): fstore.ingest( feature_set, source, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), targets=targets, spark_context=self.spark_service, ) @@ -562,8 +714,10 @@ def test_aggregations(self): } ) - path = "v3io:///bigdata/test_aggregations.parquet" - fsys = fsspec.filesystem(v3iofs.fs.V3ioFS.protocol) + path = f"{self.output_dir(url=False)}/test_aggregations.parquet" + fsys = fsspec.filesystem( + "file" if self.run_local else v3iofs.fs.V3ioFS.protocol + ) df.to_parquet(path=path, filesystem=fsys) source = ParquetSource("myparquet", path=path) @@ -655,12 +809,12 @@ def test_aggregations(self): windows="1h", period="10m", ) - + self.set_targets(data_set) fstore.ingest( data_set, source, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) features = [ @@ -831,8 +985,12 @@ def test_aggregations_emit_every_event(self): } ) - path = "v3io:///bigdata/test_aggregations_emit_every_event.parquet" - fsys = fsspec.filesystem(v3iofs.fs.V3ioFS.protocol) + path = ( + f"{self.output_dir(url=False)}/test_aggregations_emit_every_event.parquet" + ) + fsys = fsspec.filesystem( + "file" if self.run_local else v3iofs.fs.V3ioFS.protocol + ) df.to_parquet(path=path, filesystem=fsys) source = ParquetSource("myparquet", path=path) @@ -852,12 +1010,12 @@ def test_aggregations_emit_every_event(self): period="10m", emit_policy=EmitEveryEvent(), ) - + self.set_targets(data_set) fstore.ingest( data_set, source, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) print(f"Results:\n{data_set.to_dataframe().sort_values('time').to_string()}\n") @@ -917,8 +1075,11 @@ def test_aggregations_emit_every_event(self): def test_mix_of_partitioned_and_nonpartitioned_targets(self): name = "test_mix_of_partitioned_and_nonpartitioned_targets" - path = "v3io:///bigdata/bla.parquet" - fsys = fsspec.filesystem(v3iofs.fs.V3ioFS.protocol) + path = f"{self.output_dir(url=False)}/bla.parquet" + url = f"{self.output_dir()}/bla.parquet" + fsys = fsspec.filesystem( + "file" if self.run_local else v3iofs.fs.V3ioFS.protocol + ) pd.DataFrame( { "time": [ @@ -932,7 +1093,7 @@ def test_mix_of_partitioned_and_nonpartitioned_targets(self): source = ParquetSource( "myparquet", - path=path, + path=url, ) feature_set = fstore.FeatureSet( @@ -942,8 +1103,8 @@ def test_mix_of_partitioned_and_nonpartitioned_targets(self): engine="spark", ) - partitioned_output_path = "v3io:///bigdata/partitioned/" - nonpartitioned_output_path = "v3io:///bigdata/nonpartitioned/" + partitioned_output_path = f"{self.output_dir()}/partitioned/" + nonpartitioned_output_path = f"{self.output_dir()}/nonpartitioned/" targets = [ ParquetTarget( name="tar1", @@ -958,7 +1119,7 @@ def test_mix_of_partitioned_and_nonpartitioned_targets(self): fstore.ingest( feature_set, source, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), targets=targets, spark_context=self.spark_service, ) @@ -979,8 +1140,10 @@ def test_mix_of_partitioned_and_nonpartitioned_targets(self): def test_write_empty_dataframe_overwrite_false(self): name = "test_write_empty_dataframe_overwrite_false" - path = "v3io:///bigdata/test_write_empty_dataframe_overwrite_false.parquet" - fsys = fsspec.filesystem(v3iofs.fs.V3ioFS.protocol) + path = f"{self.output_dir(url=False)}/test_write_empty_dataframe_overwrite_false.parquet" + fsys = fsspec.filesystem( + "file" if self.run_local else v3iofs.fs.V3ioFS.protocol + ) empty_df = pd.DataFrame( { "time": [ @@ -1006,14 +1169,14 @@ def test_write_empty_dataframe_overwrite_false(self): target = ParquetTarget( name="pq", - path="v3io:///bigdata/test_write_empty_dataframe_overwrite_false/", + path=f"{self.output_dir()}/{self.test_name()}/", partitioned=False, ) fstore.ingest( feature_set, source, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), targets=[ target, ], @@ -1028,8 +1191,12 @@ def test_write_empty_dataframe_overwrite_false(self): def test_write_dataframe_overwrite_false(self): name = "test_write_dataframe_overwrite_false" - path = "v3io:///bigdata/test_write_dataframe_overwrite_false.parquet" - fsys = fsspec.filesystem(v3iofs.fs.V3ioFS.protocol) + path = ( + f"{self.output_dir(url=False)}/test_write_dataframe_overwrite_false.parquet" + ) + fsys = fsspec.filesystem( + "file" if self.run_local else v3iofs.fs.V3ioFS.protocol + ) df = pd.DataFrame( { "time": [ @@ -1055,14 +1222,14 @@ def test_write_dataframe_overwrite_false(self): target = ParquetTarget( name="pq", - path="v3io:///bigdata/test_write_dataframe_overwrite_false/", + path=f"{self.output_dir()}/{self.test_name()}/", partitioned=False, ) fstore.ingest( feature_set, source, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), targets=[ target, ], @@ -1081,24 +1248,26 @@ def test_write_dataframe_overwrite_false(self): "should_succeed, is_parquet, is_partitioned, target_path", [ # spark - csv - fail for single file - (True, False, None, "v3io:///bigdata/dif-eng/csv"), - (False, False, None, "v3io:///bigdata/dif-eng/file.csv"), + (True, False, None, "dif-eng/csv"), + (False, False, None, "dif-eng/file.csv"), # spark - parquet - fail for single file - (True, True, True, "v3io:///bigdata/dif-eng/pq"), - (False, True, True, "v3io:///bigdata/dif-eng/file.pq"), - (True, True, False, "v3io:///bigdata/dif-eng/pq"), - (False, True, False, "v3io:///bigdata/dif-eng/file.pq"), + (True, True, True, "dif-eng/pq"), + (False, True, True, "dif-eng/file.pq"), + (True, True, False, "dif-eng/pq"), + (False, True, False, "dif-eng/file.pq"), ], ) def test_different_paths_for_ingest_on_spark_engines( self, should_succeed, is_parquet, is_partitioned, target_path ): + target_path = f"{self.output_dir()}/{target_path}" + fset = FeatureSet("fsname", entities=[Entity("ticker")], engine="spark") - source = ( - "v3io:///bigdata/test_different_paths_for_ingest_on_spark_engines.parquet" + source = f"{self.output_dir(url=False)}/test_different_paths_for_ingest_on_spark_engines.parquet" + fsys = fsspec.filesystem( + "file" if self.run_local else v3iofs.fs.V3ioFS.protocol ) - fsys = fsspec.filesystem(v3iofs.fs.V3ioFS.protocol) stocks.to_parquet(path=source, filesystem=fsys) source = ParquetSource( "myparquet", @@ -1114,7 +1283,7 @@ def test_different_paths_for_ingest_on_spark_engines( if should_succeed: fstore.ingest( fset, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), spark_context=self.spark_service, source=source, targets=[target], @@ -1131,6 +1300,13 @@ def test_different_paths_for_ingest_on_spark_engines( fstore.ingest(fset, source=source, targets=[target]) def test_error_is_properly_propagated(self): + if self.run_local: + import pyspark.sql.utils + + expected_error = pyspark.sql.utils.AnalysisException + else: + expected_error = mlrun.runtimes.utils.RunError + key = "patient_id" measurements = fstore.FeatureSet( "measurements", @@ -1139,13 +1315,13 @@ def test_error_is_properly_propagated(self): engine="spark", ) source = ParquetSource("myparquet", path="wrong-path.pq") - with pytest.raises(mlrun.runtimes.utils.RunError): + with pytest.raises(expected_error): fstore.ingest( measurements, source, return_df=True, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) # ML-3092 @@ -1158,12 +1334,13 @@ def test_get_offline_features_with_filter_and_indexes(self, timestamp_key): timestamp_key=timestamp_key, engine="spark", ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) + self.set_targets(measurements) fstore.ingest( measurements, source, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) assert measurements.status.targets[0].run_id is not None fv_name = "measurements-fv" @@ -1177,13 +1354,15 @@ def test_get_offline_features_with_filter_and_indexes(self, timestamp_key): ) my_fv.spec.with_indexes = True my_fv.save() - target = ParquetTarget("mytarget", path=self.get_remote_pq_target_path()) + target = ParquetTarget( + "mytarget", path=f"{self.output_dir()}-get_offline_features" + ) resp = fstore.get_offline_features( fv_name, target=target, query="bad>6 and bad<8", engine="spark", - run_config=fstore.RunConfig(local=False, kind="remote-spark"), + run_config=fstore.RunConfig(local=self.run_local, kind="remote-spark"), spark_service=self.spark_service, ) resp_df = resp.to_dataframe() @@ -1219,17 +1398,19 @@ def test_get_offline_features_with_spark_engine(self, passthrough, target_type): engine="spark", passthrough=passthrough, ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) + self.set_targets(measurements) fstore.ingest( measurements, source, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) - assert measurements.status.targets[0].run_id is not None + if not self.run_local: + assert measurements.status.targets[0].run_id is not None # assert that online target exist (nosql) and offline target does not (parquet) - if passthrough: + if passthrough and not self.run_local: assert len(measurements.status.targets) == 1 assert isinstance(measurements.status.targets["nosql"], DataTarget) @@ -1245,13 +1426,13 @@ def test_get_offline_features_with_spark_engine(self, passthrough, target_type): my_fv.save() target = target_type( "mytarget", - path="v3io:///bigdata/test_get_offline_features_with_spark_engine_testdata_target/", + path=f"{self.output_dir()}-get_offline_features", ) resp = fstore.get_offline_features( fv_name, target=target, query="bad>6 and bad<8", - run_config=fstore.RunConfig(local=False, kind="remote-spark"), + run_config=fstore.RunConfig(local=self.run_local, kind="remote-spark"), engine="spark", spark_service=self.spark_service, ) @@ -1270,8 +1451,9 @@ def test_get_offline_features_with_spark_engine(self, passthrough, target_type): def test_ingest_with_steps_drop_features(self): key = "patient_id" - csv_path_spark = "v3io:///bigdata/test_ingest_to_csv_spark" - csv_path_storey = "v3io:///bigdata/test_ingest_to_csv_storey.csv" + base_path = self.test_output_subdir_path() + csv_path_spark = f"{base_path}_spark" + csv_path_storey = f"{base_path}_storey.csv" measurements = fstore.FeatureSet( "measurements_spark", @@ -1280,14 +1462,14 @@ def test_ingest_with_steps_drop_features(self): engine="spark", ) measurements.graph.to(DropFeatures(features=["bad"])) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [CSVTarget(name="csv", path=csv_path_spark)] fstore.ingest( measurements, source, targets, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) csv_path_spark = measurements.get_target_path(name="csv") @@ -1297,7 +1479,7 @@ def test_ingest_with_steps_drop_features(self): timestamp_key="timestamp", ) measurements.graph.to(DropFeatures(features=["bad"])) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [CSVTarget(name="csv", path=csv_path_storey)] fstore.ingest( measurements, @@ -1314,7 +1496,7 @@ def test_ingest_with_steps_drop_features(self): engine="spark", ) measurements.graph.to(DropFeatures(features=[key])) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) key_as_set = {key} with pytest.raises( mlrun.errors.MLRunInvalidArgumentError, @@ -1324,13 +1506,14 @@ def test_ingest_with_steps_drop_features(self): measurements, source, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) def test_ingest_with_steps_onehot(self): key = "patient_id" - csv_path_spark = "v3io:///bigdata/test_ingest_to_csv_spark" - csv_path_storey = "v3io:///bigdata/test_ingest_to_csv_storey.csv" + base_path = self.test_output_subdir_path() + csv_path_spark = f"{base_path}_spark" + csv_path_storey = f"{base_path}_storey.csv" measurements = fstore.FeatureSet( "measurements_spark", @@ -1339,14 +1522,14 @@ def test_ingest_with_steps_onehot(self): engine="spark", ) measurements.graph.to(OneHotEncoder(mapping={"is_in_bed": [0, 1]})) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [CSVTarget(name="csv", path=csv_path_spark)] fstore.ingest( measurements, source, targets, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) csv_path_spark = measurements.get_target_path(name="csv") @@ -1356,7 +1539,7 @@ def test_ingest_with_steps_onehot(self): timestamp_key="timestamp", ) measurements.graph.to(OneHotEncoder(mapping={"is_in_bed": [0, 1]})) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [CSVTarget(name="csv", path=csv_path_storey)] fstore.ingest( measurements, @@ -1367,10 +1550,11 @@ def test_ingest_with_steps_onehot(self): self.read_csv_and_assert(csv_path_spark, csv_path_storey) @pytest.mark.parametrize("with_original_features", [True, False]) - def test_ingest_with_steps_mapval(self, with_original_features): + def test_ingest_with_steps_mapvalues(self, with_original_features): key = "patient_id" - csv_path_spark = "v3io:///bigdata/test_ingest_to_csv_spark" - csv_path_storey = "v3io:///bigdata/test_ingest_to_csv_storey.csv" + base_path = self.test_output_subdir_path() + csv_path_spark = f"{base_path}_spark" + csv_path_storey = f"{base_path}_storey.csv" measurements = fstore.FeatureSet( "measurements_spark", @@ -1387,14 +1571,14 @@ def test_ingest_with_steps_mapval(self, with_original_features): with_original_features=with_original_features, ) ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [CSVTarget(name="csv", path=csv_path_spark)] fstore.ingest( measurements, source, targets, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) csv_path_spark = measurements.get_target_path(name="csv") @@ -1412,7 +1596,7 @@ def test_ingest_with_steps_mapval(self, with_original_features): with_original_features=with_original_features, ) ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [CSVTarget(name="csv", path=csv_path_storey)] fstore.ingest( measurements, @@ -1422,11 +1606,83 @@ def test_ingest_with_steps_mapval(self, with_original_features): csv_path_storey = measurements.get_target_path(name="csv") self.read_csv_and_assert(csv_path_spark, csv_path_storey) + def test_mapvalues_with_partial_mapping(self): + # checks partial mapping -> only part of the values in field are replaced. + key = "patient_id" + csv_path_spark = self.test_output_subdir_path() + original_df = pd.read_parquet(self.get_pq_source_path()) + measurements = fstore.FeatureSet( + "measurements_spark", + entities=[fstore.Entity(key)], + timestamp_key="timestamp", + engine="spark", + ) + measurements.graph.to( + MapValues( + mapping={ + "bad": {17: -1}, + }, + with_original_features=True, + ) + ) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) + targets = [CSVTarget(name="csv", path=csv_path_spark)] + fstore.ingest( + measurements, + source, + targets, + spark_context=self.spark_service, + run_config=fstore.RunConfig(local=self.run_local), + ) + csv_path_spark = measurements.get_target_path(name="csv") + df = self.read_csv(csv_path=csv_path_spark) + assert not df.empty + assert not df["bad_mapped"].isna().any() + assert not df["bad_mapped"].isnull().any() + assert not (df["bad_mapped"] == 17).any() + # Note that there are no occurrences of -1 in the "bad" field of the original DataFrame. + assert len(df[df["bad_mapped"] == -1]) == len( + original_df[original_df["bad"] == 17] + ) + + def test_mapvalues_with_mixed_types(self): + key = "patient_id" + csv_path_spark = self.test_output_subdir_path() + measurements = fstore.FeatureSet( + "measurements_spark", + entities=[fstore.Entity(key)], + timestamp_key="timestamp", + engine="spark", + ) + measurements.graph.to( + MapValues( + mapping={ + "hr_is_error": {True: "1"}, + }, + with_original_features=True, + ) + ) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) + targets = [CSVTarget(name="csv", path=csv_path_spark)] + with pytest.raises( + mlrun.runtimes.utils.RunError, + match="^MapValues - mapping that changes column type must change all values accordingly," + " which is not the case for column 'hr_is_error'$", + ): + fstore.ingest( + measurements, + source, + targets, + spark_context=self.spark_service, + run_config=fstore.RunConfig(local=self.run_local), + ) + @pytest.mark.parametrize("timestamp_col", [None, "timestamp"]) def test_ingest_with_steps_extractor(self, timestamp_col): key = "patient_id" - out_path_spark = "v3io:///bigdata/test_ingest_with_steps_extractor_spark" - out_path_storey = "v3io:///bigdata/test_ingest_with_steps_extractor_storey" + base_path = self.test_output_subdir_path() + out_path_spark = f"{base_path}_spark" + out_path_storey = f"{base_path}_storey" measurements = fstore.FeatureSet( "measurements_spark", @@ -1440,14 +1696,14 @@ def test_ingest_with_steps_extractor(self, timestamp_col): timestamp_col=timestamp_col, ) ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [ParquetTarget(path=out_path_spark)] fstore.ingest( measurements, source, targets, spark_context=self.spark_service, - run_config=fstore.RunConfig(local=False), + run_config=fstore.RunConfig(local=self.run_local), ) out_path_spark = measurements.get_target_path() @@ -1462,7 +1718,7 @@ def test_ingest_with_steps_extractor(self, timestamp_col): timestamp_col=timestamp_col, ) ) - source = ParquetSource("myparquet", path=self.get_remote_pq_source_path()) + source = ParquetSource("myparquet", path=self.get_pq_source_path()) targets = [ParquetTarget(path=out_path_storey)] fstore.ingest( measurements, @@ -1474,8 +1730,7 @@ def test_ingest_with_steps_extractor(self, timestamp_col): self.read_parquet_and_assert(out_path_spark, out_path_storey) @pytest.mark.parametrize("with_indexes", [True, False]) - @pytest.mark.parametrize("join_type", ["inner", "outer"]) - def test_relation_join(self, join_type, with_indexes): + def test_relation_join(self, with_indexes): """Test 3 option of using get offline feature with relations""" departments = pd.DataFrame( { @@ -1519,7 +1774,6 @@ def test_relation_join(self, join_type, with_indexes): join_employee_department = pd.merge( employees_with_department, departments, - how=join_type, left_on=["department_id"], right_on=["d_id"], suffixes=("_employees", "_departments"), @@ -1528,7 +1782,6 @@ def test_relation_join(self, join_type, with_indexes): join_employee_managers = pd.merge( join_employee_department, managers, - how=join_type, left_on=["manager_id"], right_on=["m_id"], suffixes=("_manage", "_"), @@ -1537,7 +1790,6 @@ def test_relation_join(self, join_type, with_indexes): join_employee_sets = pd.merge( employees_with_department, employees_with_class, - how=join_type, left_on=["id"], right_on=["id"], suffixes=("_employees", "_e_mini"), @@ -1546,7 +1798,6 @@ def test_relation_join(self, join_type, with_indexes): _merge_step = pd.merge( join_employee_department, employees_with_class, - how=join_type, left_on=["id"], right_on=["id"], suffixes=("_", "_e_mini"), @@ -1555,7 +1806,6 @@ def test_relation_join(self, join_type, with_indexes): join_all = pd.merge( _merge_step, classes, - how=join_type, left_on=["class_id"], right_on=["c_id"], suffixes=("_e_mini", "_cls"), @@ -1619,7 +1869,7 @@ def test_relation_join(self, join_type, with_indexes): "managers", entities=[managers_set_entity], ) - managers_set.set_targets(targets=["parquet"], with_defaults=False) + self.set_targets(managers_set, also_in_remote=True) fstore.ingest(managers_set, managers) classes_set_entity = fstore.Entity("c_id") @@ -1627,7 +1877,7 @@ def test_relation_join(self, join_type, with_indexes): "classes", entities=[classes_set_entity], ) - managers_set.set_targets(targets=["parquet"], with_defaults=False) + self.set_targets(classes_set, also_in_remote=True) fstore.ingest(classes_set, classes) departments_set_entity = fstore.Entity("d_id") @@ -1636,7 +1886,7 @@ def test_relation_join(self, join_type, with_indexes): entities=[departments_set_entity], relations={"manager_id": managers_set_entity}, ) - departments_set.set_targets(targets=["parquet"], with_defaults=False) + self.set_targets(departments_set, also_in_remote=True) fstore.ingest(departments_set, departments) employees_set_entity = fstore.Entity("id") @@ -1645,7 +1895,7 @@ def test_relation_join(self, join_type, with_indexes): entities=[employees_set_entity], relations={"department_id": departments_set_entity}, ) - employees_set.set_targets(targets=["parquet"], with_defaults=False) + self.set_targets(employees_set, also_in_remote=True) fstore.ingest(employees_set, employees_with_department) mini_employees_set = fstore.FeatureSet( @@ -1656,7 +1906,7 @@ def test_relation_join(self, join_type, with_indexes): "class_id": classes_set_entity, }, ) - mini_employees_set.set_targets(targets=["parquet"], with_defaults=False) + self.set_targets(mini_employees_set, also_in_remote=True) fstore.ingest(mini_employees_set, employees_with_class) features = ["employees.name"] @@ -1666,15 +1916,16 @@ def test_relation_join(self, join_type, with_indexes): ) vector.save() - target = ParquetTarget("mytarget", path=self.get_remote_pq_target_path()) + target = ParquetTarget( + "mytarget", path=f"{self.output_dir()}-get_offline_features" + ) resp = fstore.get_offline_features( vector, target=target, with_indexes=with_indexes, - run_config=fstore.RunConfig(local=False, kind="remote-spark"), + run_config=fstore.RunConfig(local=self.run_local, kind="remote-spark"), engine="spark", spark_service=self.spark_service, - join_type=join_type, order_by="name", ) if with_indexes: @@ -1694,15 +1945,16 @@ def test_relation_join(self, join_type, with_indexes): ) vector.save() - target = ParquetTarget("mytarget", path=self.get_remote_pq_target_path()) + target = ParquetTarget( + "mytarget", path=f"{self.output_dir()}-get_offline_features" + ) resp_1 = fstore.get_offline_features( vector, target=target, with_indexes=with_indexes, - run_config=fstore.RunConfig(local=False, kind="remote-spark"), + run_config=fstore.RunConfig(local=self.run_local, kind="remote-spark"), engine="spark", spark_service=self.spark_service, - join_type=join_type, order_by="n", ) assert_frame_equal(join_employee_department, resp_1.to_dataframe()) @@ -1718,15 +1970,16 @@ def test_relation_join(self, join_type, with_indexes): ) vector.save() - target = ParquetTarget("mytarget", path=self.get_remote_pq_target_path()) + target = ParquetTarget( + "mytarget", path=f"{self.output_dir()}-get_offline_features" + ) resp_2 = fstore.get_offline_features( vector, target=target, with_indexes=with_indexes, - run_config=fstore.RunConfig(local=False, kind="remote-spark"), + run_config=fstore.RunConfig(local=self.run_local, kind="remote-spark"), engine="spark", spark_service=self.spark_service, - join_type=join_type, order_by=["n"], ) assert_frame_equal(join_employee_managers, resp_2.to_dataframe()) @@ -1738,15 +1991,16 @@ def test_relation_join(self, join_type, with_indexes): ) vector.save() - target = ParquetTarget("mytarget", path=self.get_remote_pq_target_path()) + target = ParquetTarget( + "mytarget", path=f"{self.output_dir()}-get_offline_features" + ) resp_3 = fstore.get_offline_features( vector, target=target, with_indexes=with_indexes, - run_config=fstore.RunConfig(local=False, kind="remote-spark"), + run_config=fstore.RunConfig(local=self.run_local, kind="remote-spark"), engine="spark", spark_service=self.spark_service, - join_type=join_type, order_by="name", ) assert_frame_equal(join_employee_sets, resp_3.to_dataframe()) @@ -1763,15 +2017,16 @@ def test_relation_join(self, join_type, with_indexes): ) vector.save() - target = ParquetTarget("mytarget", path=self.get_remote_pq_target_path()) + target = ParquetTarget( + "mytarget", path=f"{self.output_dir()}-get_offline_features" + ) resp_4 = fstore.get_offline_features( vector, target=target, with_indexes=with_indexes, - run_config=fstore.RunConfig(local=False, kind="remote-spark"), + run_config=fstore.RunConfig(local=self.run_local, kind="remote-spark"), engine="spark", spark_service=self.spark_service, - join_type=join_type, order_by="n", ) assert_frame_equal(join_all, resp_4.to_dataframe()) @@ -1825,7 +2080,7 @@ def test_relation_asof_join(self, with_indexes): departments_set = fstore.FeatureSet( "departments", entities=[departments_set_entity], timestamp_key="time" ) - departments_set.set_targets(targets=["parquet"], with_defaults=False) + self.set_targets(departments_set, also_in_remote=True) fstore.ingest(departments_set, departments) employees_set_entity = fstore.Entity("id") @@ -1835,7 +2090,7 @@ def test_relation_asof_join(self, with_indexes): relations={"department_id": departments_set_entity}, timestamp_key="time", ) - employees_set.set_targets(targets=["parquet"], with_defaults=False) + self.set_targets(employees_set, also_in_remote=True) fstore.ingest(employees_set, employees_with_department) features = ["employees.name as n", "departments.name as n2"] @@ -1844,12 +2099,14 @@ def test_relation_asof_join(self, with_indexes): "employees-vec", features, description="Employees feature vector" ) vector.save() - target = ParquetTarget("mytarget", path=self.get_remote_pq_target_path()) + target = ParquetTarget( + "mytarget", path=f"{self.output_dir()}-get_offline_features" + ) resp_1 = fstore.get_offline_features( vector, target=target, with_indexes=with_indexes, - run_config=fstore.RunConfig(local=False, kind="remote-spark"), + run_config=fstore.RunConfig(local=self.run_local, kind="remote-spark"), engine="spark", spark_service=self.spark_service, order_by=["n"], @@ -1860,7 +2117,8 @@ def test_relation_asof_join(self, with_indexes): resp_1.to_dataframe().sort_index(axis=1), ) - def test_as_of_join_result(self): + @pytest.mark.parametrize("ts_r", ["ts", "ts_r"]) + def test_as_of_join_result(self, ts_r): test_base_time = datetime.fromisoformat("2020-07-21T12:00:00+00:00") df_left = pd.DataFrame( @@ -1874,7 +2132,7 @@ def test_as_of_join_result(self): df_right = pd.DataFrame( { "ent": ["a", "a", "a", "b"], - "ts": [ + ts_r: [ test_base_time - pd.Timedelta(minutes=1), test_base_time - pd.Timedelta(minutes=2), test_base_time - pd.Timedelta(minutes=3), @@ -1884,54 +2142,147 @@ def test_as_of_join_result(self): } ) - left_path = "v3io:///bigdata/asof_join/df_left.parquet" - right_path = "v3io:///bigdata/asof_join/df_right.parquet" + expected_df = pd.DataFrame( + { + "f1": ["a-val", "b-val"], + "f2": ["newest", "only-value"], + } + ) + base_path = self.test_output_subdir_path(url=False) + left_path = f"{base_path}/df_left.parquet" + right_path = f"{base_path}/df_right.parquet" - fsys = fsspec.filesystem(v3iofs.fs.V3ioFS.protocol) + fsys = fsspec.filesystem( + "file" if self.run_local else v3iofs.fs.V3ioFS.protocol + ) + fsys.makedirs(base_path, exist_ok=True) df_left.to_parquet(path=left_path, filesystem=fsys) df_right.to_parquet(path=right_path, filesystem=fsys) fset1 = fstore.FeatureSet("fs1-as-of", entities=["ent"], timestamp_key="ts") - fset1.set_targets(["parquet"], with_defaults=False) - fset2 = fstore.FeatureSet("fs2-as-of", entities=["ent"], timestamp_key="ts") - fset2.set_targets(["parquet"], with_defaults=False) + self.set_targets(fset1, also_in_remote=True) + fset2 = fstore.FeatureSet("fs2-as-of", entities=["ent"], timestamp_key=ts_r) + self.set_targets(fset2, also_in_remote=True) - source_left = ParquetSource("pq1", path=left_path) - source_right = ParquetSource("pq2", path=right_path) + base_url = self.test_output_subdir_path() + left_url = f"{base_url}/df_left.parquet" + right_url = f"{base_url}/df_right.parquet" + + source_left = ParquetSource("pq1", path=left_url) + source_right = ParquetSource("pq2", path=right_url) fstore.ingest(fset1, source_left) fstore.ingest(fset2, source_right) - self._logger.info( - f"fset1 BEFORE LOCAL engine merger:\n {fset1.to_dataframe()}" - ) - self._logger.info( - f"fset2 BEFORE LOCAL engine merger:\n {fset2.to_dataframe()}" - ) - - vec = fstore.FeatureVector("vec1", ["fs1-as-of.*", "fs2-as-of.*"]) - - resp = fstore.get_offline_features(vec, engine="local") - local_engine_res = resp.to_dataframe().sort_index(axis=1) - - self._logger.info(f"fset1 AFTER LOCAL engine merger:\n {fset1.to_dataframe()}") - self._logger.info(f"fset2 AFTER LOCAL engine merger:\n {fset2.to_dataframe()}") - vec_for_spark = fstore.FeatureVector( "vec1-spark", ["fs1-as-of.*", "fs2-as-of.*"] ) - target = ParquetTarget("mytarget", path=self.get_remote_pq_target_path()) + target = ParquetTarget( + "mytarget", path=f"{self.output_dir()}-get_offline_features" + ) resp = fstore.get_offline_features( vec_for_spark, engine="spark", - run_config=fstore.RunConfig(local=False, kind="remote-spark"), + run_config=fstore.RunConfig(local=self.run_local, kind="remote-spark"), spark_service=self.spark_service, target=target, ) spark_engine_res = resp.to_dataframe().sort_index(axis=1) - self._logger.info(f"result of LOCAL engine merger:\n {local_engine_res}") - self._logger.info(f"result of SPARK engine merger:\n {spark_engine_res}") + assert_frame_equal(expected_df, spark_engine_res) + + @pytest.mark.parametrize( + "timestamp_for_filtering", + [None, "other_ts", "bad_ts", {"fs1": "other_ts"}, {"fs1": "bad_ts"}], + ) + @pytest.mark.parametrize("passthrough", [True, False]) + def test_time_filter(self, timestamp_for_filtering, passthrough): + test_base_time = datetime.fromisoformat("2020-07-21T12:00:00") + + df = pd.DataFrame( + { + "ent": ["a", "b", "c", "d"], + "ts_key": [ + test_base_time - pd.Timedelta(minutes=1), + test_base_time - pd.Timedelta(minutes=2), + test_base_time - pd.Timedelta(minutes=3), + test_base_time - pd.Timedelta(minutes=4), + ], + "other_ts": [ + test_base_time - pd.Timedelta(minutes=4), + test_base_time - pd.Timedelta(minutes=3), + test_base_time - pd.Timedelta(minutes=2), + test_base_time - pd.Timedelta(minutes=1), + ], + "val": [1, 2, 3, 4], + } + ) + + base_path = self.test_output_subdir_path(url=False) + path = f"{base_path}/df_for_filter.parquet" - assert spark_engine_res.shape == (2, 2) - assert local_engine_res.equals(spark_engine_res) + fsys = fsspec.filesystem( + "file" if self.run_local else v3iofs.fs.V3ioFS.protocol + ) + fsys.makedirs(base_path, exist_ok=True) + df.to_parquet(path=path, filesystem=fsys) + source = ParquetSource("pq1", path=path) + + fset1 = fstore.FeatureSet( + "fs1", entities=["ent"], timestamp_key="ts_key", passthrough=passthrough + ) + self.set_targets(fset1, also_in_remote=True) + + fstore.ingest(fset1, source) + + vec = fstore.FeatureVector("vec1", ["fs1.val"]) + + target = ParquetTarget( + "mytarget", path=f"{self.output_dir()}-get_offline_features" + ) + + if isinstance(timestamp_for_filtering, dict): + timestamp_for_filtering_str = timestamp_for_filtering["fs1"] + else: + timestamp_for_filtering_str = timestamp_for_filtering + if timestamp_for_filtering_str != "bad_ts": + resp = fstore.get_offline_features( + feature_vector=vec, + start_time=test_base_time - pd.Timedelta(minutes=3), + end_time=test_base_time, + timestamp_for_filtering=timestamp_for_filtering, + engine="spark", + run_config=fstore.RunConfig(local=self.run_local, kind="remote-spark"), + spark_service=self.spark_service, + target=target, + ) + res_df = resp.to_dataframe().sort_index(axis=1) + + if not timestamp_for_filtering_str: + assert res_df["val"].tolist() == [1, 2] + elif timestamp_for_filtering_str == "other_ts": + assert res_df["val"].tolist() == [3, 4] + + assert res_df.columns == ["val"] + else: + err = ( + mlrun.errors.MLRunInvalidArgumentError + if self.run_local + else mlrun.runtimes.utils.RunError + ) + with pytest.raises( + err, + match="Feature set `fs1` does not have a column named `bad_ts` to filter on.", + ): + fstore.get_offline_features( + feature_vector=vec, + start_time=test_base_time - pd.Timedelta(minutes=3), + end_time=test_base_time, + timestamp_for_filtering=timestamp_for_filtering, + engine="spark", + run_config=fstore.RunConfig( + local=self.run_local, kind="remote-spark" + ), + spark_service=self.spark_service, + target=target, + ) diff --git a/tests/system/feature_store/test_sql_db.py b/tests/system/feature_store/test_sql_db.py index 57ba4dbc71a7..ed9334d0ae03 100644 --- a/tests/system/feature_store/test_sql_db.py +++ b/tests/system/feature_store/test_sql_db.py @@ -99,12 +99,12 @@ def run_around_tests(self): engine.dispose() @pytest.mark.parametrize( - "source_name, key, time_fields", + "source_name, key, parse_dates", [("stocks", "ticker", None), ("trades", "ind", ["time"])], ) @pytest.mark.parametrize("fset_engine", ["pandas", "storey"]) def test_sql_source_basic( - self, source_name: str, key: str, time_fields: List[str], fset_engine: str + self, source_name: str, key: str, parse_dates: List[str], fset_engine: str ): from sqlalchemy_utils import create_database, database_exists @@ -123,7 +123,7 @@ def test_sql_source_basic( source = SQLSource( table_name=source_name, key_field=key, - time_fields=time_fields, + parse_dates=parse_dates, ) feature_set = fs.FeatureSet( @@ -162,7 +162,7 @@ def test_sql_source_with_step( source = SQLSource( table_name=source_name, key_field=key, - time_fields=["time"] if source_name == "quotes" else None, + parse_dates=["time"] if source_name == "quotes" else None, ) feature_set = fs.FeatureSet( f"fs-{source_name}", entities=[fs.Entity(key)], engine=fset_engine @@ -206,7 +206,7 @@ def test_sql_source_with_aggregation( ) # test source - source = SQLSource(table_name=source_name, key_field=key, time_fields=["time"]) + source = SQLSource(table_name=source_name, key_field=key, parse_dates=["time"]) feature_set = fs.FeatureSet(f"fs-{source_name}", entities=[fs.Entity(key)]) feature_set.add_aggregation( aggr_col, ["sum", "max"], "1h", "10m", name=f"{aggr_col}1" @@ -238,7 +238,7 @@ def test_sql_target_basic(self, target_name: str, key: str, fset_engine: str): create_table=True, schema=schema, primary_key_column=key, - time_fields=["time"], + parse_dates=["time"], ) feature_set = fs.FeatureSet( f"fs-{target_name}-tr", entities=[fs.Entity(key)], engine=fset_engine @@ -271,7 +271,7 @@ def test_sql_target_without_create( table_name=target_name, create_table=False, primary_key_column=key, - time_fields=["time"] if target_name == "trades" else None, + parse_dates=["time"] if target_name == "trades" else None, ) feature_set = fs.FeatureSet( f"fs-{target_name}-tr", entities=[fs.Entity(key)], engine=fset_engine @@ -299,7 +299,7 @@ def test_sql_get_online_feature_basic( create_table=True, schema=schema, primary_key_column=key, - time_fields=["time"], + parse_dates=["time"], ) feature_set = fs.FeatureSet( f"fs-{target_name}-tr", entities=[fs.Entity(key)], engine=fset_engine @@ -349,7 +349,7 @@ def test_sql_source_and_target_basic(self, name: str, key: str, fset_engine: str source = SQLSource( table_name=table_name, key_field=key, - time_fields=["time"] if name == "trades" else None, + parse_dates=["time"] if name == "trades" else None, ) target = SQLTarget( @@ -357,7 +357,7 @@ def test_sql_source_and_target_basic(self, name: str, key: str, fset_engine: str create_table=True, schema=schema, primary_key_column=key, - time_fields=["time"] if name == "trades" else None, + parse_dates=["time"] if name == "trades" else None, ) targets = [target] diff --git a/tests/system/model_monitoring/test_model_monitoring.py b/tests/system/model_monitoring/test_model_monitoring.py index 23348f4b750d..ab1c8e56e55a 100644 --- a/tests/system/model_monitoring/test_model_monitoring.py +++ b/tests/system/model_monitoring/test_model_monitoring.py @@ -28,12 +28,13 @@ import mlrun import mlrun.api.crud -import mlrun.api.schemas import mlrun.artifacts.model +import mlrun.common.model_monitoring as model_monitoring_constants +import mlrun.common.schemas import mlrun.feature_store -import mlrun.model_monitoring.constants as model_monitoring_constants import mlrun.utils -from mlrun.api.schemas import ( +from mlrun.common.model_monitoring import EndpointType, ModelMonitoringMode +from mlrun.common.schemas import ( ModelEndpoint, ModelEndpointMetadata, ModelEndpointSpec, @@ -42,7 +43,6 @@ from mlrun.errors import MLRunNotFoundError from mlrun.model import BaseMetadata from mlrun.runtimes import BaseRuntime -from mlrun.utils.model_monitoring import EndpointType from mlrun.utils.v3io_clients import get_frames_client from tests.system.base import TestMLRunSystem @@ -62,7 +62,7 @@ def test_clear_endpoint(self): db = mlrun.get_run_db() db.create_model_endpoint( - endpoint.metadata.project, endpoint.metadata.uid, endpoint + endpoint.metadata.project, endpoint.metadata.uid, endpoint.dict() ) endpoint_response = db.get_model_endpoint( @@ -88,14 +88,14 @@ def test_store_endpoint_update_existing(self): db.create_model_endpoint( project=endpoint.metadata.project, endpoint_id=endpoint.metadata.uid, - model_endpoint=endpoint, + model_endpoint=endpoint.dict(), ) endpoint_before_update = db.get_model_endpoint( project=endpoint.metadata.project, endpoint_id=endpoint.metadata.uid ) - assert endpoint_before_update.status.state is None + assert endpoint_before_update.status.state == "null" updated_state = "testing...testing...1 2 1 2" drift_status = "DRIFT_DETECTED" @@ -133,7 +133,7 @@ def test_store_endpoint_update_existing(self): def test_list_endpoints_on_empty_project(self): endpoints_out = mlrun.get_run_db().list_model_endpoints(self.project_name) - assert len(endpoints_out.endpoints) == 0 + assert len(endpoints_out) == 0 def test_list_endpoints(self): db = mlrun.get_run_db() @@ -145,13 +145,13 @@ def test_list_endpoints(self): for endpoint in endpoints_in: db.create_model_endpoint( - endpoint.metadata.project, endpoint.metadata.uid, endpoint + endpoint.metadata.project, endpoint.metadata.uid, endpoint.dict() ) endpoints_out = db.list_model_endpoints(self.project_name) in_endpoint_ids = set(map(lambda e: e.metadata.uid, endpoints_in)) - out_endpoint_ids = set(map(lambda e: e.metadata.uid, endpoints_out.endpoints)) + out_endpoint_ids = set(map(lambda e: e.metadata.uid, endpoints_out)) endpoints_intersect = in_endpoint_ids.intersection(out_endpoint_ids) assert len(endpoints_intersect) == number_of_endpoints @@ -176,32 +176,33 @@ def test_list_endpoints_filter(self): db.create_model_endpoint( endpoint_details.metadata.project, endpoint_details.metadata.uid, - endpoint_details, + endpoint_details.dict(), ) filter_model = db.list_model_endpoints(self.project_name, model="filterme") - assert len(filter_model.endpoints) == 1 - - filter_labels = db.list_model_endpoints( - self.project_name, labels=["filtermex=1"] - ) - assert len(filter_labels.endpoints) == 4 - - filter_labels = db.list_model_endpoints( - self.project_name, labels=["filtermex=1", "filtermey=2"] - ) - assert len(filter_labels.endpoints) == 4 - - filter_labels = db.list_model_endpoints( - self.project_name, labels=["filtermey=2"] - ) - assert len(filter_labels.endpoints) == 4 - - @staticmethod - def _get_auth_info() -> mlrun.api.schemas.AuthInfo: - return mlrun.api.schemas.AuthInfo( - data_session=os.environ.get("V3IO_ACCESS_KEY") - ) + assert len(filter_model) == 1 + + # TODO: Uncomment the following assertions once the KV labels filters is fixed. + # Following the implementation of supporting SQL store for model endpoints records, this table + # has static schema. That means, in order to keep the schema logic for both SQL and KV, + # it is not possible to add new label columns dynamically to the KV table. Therefore, the label filtering + # process for the KV should be updated accordingly. + # + + # filter_labels = db.list_model_endpoints( + # self.project_name, labels=["filtermex=1"] + # ) + # assert len(filter_labels) == 4 + # + # filter_labels = db.list_model_endpoints( + # self.project_name, labels=["filtermex=1", "filtermey=2"] + # ) + # assert len(filter_labels) == 4 + # + # filter_labels = db.list_model_endpoints( + # self.project_name, labels=["filtermey=2"] + # ) + # assert len(filter_labels) == 4 def _mock_random_endpoint(self, state: Optional[str] = None) -> ModelEndpoint: def random_labels(): @@ -211,7 +212,9 @@ def random_labels(): return ModelEndpoint( metadata=ModelEndpointMetadata( - project=self.project_name, labels=random_labels() + project=self.project_name, + labels=random_labels(), + uid=str(randint(1000, 5000)), ), spec=ModelEndpointSpec( function_uri=f"test/function_{randint(0, 100)}:v{randint(0, 100)}", @@ -253,7 +256,7 @@ def test_basic_model_monitoring(self): # Import the serving function from the function hub serving_fn = mlrun.import_function( - "hub://v2_model_server", project=self.project_name + "hub://v2-model-server", project=self.project_name ).apply(mlrun.auto_mount()) # enable model monitoring serving_fn.set_tracking() @@ -279,9 +282,8 @@ def test_basic_model_monitoring(self): # Deploy the function serving_fn.deploy() - # Simulating Requests + # Simulating valid requests iris_data = iris["data"].tolist() - t_end = monotonic() + simulation_time while monotonic() < t_end: data_point = choice(iris_data) @@ -290,19 +292,19 @@ def test_basic_model_monitoring(self): ) sleep(uniform(0.2, 1.1)) - # test metrics + # Test metrics endpoints_list = mlrun.get_run_db().list_model_endpoints( self.project_name, metrics=["predictions_per_second"] ) - assert len(endpoints_list.endpoints) == 1 + assert len(endpoints_list) == 1 - endpoint = endpoints_list.endpoints[0] + endpoint = endpoints_list[0] assert len(endpoint.status.metrics) > 0 - predictions_per_second = endpoint.status.metrics["predictions_per_second"] - assert predictions_per_second.name == "predictions_per_second" - - total = sum((m[1] for m in predictions_per_second.values)) + predictions_per_second = endpoint.status.metrics["real_time"][ + "predictions_per_second" + ] + total = sum((m[1] for m in predictions_per_second)) assert total > 0 @@ -365,8 +367,8 @@ def test_model_monitoring_with_regression(self): fv, target=mlrun.datastore.targets.ParquetTarget() ) - # Train the model using the auto trainer from the marketplace - train = mlrun.import_function("hub://auto_trainer", new_name="train") + # Train the model using the auto trainer from the hub + train = mlrun.import_function("hub://auto-trainer", new_name="train") train.deploy() model_class = "sklearn.linear_model.LinearRegression" model_name = "diabetes_model" @@ -396,7 +398,7 @@ def test_model_monitoring_with_regression(self): # Set the serving topology to simple model routing # with data enrichment and imputing from the feature vector - serving_fn = mlrun.import_function("hub://v2_model_server", new_name="serving") + serving_fn = mlrun.import_function("hub://v2-model-server", new_name="serving") serving_fn.set_topology( "router", mlrun.serving.routers.EnrichmentModelRouter( @@ -430,14 +432,11 @@ def test_model_monitoring_with_regression(self): # Validate a single endpoint endpoints_list = mlrun.get_run_db().list_model_endpoints(self.project_name) - assert len(endpoints_list.endpoints) == 1 + assert len(endpoints_list) == 1 # Validate monitoring mode - model_endpoint = endpoints_list.endpoints[0] - assert ( - model_endpoint.spec.monitoring_mode - == mlrun.api.schemas.ModelMonitoringMode.enabled.value - ) + model_endpoint = endpoints_list[0] + assert model_endpoint.spec.monitoring_mode == ModelMonitoringMode.enabled.value # Validate tracking policy batch_job = db.get_schedule( @@ -446,7 +445,7 @@ def test_model_monitoring_with_regression(self): assert batch_job.cron_trigger.hour == "*/3" # TODO: uncomment the following assertion once the auto trainer function - # from mlrun marketplace is upgraded to 1.0.8 + # from mlrun hub is upgraded to 1.0.8 # assert len(model_obj.spec.feature_stats) == len( # model_endpoint.spec.feature_names # ) + len(model_endpoint.spec.label_names) @@ -482,6 +481,8 @@ def test_model_monitoring_voting_ensemble(self): # 2 - deployment status of monitoring stream nuclio function # 3 - model endpoints types for both children and router # 4 - metrics and drift status per model endpoint + # 5 - invalid records are considered in the aggregated error count value + # 6 - KV schema file is generated as expected simulation_time = 120 # 120 seconds to allow tsdb batching @@ -512,7 +513,7 @@ def test_model_monitoring_voting_ensemble(self): # Import the serving function from the function hub serving_fn = mlrun.import_function( - "hub://v2_model_server", project=self.project_name + "hub://v2-model-server", project=self.project_name ).apply(mlrun.auto_mount()) serving_fn.set_topology( @@ -529,11 +530,10 @@ def test_model_monitoring_voting_ensemble(self): "sklearn_AdaBoostClassifier": "sklearn.ensemble.AdaBoostClassifier", } - # Import the auto trainer function from the marketplace (hub://) - train = mlrun.import_function("hub://auto_trainer") + # Import the auto trainer function from the hub (hub://) + train = mlrun.import_function("hub://auto-trainer") for name, pkg in model_names.items(): - # Run the function and specify input dataset path and some parameters (algorithm and label column name) train_run = train.run( name=name, @@ -584,6 +584,15 @@ def test_model_monitoring_voting_ensemble(self): # invoke the model before running the model monitoring batch job iris_data = iris["data"].tolist() + # Simulating invalid request + invalid_input = ["n", "s", "o", "-"] + with pytest.raises(RuntimeError): + serving_fn.invoke( + "v2/models/VotingEnsemble/infer", + json.dumps({"inputs": [invalid_input]}), + ) + + # Simulating valid requests t_end = monotonic() + simulation_time start_time = datetime.now(timezone.utc) data_sent = 0 @@ -602,6 +611,9 @@ def test_model_monitoring_voting_ensemble(self): # it can take ~1 minute for the batch pod to finish running sleep(60) + # Check that the KV schema has been generated as expected + self._check_kv_schema_file() + tsdb_path = f"/pipelines/{self.project_name}/model-endpoints/events/" client = get_frames_client( token=os.environ.get("V3IO_ACCESS_KEY"), @@ -614,25 +626,23 @@ def test_model_monitoring_voting_ensemble(self): self.project_name, top_level=True ) - assert len(top_level_endpoints.endpoints) == 1 - assert ( - top_level_endpoints.endpoints[0].status.endpoint_type == EndpointType.ROUTER - ) + assert len(top_level_endpoints) == 1 + assert top_level_endpoints[0].status.endpoint_type == EndpointType.ROUTER - children_list = top_level_endpoints.endpoints[0].status.children_uids + children_list = top_level_endpoints[0].status.children_uids assert len(children_list) == len(model_names) endpoints_children_list = mlrun.get_run_db().list_model_endpoints( self.project_name, uids=children_list ) - assert len(endpoints_children_list.endpoints) == len(model_names) - for child in endpoints_children_list.endpoints: + assert len(endpoints_children_list) == len(model_names) + for child in endpoints_children_list: assert child.status.endpoint_type == EndpointType.LEAF_EP # list model endpoints and perform analysis for each endpoint endpoints_list = mlrun.get_run_db().list_model_endpoints(self.project_name) - for endpoint in endpoints_list.endpoints: + for endpoint in endpoints_list: # Validate that the model endpoint record has been updated through the stream process assert endpoint.status.first_request != endpoint.status.last_request data = client.read( @@ -691,7 +701,142 @@ def test_model_monitoring_voting_ensemble(self): assert measure in drift_measures assert type(drift_measures[measure]) == float + # Validate error count value + assert endpoint.status.error_count == 1 + def _check_monitoring_building_state(self, base_runtime): # Check if model monitoring stream function is ready stat = mlrun.get_run_db().get_builder_status(base_runtime) assert base_runtime.status.state == "ready", stat + + def _check_kv_schema_file(self): + """Check that the KV schema has been generated as expected""" + + # Initialize V3IO client object that will be used to retrieve the KV schema + client = mlrun.utils.v3io_clients.get_v3io_client( + endpoint=mlrun.mlconf.v3io_api + ) + + # Get the schema raw object + schema_raw = client.object.get( + container="users", + path=f"pipelines/{self.project_name}/model-endpoints/endpoints/.#schema", + access_key=os.environ.get("V3IO_ACCESS_KEY"), + ) + + # Convert the content into a dict + schema = json.loads(schema_raw.body) + + # Validate the schema key value + assert schema["key"] == model_monitoring_constants.EventFieldType.UID + + # Create a new dictionary of field_name:field_type out of the schema dictionary + fields_dict = {item["name"]: item["type"] for item in schema["fields"]} + + # Validate the type of several keys + assert fields_dict["error_count"] == "long" + assert fields_dict["function_uri"] == "string" + assert fields_dict["endpoint_type"] == "string" + assert fields_dict["active"] == "boolean" + + +@TestMLRunSystem.skip_test_if_env_not_configured +@pytest.mark.enterprise +class TestModelMonitoringKafka(TestMLRunSystem): + """Deploy a basic iris model configured with kafka stream""" + + brokers = ( + os.environ["MLRUN_SYSTEM_TESTS_KAFKA_BROKERS"] + if "MLRUN_SYSTEM_TESTS_KAFKA_BROKERS" in os.environ + and os.environ["MLRUN_SYSTEM_TESTS_KAFKA_BROKERS"] + else None + ) + project_name = "pr-kafka-model-monitoring" + + @pytest.mark.timeout(300) + @pytest.mark.skipif( + not brokers, reason="MLRUN_SYSTEM_TESTS_KAFKA_BROKERS not defined" + ) + def test_model_monitoring_with_kafka_stream(self): + project = mlrun.get_run_db().get_project(self.project_name) + + iris = load_iris() + train_set = pd.DataFrame( + iris["data"], + columns=[ + "sepal_length_cm", + "sepal_width_cm", + "petal_length_cm", + "petal_width_cm", + ], + ) + + # Import the serving function from the function hub + serving_fn = mlrun.import_function( + "hub://v2_model_server", project=self.project_name + ).apply(mlrun.auto_mount()) + + model_name = "sklearn_RandomForestClassifier" + + # Upload the model through the projects API so that it is available to the serving function + project.log_model( + model_name, + model_dir=os.path.relpath(self.assets_path), + model_file="model.pkl", + training_set=train_set, + artifact_path=f"v3io:///projects/{project.metadata.name}", + ) + # Add the model to the serving function's routing spec + serving_fn.add_model( + model_name, + model_path=project.get_artifact_uri( + key=model_name, category="model", tag="latest" + ), + ) + + project.set_model_monitoring_credentials(stream_path=f"kafka://{self.brokers}") + + # enable model monitoring + serving_fn.set_tracking() + # Deploy the function + serving_fn.deploy() + + monitoring_stream_fn = project.get_function("model-monitoring-stream") + + function_config = monitoring_stream_fn.spec.config + + # Validate kakfa stream trigger configurations + assert function_config["spec.triggers.kafka"] + assert ( + function_config["spec.triggers.kafka"]["attributes"]["topics"][0] + == f"monitoring_stream_{self.project_name}" + ) + assert ( + function_config["spec.triggers.kafka"]["attributes"]["brokers"][0] + == self.brokers + ) + + import kafka + + # Validate that the topic exist as expected + consumer = kafka.KafkaConsumer(bootstrap_servers=[self.brokers]) + topics = consumer.topics() + assert f"monitoring_stream_{self.project_name}" in topics + + # Simulating Requests + iris_data = iris["data"].tolist() + + for i in range(100): + data_point = choice(iris_data) + serving_fn.invoke( + f"v2/models/{model_name}/infer", json.dumps({"inputs": [data_point]}) + ) + sleep(uniform(0.02, 0.03)) + + # Validate that the model endpoint metrics were updated as indication for the sanity of the flow + model_endpoint = mlrun.get_run_db().list_model_endpoints( + project=self.project_name + )[0] + + assert model_endpoint.status.metrics["generic"]["latency_avg_5m"] > 0 + assert model_endpoint.status.metrics["generic"]["predictions_count_5m"] > 0 diff --git a/tests/system/projects/assets/handler_workflow.py b/tests/system/projects/assets/handler_workflow.py new file mode 100644 index 000000000000..0464d2c5b10d --- /dev/null +++ b/tests/system/projects/assets/handler_workflow.py @@ -0,0 +1,22 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from kfp import dsl + +funcs = {} + + +@dsl.pipeline(name="Demo training pipeline", description="Tests simple handler") +def job_pipeline(): + funcs["my-func"].as_step() diff --git a/tests/system/projects/assets/kflow.py b/tests/system/projects/assets/kflow.py index 87ef60c9083b..522f6926aa38 100644 --- a/tests/system/projects/assets/kflow.py +++ b/tests/system/projects/assets/kflow.py @@ -42,7 +42,7 @@ def kfpipeline(model_class=default_pkg_class, build=0): # train the model using a library (hub://) function and the generated data # no need to define handler in this step because the train function is the default handler - train = funcs["auto_trainer"].as_step( + train = funcs["auto-trainer"].as_step( name="train", inputs={"dataset": prep_data.outputs["cleaned_data"]}, params={ @@ -53,7 +53,7 @@ def kfpipeline(model_class=default_pkg_class, build=0): ) # test the model using a library (hub://) function and the generated model - funcs["auto_trainer"].as_step( + funcs["auto-trainer"].as_step( name="test", handler="evaluate", params={"label_columns": "label", "model": train.outputs["model"]}, diff --git a/tests/system/projects/assets/newflow.py b/tests/system/projects/assets/newflow.py index 6e0e9a96ec08..67b1dc69b70f 100644 --- a/tests/system/projects/assets/newflow.py +++ b/tests/system/projects/assets/newflow.py @@ -52,7 +52,7 @@ def newpipe(): # train with hyper-paremeters train = run_function( - "auto_trainer", + "auto-trainer", name="train", params={"label_columns": LABELS, "train_test_split_size": 0.10}, hyperparams={ @@ -70,7 +70,7 @@ def newpipe(): # test and visualize our model run_function( - "auto_trainer", + "auto-trainer", name="test", handler="evaluate", params={"label_columns": LABELS, "model": train.outputs["model"]}, @@ -87,7 +87,7 @@ def newpipe(): # test out new model server (via REST API calls), use imported function run_function( - "hub://v2_model_tester", + "hub://v2-model-tester", name="model-tester", params={"addr": deploy.outputs["endpoint"], "model": f"{DATASET}:v1"}, inputs={"table": train.outputs["test_set"]}, diff --git a/tests/system/projects/assets/sleep.py b/tests/system/projects/assets/sleep.py new file mode 100644 index 000000000000..3c49d566bccd --- /dev/null +++ b/tests/system/projects/assets/sleep.py @@ -0,0 +1,25 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import datetime +import time + + +def handler(context, time_to_sleep=1): + print("started", str(datetime.datetime.now())) + print(f"Sleeping for {time_to_sleep} seconds") + context.log_result("started", str(datetime.datetime.now())) + time.sleep(int(time_to_sleep)) + context.log_result("finished", str(datetime.datetime.now())) + print("finished", str(datetime.datetime.now())) diff --git a/tests/system/projects/assets/workflow.py b/tests/system/projects/assets/workflow.py new file mode 100644 index 000000000000..d80a604ed65a --- /dev/null +++ b/tests/system/projects/assets/workflow.py @@ -0,0 +1,27 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import mlrun + + +def kfpipeline(): + time_to_sleep = 60 + + step_1 = mlrun.run_function( + "func-1", params={"time_to_sleep": time_to_sleep}, outputs=["return"] + ) + + mlrun.run_function( + "func-2", params={"time_to_sleep": time_to_sleep}, outputs=["return"] + ).after(step_1) diff --git a/tests/system/projects/test_project.py b/tests/system/projects/test_project.py index 881bb20bbd07..23237d4de04c 100644 --- a/tests/system/projects/test_project.py +++ b/tests/system/projects/test_project.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import io import os import pathlib import re @@ -23,9 +24,9 @@ from kfp import dsl import mlrun +import mlrun.utils.logger from mlrun.artifacts import Artifact from mlrun.model import EntrypointParam -from mlrun.utils import logger from tests.conftest import out_path from tests.system.base import TestMLRunSystem @@ -45,7 +46,7 @@ def exec_project(args): @dsl.pipeline(name="test pipeline", description="test") def pipe_test(): # train the model using a library (hub://) function and the generated data - funcs["auto_trainer"].as_step( + funcs["auto-trainer"].as_step( name="train", inputs={"dataset": data_url}, params={"model_class": model_class, "label_columns": "label"}, @@ -59,14 +60,25 @@ def pipe_test(): class TestProject(TestMLRunSystem): project_name = "project-system-test-project" custom_project_names_to_delete = [] + _logger_redirected = False def custom_setup(self): pass def custom_teardown(self): + if self._logger_redirected: + mlrun.utils.logger.replace_handler_stream("default", sys.stdout) + self._logger_redirected = False + + self._logger.debug( + "Deleting custom projects", + num_projects_to_delete=len(self.custom_project_names_to_delete), + ) for name in self.custom_project_names_to_delete: self._delete_test_project(name) + self.custom_project_names_to_delete = [] + @property def assets_path(self): return ( @@ -86,8 +98,8 @@ def _create_project(self, project_name, with_repo=False, overwrite=False): with_repo=with_repo, ) proj.set_function("hub://describe") - proj.set_function("hub://auto_trainer", "auto_trainer") - proj.set_function("hub://v2_model_server", "serving") + proj.set_function("hub://auto-trainer", "auto-trainer") + proj.set_function("hub://v2-model-server", "serving") proj.set_artifact("data", Artifact(target_path=data_url)) proj.spec.params = {"label_columns": "label"} arg = EntrypointParam( @@ -122,8 +134,38 @@ def test_project_persists_function_changes(self): == commands ) + def test_build_function_image_usability(self): + func_name = "my-func" + fn = self.project.set_function( + str(self.assets_path / "handler.py"), + func_name, + kind="job", + image="mlrun/mlrun", + ) + + # redirect logger to capture logs and check for warnings + self._logger_redirected = True + _stdout = io.StringIO() + mlrun.utils.logger.replace_handler_stream("default", _stdout) + + # build function with image that has a protocol prefix + self.project.build_function( + fn, + image=f"https://{mlrun.config.config.httpdb.builder.docker_registry}/test/image:v3", + base_image="mlrun/mlrun", + commands=["echo 1"], + ) + out = _stdout.getvalue() + assert ( + "[warning] The image has an unexpected protocol prefix ('http://' or 'https://'). " + "If you wish to use the default configured registry, no protocol prefix is required " + "(note that you can also use '.' instead of the full URL " + "where is a placeholder). " + "Removing protocol prefix from image." in out + ) + def test_run(self): - name = "pipe1" + name = "pipe0" self.custom_project_names_to_delete.append(name) # create project in context self._create_project(name) @@ -186,7 +228,7 @@ def test_run_git_load(self): project2 = mlrun.load_project( project_dir, "git://github.com/mlrun/project-demo.git#main", name=name ) - logger.info("run pipeline from git") + self._logger.info("run pipeline from git") # run project, load source into container at runtime project2.spec.load_source_on_run = True @@ -204,13 +246,12 @@ def test_run_git_build(self): project2 = mlrun.load_project( project_dir, "git://github.com/mlrun/project-demo.git#main", name=name ) - logger.info("run pipeline from git") + self._logger.info("run pipeline from git") project2.spec.load_source_on_run = False run = project2.run( "main", artifact_path=f"v3io:///projects/{name}", arguments={"build": 1}, - workflow_path=str(self.assets_path / "kflow.py"), ) run.wait_for_completion() assert run.state == mlrun.run.RunStatuses.succeeded, "pipeline failed" @@ -242,7 +283,7 @@ def test_run_cli(self): project_dir, ] out = exec_project(args) - print(out) + self._logger.debug("executed project", out=out) # load the project from local dir and change a workflow project2 = mlrun.load_project(project_dir) @@ -250,7 +291,7 @@ def test_run_cli(self): project2.spec.workflows = {} project2.set_workflow("kf", "./kflow.py") project2.save() - print(project2.to_yaml()) + self._logger.debug("saved project", project2=project2.to_yaml()) # exec the workflow args = [ @@ -282,7 +323,7 @@ def test_cli_with_remote(self): project_dir, ] out = exec_project(args) - print(out) + self._logger.debug("executed project", out=out) # exec the workflow args = [ @@ -426,7 +467,7 @@ def _test_new_pipeline(self, name, engine): handler="iris_generator", requirements=["requests"], ) - print(project.to_yaml()) + self._logger.debug("set project function", project=project.to_yaml()) run = project.run( "newflow", engine=engine, @@ -494,6 +535,15 @@ def test_remote_pipeline_with_local_engine_from_github(self): local=True, ) + def test_non_existent_run_id_in_pipeline(self): + project_name = "default" + db = mlrun.get_run_db() + + with pytest.raises(mlrun.errors.MLRunNotFoundError): + db.get_pipeline( + "25811259-6d21-4caf-86e8-badc0ffee000", project=project_name + ) + def test_remote_from_archive(self): name = "pipe6" self.custom_project_names_to_delete.append(name) @@ -502,7 +552,7 @@ def test_remote_from_archive(self): project.export(archive_path) project.spec.source = archive_path project.save() - print(project.to_yaml()) + self._logger.debug("saved project", project=project.to_yaml()) run = project.run( "main", watch=True, @@ -511,6 +561,38 @@ def test_remote_from_archive(self): assert run.state == mlrun.run.RunStatuses.succeeded, "pipeline failed" assert run.run_id, "workflow's run id failed to fetch" + def test_kfp_from_local_code(self): + name = "kfp-from-local-code" + self.custom_project_names_to_delete.append(name) + + # change cwd to the current file's dir to make sure the handler file is found + current_file_abspath = os.path.abspath(__file__) + current_dirname = os.path.dirname(current_file_abspath) + os.chdir(current_dirname) + + project = mlrun.get_or_create_project(name, user_project=True, context="./") + + handler_fn = project.set_function( + func="./assets/handler.py", + handler="my_func", + name="my-func", + kind="job", + image="mlrun/mlrun", + ) + project.build_function(handler_fn) + + project.set_workflow( + "main", "./assets/handler_workflow.py", handler="job_pipeline" + ) + project.save() + + run = project.run( + "main", + watch=True, + ) + assert run.state == mlrun.run.RunStatuses.succeeded, "pipeline failed" + assert run.run_id, "workflow's run id failed to fetch" + def test_local_cli(self): # load project from git name = "lclclipipe" @@ -523,7 +605,7 @@ def test_local_cli(self): handler="iris_generator", ) project.save() - print(project.to_yaml()) + self._logger.debug("saved project", project=project.to_yaml()) # exec the workflow args = [ @@ -539,7 +621,7 @@ def test_local_cli(self): str(self.assets_path), ] out = exec_project(args) - print("OUT:\n", out) + self._logger.debug("executed project", out=out) assert ( out.find("pipeline run finished, state=Succeeded") != -1 ), "pipeline failed" @@ -564,11 +646,13 @@ def test_run_cli_watch_with_timeout(self): ] out = exec_project(args) - print("OUT:\n", out) + self._logger.debug("executed project", out=out) assert ( out.find( - "Exception: failed to execute command by the given deadline. last_exception: " - "pipeline run has not completed yet, function_name: get_pipeline_if_completed, timeout: 1" + "failed to execute command by the given deadline. " + "last_exception: pipeline run has not completed yet, " + "function_name: _wait_for_pipeline_completion, timeout: 1, " + "caused by: pipeline run has not completed yet" ) != -1 ) @@ -818,3 +902,153 @@ def test_remote_workflow_source(self): def _assert_scheduled(self, project_name, schedule_str): schedule = self._run_db.get_schedule(project_name, "main") assert schedule.scheduled_object["schedule"] == schedule_str + + def test_remote_workflow_source_with_subpath(self): + # Test running remote workflow when the project files are store in a relative path (the subpath) + project_source = "git://github.com/mlrun/system-tests.git#main" + project_context = "./test_subpath_remote" + project_name = "test-remote-workflow-source-with-subpath" + self.custom_project_names_to_delete.append(project_name) + project = mlrun.load_project( + context=project_context, + url=project_source, + subpath="./test_remote_workflow_subpath", + name=project_name, + ) + project.run("main", arguments={"x": 1}, engine="remote:kfp", watch=True) + + @pytest.mark.parametrize("pull_state_mode", ["disabled", "enabled"]) + def test_abort_step_in_workflow(self, pull_state_mode): + project_name = "test-abort-step" + self.custom_project_names_to_delete.append(project_name) + project = mlrun.new_project(project_name, context=str(self.assets_path)) + + # when pull_state mode is enabled it simulates the flow of wait_for_completion + mlrun.mlconf.httpdb.logs.pipelines.pull_state.mode = pull_state_mode + + code_path = str(self.assets_path / "sleep.py") + workflow_path = str(self.assets_path / "workflow.py") + + project.set_function( + name="func-1", + func=code_path, + kind="job", + image="mlrun/mlrun", + handler="handler", + ) + project.set_function( + name="func-2", + func=code_path, + kind="job", + image="mlrun/mlrun", + handler="handler", + ) + + def _assert_workflow_status(workflow, status): + assert workflow.state == status + + # set and run a two-step workflow in the project + project.set_workflow("main", workflow_path) + workflow = project.run("main", engine="kfp") + + mlrun.utils.retry_until_successful( + 1, + 20, + self._logger, + True, + _assert_workflow_status, + workflow, + mlrun.run.RunStatuses.running, + ) + + # obtain the first run in the workflow when it began running + runs = [] + while len(runs) != 1: + runs = project.list_runs( + labels=[f"workflow={workflow.run_id}"], state="running" + ) + + # abort the first workflow step + db = mlrun.get_run_db() + db.abort_run(runs.to_objects()[0].uid()) + + # when a step is aborted, assert that the entire workflow failed and did not continue + mlrun.utils.retry_until_successful( + 5, + 60, + self._logger, + True, + _assert_workflow_status, + workflow, + mlrun.run.RunStatuses.failed, + ) + + def test_project_build_image(self): + name = "test-build-image" + self.custom_project_names_to_delete.append(name) + project = mlrun.new_project(name, context=str(self.assets_path)) + + image_name = ".test-custom-image" + project.build_image( + image=image_name, + set_as_default=True, + with_mlrun=False, + base_image="mlrun/mlrun", + requirements=["vaderSentiment"], + commands=["echo 1"], + ) + + assert project.default_image == image_name + + # test with user provided function object + project.set_function( + str(self.assets_path / "sentiment.py"), + name="scores", + kind="job", + handler="handler", + ) + + run_result = project.run_function("scores", params={"text": "good morning"}) + assert run_result.output("score") + + def test_project_build_config_export_import(self): + # Verify that the build config is exported properly by the project, and a new project loaded from it + # can build default image directly without needing additional details. + + name_export = "test-build-image-export" + name_import = "test-build-image-import" + self.custom_project_names_to_delete.extend([name_export, name_import]) + + project = mlrun.new_project(name_export, context=str(self.assets_path)) + image_name = ".test-export-custom-image" + + project.build_config( + image=image_name, + set_as_default=True, + with_mlrun=False, + base_image="mlrun/mlrun", + requirements=["vaderSentiment"], + commands=["echo 1"], + ) + assert project.default_image == image_name + + project_dir = f"{projects_dir}/{name_export}" + proj_file_path = project_dir + "/project.yaml" + project.export(proj_file_path) + + new_project = mlrun.load_project(project_dir, name=name_import) + new_project.build_image() + + new_project.set_function( + str(self.assets_path / "sentiment.py"), + name="scores", + kind="job", + handler="handler", + ) + + run_result = new_project.run_function( + "scores", params={"text": "terrible evening"} + ) + assert run_result.output("score") + + shutil.rmtree(project_dir, ignore_errors=True) diff --git a/tests/system/runtimes/assets/function_with_kwargs.py b/tests/system/runtimes/assets/function_with_kwargs.py new file mode 100644 index 000000000000..e41da76da4b0 --- /dev/null +++ b/tests/system/runtimes/assets/function_with_kwargs.py @@ -0,0 +1,20 @@ +# Copyright 2023 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +def func(context, x, **kwargs): + context.logger.info(x) + context.logger.info(kwargs) + if not kwargs: + raise Exception("kwargs is empty") + return kwargs diff --git a/tests/system/runtimes/test_kubejob.py b/tests/system/runtimes/test_kubejob.py index 76828dad2949..e115e2a68111 100644 --- a/tests/system/runtimes/test_kubejob.py +++ b/tests/system/runtimes/test_kubejob.py @@ -272,6 +272,23 @@ def test_new_function_with_args(self): "val-with-artifact", ] + def test_function_with_kwargs(self): + code_path = str(self.assets_path / "function_with_kwargs.py") + mlrun.get_or_create_project(self.project_name, self.results_path) + + function = mlrun.code_to_function( + name="function-with-kwargs", + kind="job", + project=self.project_name, + filename=code_path, + image="mlrun/mlrun", + ) + kwargs = {"some_arg": "a-value-123", "another_arg": "another-value-456"} + params = {"x": "2"} + params.update(kwargs) + run = function.run(params=params, handler="func") + assert run.outputs["return"] == kwargs + def test_class_handler(self): code_path = str(self.assets_path / "kubejob_function.py") cases = [ diff --git a/tests/system/runtimes/test_mpijob.py b/tests/system/runtimes/test_mpijob.py index 58fbc7671f5d..6703b99850a1 100644 --- a/tests/system/runtimes/test_mpijob.py +++ b/tests/system/runtimes/test_mpijob.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import pytest + import mlrun import tests.system.base from mlrun.runtimes.constants import RunStates @@ -21,6 +23,10 @@ class TestMpiJobRuntime(tests.system.base.TestMLRunSystem): project_name = "does-not-exist-mpijob" + # TODO: This test is failing in the open source system tests due to a lack of resources + # (running in git action worker with limited resources). + # This mark should be removed if we shift to a new CE testing environment with adequate resources + @pytest.mark.enterprise def test_run_state_completion(self): code_path = str(self.assets_path / "mpijob_function.py") diff --git a/tests/system/runtimes/test_notifications.py b/tests/system/runtimes/test_notifications.py index 120181a85f12..45f1f8688e14 100644 --- a/tests/system/runtimes/test_notifications.py +++ b/tests/system/runtimes/test_notifications.py @@ -23,7 +23,7 @@ class TestNotifications(tests.system.base.TestMLRunSystem): def test_run_notifications(self): error_notification_name = "slack-should-fail" - success_notification_name = "console-should-succeed" + success_notification_name = "slack-should-succeed" def _assert_notifications(): runs = self._run_db.list_runs( @@ -31,11 +31,13 @@ def _assert_notifications(): with_notifications=True, ) assert len(runs) == 1 - assert len(runs[0]["spec"]["notifications"]) == 2 - for notification in runs[0]["spec"]["notifications"]: - if notification["name"] == error_notification.name: + assert len(runs[0]["status"]["notifications"]) == 2 + for notification_name, notification in runs[0]["status"][ + "notifications" + ].items(): + if notification_name == error_notification.name: assert notification["status"] == "error" - elif notification["name"] == success_notification.name: + elif notification_name == success_notification.name: assert notification["status"] == "sent" error_notification = self._create_notification( @@ -46,9 +48,13 @@ def _assert_notifications(): }, ) success_notification = self._create_notification( - kind="console", + kind="slack", name=success_notification_name, message="should-succeed", + params={ + # dummy slack test url should return 200 + "webhook": "https://slack.com/api/api.test", + }, ) function = mlrun.new_function( @@ -67,12 +73,113 @@ def _assert_notifications(): # the notifications are sent asynchronously, so we need to wait for them mlrun.utils.retry_until_successful( 1, - 20, + 40, self._logger, True, _assert_notifications, ) + def test_set_run_notifications(self): + + notification_name = "slack-should-succeed" + + def _assert_notification_was_sent(): + runs = self._run_db.list_runs( + project=self.project_name, + with_notifications=True, + ) + assert len(runs) == 1 + assert len(runs[0]["status"]["notifications"]) == 1 + assert ( + runs[0]["status"]["notifications"][notification_name]["status"] + == "sent" + ) + + self._create_sleep_func_in_project() + + notification = self._create_notification( + name=notification_name, + message="should-succeed", + params={ + # dummy slack test url should return 200 + "webhook": "https://slack.com/api/api.test", + }, + ) + + run = self.project.run_function( + "test-sleep", local=False, params={"time_to_sleep": 10} + ) + self._run_db.set_run_notifications( + self.project_name, run.metadata.uid, [notification] + ) + + run.wait_for_completion() + + # the notifications are sent asynchronously, so we need to wait for them + mlrun.utils.retry_until_successful( + 1, + 40, + self._logger, + True, + _assert_notification_was_sent, + ) + + def test_set_schedule_notifications(self): + + notification_name = "slack-notification" + schedule_name = "test-sleep" + + def _assert_notification_in_schedule(): + schedule = self._run_db.get_schedule( + self.project_name, schedule_name, include_last_run=True + ) + schedule_spec = schedule.scheduled_object["task"]["spec"] + last_run = schedule.last_run + assert "notifications" in schedule_spec + assert len(schedule_spec["notifications"]) == 1 + assert schedule_spec["notifications"][0]["name"] == notification_name + + runs = self._run_db.list_runs( + uid=last_run["metadata"]["uid"], + project=self.project_name, + with_notifications=True, + ) + assert len(runs) == 1 + assert len(runs[0]["status"]["notifications"]) == 1 + assert ( + runs[0]["status"]["notifications"][notification_name]["status"] + == "sent" + ) + + self._create_sleep_func_in_project() + + notification = self._create_notification( + name=notification_name, + message="should-succeed", + params={ + # dummy slack test url should return 200 + "webhook": "https://slack.com/api/api.test", + }, + ) + + self.project.run_function( + "test-sleep", + local=False, + params={"time_to_sleep": 1}, + schedule="* * * * *", + ) + self._run_db.set_schedule_notifications( + self.project_name, schedule_name, [notification] + ) + + mlrun.utils.retry_until_successful( + 1, + 2 * 60, # 2 schedule cycles, so at least one should run + self._logger, + True, + _assert_notification_in_schedule, + ) + @staticmethod def _create_notification( kind=None, @@ -92,3 +199,19 @@ def _create_notification( severity=severity or "info", params=params or {}, ) + + def _create_sleep_func_in_project(self): + + code_path = str(self.assets_path / "sleep.py") + + sleep_func = mlrun.code_to_function( + name="test-sleep", + kind="job", + project=self.project_name, + filename=code_path, + image="mlrun/mlrun", + ) + self.project.set_function(sleep_func) + self.project.sync_functions(save=True) + + return sleep_func diff --git a/tests/system/runtimes/test_nuclio.py b/tests/system/runtimes/test_nuclio.py index a8127dea6fec..9d5eae7eba16 100644 --- a/tests/system/runtimes/test_nuclio.py +++ b/tests/system/runtimes/test_nuclio.py @@ -48,11 +48,8 @@ def test_deploy_function_with_error_handler(self): ) graph = function.set_topology("flow", engine="async") - graph.to(name="step1", handler="inc") - graph.add_step(name="catcher", handler="catcher", full_event=True, after="") - - graph.error_handler("catcher") + graph.error_handler("catcher", handler="catcher", full_event=True) self._logger.debug("Deploying nuclio function") deployment = function.deploy() diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 000000000000..89b0301bc087 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,48 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pathlib + +import mlrun.projects +from mlrun.__main__ import load_notification + + +def test_add_notification_to_cli_from_file(): + input_file_path = str(pathlib.Path(__file__).parent / "assets/notification.json") + notifications = (f"file={input_file_path}",) + project = mlrun.projects.MlrunProject(name="test") + load_notification(notifications, project) + + assert ( + project._notifiers._async_notifications["slack"].params.get("webhook") + == "123456" + ) + assert ( + project._notifiers._sync_notifications["ipython"].params.get("webhook") + == "1234" + ) + + +def test_add_notification_to_cli_from_dict(): + notifications = ('{"slack":{"webhook":"123456"}}', '{"ipython":{"webhook":"1234"}}') + project = mlrun.projects.MlrunProject(name="test") + load_notification(notifications, project) + + assert ( + project._notifiers._async_notifications["slack"].params.get("webhook") + == "123456" + ) + assert ( + project._notifiers._sync_notifications["ipython"].params.get("webhook") + == "1234" + ) diff --git a/tests/test_code_to_func.py b/tests/test_code_to_func.py index 6e111ad38b87..5cc4c920ee34 100644 --- a/tests/test_code_to_func.py +++ b/tests/test_code_to_func.py @@ -14,10 +14,8 @@ from os import path -from mlrun import code_to_function, get_run_db, new_model_server -from mlrun.runtimes.function import compile_function_config -from mlrun.utils import get_in, parse_versioned_object_uri -from tests.conftest import examples_path, results, tests_root_directory +from mlrun import code_to_function, new_model_server +from tests.conftest import examples_path, results def test_job_nb(): @@ -48,24 +46,6 @@ def test_nuclio_nb_serving(): assert fn.spec.build.origin_filename == filename, "did not record filename" -def test_job_file(): - filename = f"{examples_path}/training.py" - fn = code_to_function(filename=filename, kind="job") - assert fn.kind == "job", "kind not set, test failed" - assert fn.spec.build.functionSourceCode, "code not embedded" - assert fn.spec.build.origin_filename == filename, "did not record filename" - assert type(fn.metadata.labels) == dict, "metadata labels were not set" - run = fn.run(workdir=str(examples_path), local=True) - - project, uri, tag, hash_key = parse_versioned_object_uri(run.spec.function) - local_fn = get_run_db().get_function(uri, project, tag=tag, hash_key=hash_key) - assert local_fn["spec"]["command"] == filename, "wrong command path" - assert ( - local_fn["spec"]["build"]["functionSourceCode"] - == fn.spec.build.functionSourceCode - ), "code was not copied to local function" - - def test_job_file_noembed(): name = f"{examples_path}/training.py" fn = code_to_function(filename=name, kind="job", embed_code=False) @@ -106,24 +86,3 @@ def test_local_file_codeout(): assert path.isfile(out), "output not generated" fn.run(handler="training", params={"p1": 5}) - - -def test_nuclio_py(): - name = f"{examples_path}/training.py" - fn = code_to_function("nuclio", filename=name, kind="nuclio", handler="my_hand") - name, project, config = compile_function_config(fn) - assert fn.kind == "remote", "kind not set, test failed" - assert get_in(config, "spec.build.functionSourceCode"), "no source code" - assert get_in(config, "spec.runtime").startswith("py"), "runtime not set" - assert get_in(config, "spec.handler") == "training:my_hand", "wrong handler" - - -def test_nuclio_golang(): - name = f"{tests_root_directory}/assets/hello.go" - fn = code_to_function( - "nuclio", filename=name, kind="nuclio", handler="main:Handler" - ) - name, project, config = compile_function_config(fn) - assert fn.kind == "remote", "kind not set, test failed" - assert get_in(config, "spec.runtime") == "golang", "golang was not detected and set" - assert get_in(config, "spec.handler") == "main:Handler", "wrong handler" diff --git a/tests/test_config.py b/tests/test_config.py index f68b945eee2a..28e5bec41816 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest.mock from contextlib import contextmanager from os import environ from tempfile import NamedTemporaryFile @@ -23,7 +24,7 @@ import mlrun.errors from mlrun import config as mlconf -from mlrun.api.schemas import SecurityContextEnrichmentModes +from mlrun.common.schemas import SecurityContextEnrichmentModes from mlrun.db.httpdb import HTTPRunDB namespace_env_key = f"{mlconf.env_prefix}NAMESPACE" @@ -100,6 +101,68 @@ def test_file(config): assert config.namespace == ns, "not populated from file" +@pytest.mark.parametrize( + "mlrun_dbpath,v3io_api,v3io_framesd,expected_v3io_api,expected_v3io_framesd", + ( + ( + "http://mlrun-api:8080", + "", + "", + "http://v3io-webapi:8081", + "http://framesd:8080", + ), + ( + "http://mlrun-api:8080", + "http://v3io-webapi:8081", + "", + "http://v3io-webapi:8081", + "http://framesd:8080", + ), + ( + "https://mlrun-api.default-tenant.app.somedev.cluster.amzn.com", + "", + "", + "https://webapi.default-tenant.app.somedev.cluster.amzn.com", + "https://framesd.default-tenant.app.somedev.cluster.amzn.com", + ), + ( + "https://mlrun-api.default-tenant.app.somedev.cluster.amzn.com", + "https://webapi.default-tenant.app.somedev.cluster.amzn.com", + "", + "https://webapi.default-tenant.app.somedev.cluster.amzn.com", + "https://framesd.default-tenant.app.somedev.cluster.amzn.com", + ), + ( + "https://mlrun-api.default-tenant.app.somedev.cluster.amzn.com", + "", + "https://framesd.default-tenant.app.somedev.cluster.amzn.com", + "https://webapi.default-tenant.app.somedev.cluster.amzn.com", + "https://framesd.default-tenant.app.somedev.cluster.amzn.com", + ), + ), +) +def test_v3io_api_and_framesd_enrichment_from_dbpath( + config, + mlrun_dbpath, + v3io_api, + v3io_framesd, + expected_v3io_api, + expected_v3io_framesd, + monkeypatch, +): + with unittest.mock.patch.object(mlrun.db, "get_run_db", return_value=None): + env = { + "MLRUN_DBPATH": mlrun_dbpath, + "V3IO_API": v3io_api, + "V3IO_FRAMESD": v3io_framesd, + } + with patch_env(env): + mlconf.config.reload() + + assert config.v3io_api == expected_v3io_api + assert config.v3io_framesd == expected_v3io_framesd + + def test_env(config): ns = "orange" with patch_env({namespace_env_key: ns}): diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 8f714f504dff..acd5b564cb76 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -52,7 +52,7 @@ def test_in_memory(): ), "failed to log in mem artifact" -def test_file(): +def test_file(rundb_mock): with TemporaryDirectory() as tmpdir: print(tmpdir) diff --git a/tests/test_execution.py b/tests/test_execution.py new file mode 100644 index 000000000000..d63851d54163 --- /dev/null +++ b/tests/test_execution.py @@ -0,0 +1,189 @@ +# Copyright 2018 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import datetime +import unittest.mock + +import pytest + +import mlrun +import mlrun.artifacts +import mlrun.errors +from tests.conftest import out_path + + +def test_local_context(rundb_mock): + project_name = "xtst" + mlrun.mlconf.artifact_path = out_path + context = mlrun.get_or_create_ctx("xx", project=project_name, upload_artifacts=True) + db = mlrun.get_run_db() + run = db.read_run(context._uid, project=project_name) + assert run["status"]["state"] == "running", "run status not updated in db" + + # calls __exit__ and commits the context + with context: + context.log_artifact("xx", body="123", local_path="a.txt") + context.log_model("mdl", body="456", model_file="mdl.pkl", artifact_path="+/mm") + context.get_param("p1", 1) + context.get_param("p2", "a string") + context.log_result("accuracy", 16) + context.set_label("label-key", "label-value") + context.set_annotation("annotation-key", "annotation-value") + context._set_input("input-key", "input-url") + + artifact = context.get_cached_artifact("xx") + artifact.format = "z" + context.update_artifact(artifact) + + assert context._state == "completed", "task did not complete" + + run = db.read_run(context._uid, project=project_name) + + # run state should not be updated by the context + assert run["status"]["state"] == "running", "run status was updated in db" + assert ( + run["status"]["artifacts"][0]["metadata"]["key"] == "xx" + ), "artifact not updated in db" + assert ( + run["status"]["artifacts"][0]["spec"]["format"] == "z" + ), "run/artifact attribute not updated in db" + assert run["status"]["artifacts"][1]["spec"]["target_path"].startswith( + out_path + ), "artifact not uploaded to subpath" + + db_artifact = db.read_artifact(artifact.db_key, project=project_name) + assert db_artifact["spec"]["format"] == "z", "artifact attribute not updated in db" + + assert run["spec"]["parameters"]["p1"] == 1, "param not updated in db" + assert run["spec"]["parameters"]["p2"] == "a string", "param not updated in db" + assert run["status"]["results"]["accuracy"] == 16, "result not updated in db" + assert run["metadata"]["labels"]["label-key"] == "label-value", "label not updated" + assert ( + run["metadata"]["annotations"]["annotation-key"] == "annotation-value" + ), "annotation not updated" + + assert run["spec"]["inputs"]["input-key"] == "input-url", "input not updated" + + +def test_context_from_dict_when_start_time_is_string(): + context = mlrun.get_or_create_ctx("ctx") + context_dict = context.to_dict() + context = mlrun.MLClientCtx.from_dict(context_dict) + assert isinstance(context._start_time, datetime.datetime) + + +@pytest.mark.parametrize( + "is_api", + [True, False], +) +def test_context_from_run_dict(is_api): + with unittest.mock.patch("mlrun.config.is_running_as_api", return_value=is_api): + run_dict = _generate_run_dict() + + # create run object from dict and dict again to mock the run serialization + run = mlrun.run.RunObject.from_dict(run_dict) + context = mlrun.MLClientCtx.from_dict(run.to_dict(), is_api=is_api) + + assert context.name == run_dict["metadata"]["name"] + assert context._project == run_dict["metadata"]["project"] + assert context._labels == run_dict["metadata"]["labels"] + assert context._annotations == run_dict["metadata"]["annotations"] + assert context.get_param("p1") == run_dict["spec"]["parameters"]["p1"] + assert context.get_param("p2") == run_dict["spec"]["parameters"]["p2"] + assert ( + context.labels["label-key"] == run_dict["metadata"]["labels"]["label-key"] + ) + assert ( + context.annotations["annotation-key"] + == run_dict["metadata"]["annotations"]["annotation-key"] + ) + assert context.artifact_path == run_dict["spec"]["output_path"] + + +@pytest.mark.parametrize( + "state, error, expected_state", + [ + ("running", None, "completed"), + ("completed", None, "completed"), + (None, "error message", "error"), + (None, "", "error"), + ], +) +def test_context_set_state(rundb_mock, state, error, expected_state): + project_name = "test_context_error" + mlrun.mlconf.artifact_path = out_path + context = mlrun.get_or_create_ctx("xx", project=project_name, upload_artifacts=True) + db = mlrun.get_run_db() + run = db.read_run(context._uid, project=project_name) + assert run["status"]["state"] == "running", "run status not updated in db" + + # calls __exit__ and commits the context + with context: + context.set_state(execution_state=state, error=error, commit=False) + + assert context._state == expected_state, "task state was not set correctly" + assert context._error == error, "task error was not set" + + +@pytest.mark.parametrize( + "is_api", + [True, False], +) +def test_context_inputs(rundb_mock, is_api): + with unittest.mock.patch("mlrun.config.is_running_as_api", return_value=is_api): + run_dict = _generate_run_dict() + + # create run object from dict and dict again to mock the run serialization + run = mlrun.run.RunObject.from_dict(run_dict) + context = mlrun.MLClientCtx.from_dict(run.to_dict(), is_api=is_api) + assert ( + context.get_input("input-key").artifact_url + == run_dict["spec"]["inputs"]["input-key"] + ) + assert context._inputs["input-key"] == run_dict["spec"]["inputs"]["input-key"] + + key = "store-input" + url = run_dict["spec"]["inputs"][key] + assert context._inputs[key] == run_dict["spec"]["inputs"][key] + + # 'store-input' is a store artifact, store it in the db before getting it + artifact = mlrun.artifacts.Artifact(key, b"123") + rundb_mock.store_artifact(key, artifact.to_dict(), uid="123") + mlrun.datastore.store_manager.object( + url, + key, + project=run_dict["metadata"]["project"], + allow_empty_resources=True, + ) + context._allow_empty_resources = True + assert context.get_input(key).artifact_url == run_dict["spec"]["inputs"][key] + + +def _generate_run_dict(): + return { + "metadata": { + "name": "test-context-from-run-dict", + "project": "default", + "labels": {"label-key": "label-value"}, + "annotations": {"annotation-key": "annotation-value"}, + }, + "spec": { + "parameters": {"p1": 1, "p2": "a string"}, + "output_path": "test_artifact_path", + "inputs": { + "input-key": "input-url", + "store-input": "store://store-input", + }, + "allow_empty_resources": True, + }, + } diff --git a/tests/test_model.py b/tests/test_model.py index 56429c4a097d..c9b56fa71c90 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import mlrun.api.schemas +import mlrun.common.schemas import mlrun.runtimes def test_enum_yaml_dump(): function = mlrun.new_function("function-name", kind="job") - function.status.state = mlrun.api.schemas.FunctionState.ready + function.status.state = mlrun.common.schemas.FunctionState.ready print(function.to_yaml()) diff --git a/tests/test_notebooks.py b/tests/test_notebooks.py deleted file mode 100644 index d0eaefe2c1e6..000000000000 --- a/tests/test_notebooks.py +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2018 Iguazio -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from collections import ChainMap -from os import environ -from pathlib import Path -from subprocess import run - -import pytest -import yaml - -here = Path(__file__).absolute().parent -root = here.parent -# Need to be in root for docker context -tmp_dockerfile = Path(root / "Dockerfile.mlrun-test-nb") -with (here / "Dockerfile.test-nb").open() as fp: - dockerfile_template = fp.read() -docker_tag = "mlrun/test-notebook" - - -def iter_notebooks(): - cfg_file = here / "notebooks.yml" - with cfg_file.open() as fp: - configs = yaml.safe_load(fp) - - for config in configs: - if "env" not in config: - config["env"] = {} - yield pytest.param(config, id=config["nb"]) - - -def args_from_env(env): - env = ChainMap(env, environ) - args, cmd = [], [] - for name in env: - if not name.startswith("MLRUN_"): - continue - value = env[name] - args.append(f"ARG {name}") - cmd.extend(["--build-arg", f"{name}={value}"]) - - args = "\n".join(args) - return args, cmd - - -@pytest.mark.parametrize("notebook", iter_notebooks()) -def test_notebook(notebook): - path = f'./examples/{notebook["nb"]}' - args, args_cmd = args_from_env(notebook["env"]) - deps = [] - for dep in notebook.get("pip", []): - deps.append(f"RUN python -m pip install {dep}") - pip = "\n".join(deps) - - code = dockerfile_template.format(notebook=path, args=args, pip=pip) - with tmp_dockerfile.open("w") as out: - out.write(code) - - cmd = ( - ["docker", "build", "--file", str(tmp_dockerfile), "--tag", docker_tag] - + args_cmd - + ["."] - ) - out = run(cmd, cwd=root) - assert out.returncode == 0, f"Failed building {out.stdout} {out.stderr}" diff --git a/tests/test_requirements.py b/tests/test_requirements.py index 65c083433182..081acf6eaa75 100644 --- a/tests/test_requirements.py +++ b/tests/test_requirements.py @@ -93,9 +93,8 @@ def test_requirement_specifiers_convention(): ignored_invalid_map = { # See comment near requirement for why we're limiting to patch changes only for all of these "kfp": {"~=1.8.0, <1.8.14"}, - "botocore": {">=1.20.106,<1.20.107"}, - "aiobotocore": {"~=1.4.0"}, - "storey": {"~=1.3.15"}, + "aiobotocore": {"~=2.4.2"}, + "storey": {"~=1.4.3"}, "bokeh": {"~=2.4, >=2.4.2"}, "typing-extensions": {">=3.10.0,<5"}, "sphinx": {"~=4.3.0"}, @@ -111,27 +110,27 @@ def test_requirement_specifiers_convention(): "v3io-generator": { " @ git+https://github.com/v3io/data-science.git#subdirectory=generator" }, - "fsspec": {"~=2021.8.1"}, - "adlfs": {"~=2021.8.1"}, - "s3fs": {"~=2021.8.1"}, - "gcsfs": {"~=2021.8.1"}, + "fsspec": {"~=2023.1.0"}, + "adlfs": {"~=2022.2.0"}, + "s3fs": {"~=2023.1.0"}, + "gcsfs": {"~=2023.1.0"}, "distributed": {"~=2021.11.2"}, "dask": {"~=2021.11.2"}, # All of these are actually valid, they just don't use ~= so the test doesn't "understand" that # TODO: make test smart enough to understand that - "urllib3": {">=1.25.4, <1.27"}, + "urllib3": {">=1.26.9, <1.27"}, "chardet": {">=3.0.2, <4.0"}, "numpy": {">=1.16.5, <1.23.0"}, "alembic": {"~=1.4,<1.6.0"}, - "boto3": {"~=1.9, <1.17.107"}, + "boto3": {"~=1.24.59"}, "dask-ml": {"~=1.4,<1.9.0"}, - "pyarrow": {">=10,<11"}, + "pyarrow": {">=10.0, <12"}, "nbclassic": {">=0.2.8"}, - "protobuf": {">=3.13, <3.20"}, "pandas": {"~=1.2, <1.5.0"}, "ipython": {">=7.0, <9.0"}, "importlib_metadata": {">=3.6"}, "gitpython": {"~=3.1, >= 3.1.30"}, + "orjson": {"~=3.3, <3.8.12"}, "pyopenssl": {">=23"}, "google-cloud-bigquery": {"[pandas, bqstorage]~=3.2"}, # plotly artifact body in 5.12.0 may contain chars that are not encodable in 'latin-1' encoding @@ -139,6 +138,8 @@ def test_requirement_specifiers_convention(): "plotly": {"~=5.4, <5.12.0"}, # used in tests "aioresponses": {"~=0.7"}, + # conda requirements since conda does not support ~= operator + "lightgbm": {">=3.0"}, } for ( @@ -169,6 +170,9 @@ def test_requirement_specifiers_inconsistencies(): # The empty specifier is from tests/runtimes/assets/requirements.txt which is there specifically to test the # scenario of requirements without version specifiers "python-dotenv": {"", "~=0.17.0"}, + # conda requirements since conda does not support ~= operator and + # since platform condition is not required for docker + "lightgbm": {"~=3.0", "~=3.0; platform_machine != 'arm64'", ">=3.0"}, } for ( diff --git a/tests/utils/logger/test_logger.py b/tests/utils/logger/test_logger.py index b7e45906c0a1..692df88a3b5d 100644 --- a/tests/utils/logger/test_logger.py +++ b/tests/utils/logger/test_logger.py @@ -92,3 +92,35 @@ def test_exception_with_stack(make_stream_logger): test_logger.exception("This is just a test") assert str(err) in stream.getvalue() assert "This is just a test" in stream.getvalue() + + +# Regression test for duplicate logs bug fixed in PR #3381 +def test_redundant_logger_creation(): + stream = StringIO() + logger1 = create_logger("debug", name="test-logger", stream=stream) + logger2 = create_logger("debug", name="test-logger", stream=stream) + logger3 = create_logger("debug", name="test-logger", stream=stream) + logger1.info("1") + assert stream.getvalue().count("[info] 1\n") == 1 + logger2.info("2") + assert stream.getvalue().count("[info] 2\n") == 1 + logger3.info("3") + assert stream.getvalue().count("[info] 3\n") == 1 + + +def test_child_logger(): + stream = StringIO() + logger = create_logger( + "debug", + name="test-logger", + stream=stream, + formatter_kind=FormatterKinds.HUMAN_EXTENDED.name, + ) + child_logger = logger.get_child("child") + logger.debug("") + child_logger.debug("") + log_lines = stream.getvalue().strip().splitlines() + + # validate parent and child log lines + assert "test-logger:debug" in log_lines[0] + assert "test-logger.child:debug" in log_lines[1] diff --git a/tests/utils/test_deprecation.py b/tests/utils/test_deprecation.py index 9f486b65bc5d..ccb1e97cb59d 100644 --- a/tests/utils/test_deprecation.py +++ b/tests/utils/test_deprecation.py @@ -64,3 +64,55 @@ def warn(): with pytest.raises(FutureWarning): warn() + + +def test_deprecation_helper(): + """ + This test validates that the deprecation warning is shown when using a deprecated class, and that the + object is created from the new class. + """ + import mlrun.api.schemas + import mlrun.common.schemas + + with warnings.catch_warnings(record=True) as w: + # create an object using the deprecated class + obj = mlrun.api.schemas.ObjectMetadata(name="name", project="project") + + # validate that the object is created from the new class + assert type(obj) == mlrun.common.schemas.ObjectMetadata + + # validate that the warning is shown + assert len(w) == 1 + assert ( + "mlrun.api.schemas.ObjectMetadata is deprecated in version 1.4.0, " + "Please use mlrun.common.schemas.ObjectMetadata instead." + in str(w[-1].message) + ) + + +def test_deprecated_schema_as_argument(): + """ + This test validates that the deprecation warning is shown when using a deprecated schema as an argument to a + function. And that the function still works, and the schema is converted to the new schema. + The test uses the get_secrets function as an example. + """ + import mlrun.api.api.utils + import mlrun.api.schemas + import mlrun.common.schemas + + data_session = "some-data-session" + + with warnings.catch_warnings(record=True) as w: + secrets = mlrun.api.api.utils.get_secrets( + auth_info=mlrun.api.schemas.AuthInfo(data_session=data_session), + ) + + assert "V3IO_ACCESS_KEY" in secrets + assert secrets["V3IO_ACCESS_KEY"] == data_session + + # validate that the warning is shown + assert len(w) == 1 + assert ( + "mlrun.api.schemas.AuthInfo is deprecated in version 1.4.0, " + "Please use mlrun.common.schemas.AuthInfo instead." in str(w[-1].message) + ) diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 5d88d767444a..c83ecc8de8f8 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -83,7 +83,7 @@ def _raise_fatal_failure(): pytest.raises(mlrun.errors.MLRunInvalidArgumentError), ), ( - # Invalid because it's more then 63 characters + # Invalid because it's more than 63 characters "azsxdcfvg-azsxdcfvg-azsxdcfvg-azsxdcfvg-azsxdcfvg-azsxdcfvg-azsx", pytest.raises(mlrun.errors.MLRunInvalidArgumentError), ), diff --git a/tests/utils/test_notifications.py b/tests/utils/test_notifications.py index f223c478bf8f..ba762a45a609 100644 --- a/tests/utils/test_notifications.py +++ b/tests/utils/test_notifications.py @@ -15,14 +15,53 @@ import asyncio import builtins import unittest.mock +from contextlib import nullcontext as does_not_raise import aiohttp import pytest import tabulate +import mlrun.api.api.utils +import mlrun.api.crud +import mlrun.common.schemas.notification import mlrun.utils.notifications +@pytest.mark.parametrize( + "notification_kind", mlrun.common.schemas.notification.NotificationKind +) +def test_load_notification(notification_kind): + run_uid = "test-run-uid" + notification_name = "test-notification-name" + when_state = "completed" + notification = mlrun.model.Notification.from_dict( + { + "kind": notification_kind, + "when": when_state, + "status": "pending", + "name": notification_name, + } + ) + run = mlrun.model.RunObject.from_dict( + { + "metadata": {"uid": run_uid}, + "spec": {"notifications": [notification]}, + "status": {"state": when_state}, + } + ) + + notification_pusher = ( + mlrun.utils.notifications.notification_pusher.NotificationPusher([run]) + ) + notification_pusher._load_notification(run, notification) + loaded_notifications = ( + notification_pusher._sync_notifications + + notification_pusher._async_notifications + ) + assert len(loaded_notifications) == 1 + assert loaded_notifications[0][0].name == notification_name + + @pytest.mark.parametrize( "when,condition,run_state,notification_previously_sent,expected", [ @@ -30,10 +69,10 @@ (["completed"], "", "completed", True, False), (["completed"], "", "error", False, False), (["completed"], "", "error", True, False), - (["completed"], "True", "completed", False, True), - (["completed"], "True", "completed", True, False), - (["completed"], "False", "completed", False, False), - (["completed"], "False", "completed", True, False), + (["completed"], "> 4", "completed", False, True), + (["completed"], "> 4", "completed", True, False), + (["completed"], "< 4", "completed", False, False), + (["completed"], "< 4", "completed", True, False), (["error"], "", "completed", False, False), (["error"], "", "completed", True, False), (["error"], "", "error", False, True), @@ -42,20 +81,25 @@ (["completed", "error"], "", "completed", True, False), (["completed", "error"], "", "error", False, True), (["completed", "error"], "", "error", True, False), - (["completed", "error"], "True", "completed", False, True), - (["completed", "error"], "True", "completed", True, False), - (["completed", "error"], "True", "error", False, True), - (["completed", "error"], "True", "error", True, False), - (["completed", "error"], "False", "completed", False, False), - (["completed", "error"], "False", "completed", True, False), - (["completed", "error"], "False", "error", False, True), - (["completed", "error"], "False", "error", True, False), + (["completed", "error"], "> 4", "completed", False, True), + (["completed", "error"], "> 4", "completed", True, False), + (["completed", "error"], "> 4", "error", False, True), + (["completed", "error"], "> 4", "error", True, False), + (["completed", "error"], "< 4", "completed", False, False), + (["completed", "error"], "< 4", "completed", True, False), + (["completed", "error"], "< 4", "error", False, True), + (["completed", "error"], "< 4", "error", True, False), ], ) def test_notification_should_notify( when, condition, run_state, notification_previously_sent, expected ): - run = mlrun.model.RunObject.from_dict({"status": {"state": run_state}}) + if condition: + condition = f'{{{{ run["status"]["results"]["val"] {condition} }}}}' + + run = mlrun.model.RunObject.from_dict( + {"status": {"state": run_state, "results": {"val": 5}}} + ) notification = mlrun.model.Notification.from_dict( { "when": when, @@ -64,13 +108,35 @@ def test_notification_should_notify( } ) - assert ( - mlrun.utils.notifications.notification_pusher.NotificationPusher._should_notify( - run, notification - ) - == expected + notification_pusher = ( + mlrun.utils.notifications.notification_pusher.NotificationPusher([run]) + ) + assert notification_pusher._should_notify(run, notification) == expected + + +def test_condition_evaluation_timeout(): + condition = """ + {% for i in range(100000) %} + {% for i in range(100000) %} + {% for i in range(100000) %} + {{ i }} + {% endfor %} + {% endfor %} + {% endfor %} + """ + + run = mlrun.model.RunObject.from_dict( + {"status": {"state": "completed", "results": {"val": 5}}} + ) + notification = mlrun.model.Notification.from_dict( + {"when": ["completed"], "condition": condition, "status": "pending"} ) + notification_pusher = ( + mlrun.utils.notifications.notification_pusher.NotificationPusher([run]) + ) + assert notification_pusher._should_notify(run, notification) + @pytest.mark.parametrize( "runs,expected,is_table", @@ -212,7 +278,19 @@ def test_slack_notification(runs, expected): "token": "test-token", "gitlab": True, }, - "https://gitlab.com/api/v4/projects/test-repo/merge_requests/test-issue/notes", + "https://gitlab.com/api/v4/projects/test-repo/issues/test-issue/notes", + { + "PRIVATE-TOKEN": "test-token", + }, + ), + ( + { + "repo": "test-repo", + "merge_request": "test-merge-request", + "token": "test-token", + "gitlab": True, + }, + "https://gitlab.com/api/v4/projects/test-repo/merge_requests/test-merge-request/notes", { "PRIVATE-TOKEN": "test-token", }, @@ -224,7 +302,7 @@ def test_slack_notification(runs, expected): "token": "test-token", "server": "custom-gitlab", }, - "https://custom-gitlab/api/v4/projects/test-repo/merge_requests/test-issue/notes", + "https://custom-gitlab/api/v4/projects/test-repo/issues/test-issue/notes", { "PRIVATE-TOKEN": "test-token", }, @@ -274,8 +352,8 @@ def test_inverse_dependencies( ] ) - mock_console_push = unittest.mock.MagicMock() - mock_ipython_push = unittest.mock.MagicMock() + mock_console_push = unittest.mock.MagicMock(return_value=Exception()) + mock_ipython_push = unittest.mock.MagicMock(return_value=Exception()) monkeypatch.setattr( mlrun.utils.notifications.ConsoleNotification, "push", mock_console_push ) @@ -287,5 +365,193 @@ def test_inverse_dependencies( ) custom_notification_pusher.push("test-message", "info", []) + assert mock_console_push.call_count == expected_console_call_amount assert mock_ipython_push.call_count == expected_ipython_call_amount + + +def test_notification_params_masking_on_run(monkeypatch): + def _store_project_secrets(*args, **kwargs): + pass + + monkeypatch.setattr( + mlrun.api.crud.Secrets, "store_project_secrets", _store_project_secrets + ) + run_uid = "test-run-uid" + run = { + "metadata": {"uid": run_uid, "project": "test-project"}, + "spec": { + "notifications": [ + {"when": "completed", "params": {"sensitive": "sensitive-value"}} + ] + }, + } + mlrun.api.api.utils.mask_notification_params_on_task(run) + assert "sensitive" not in run["spec"]["notifications"][0]["params"] + assert "secret" in run["spec"]["notifications"][0]["params"] + assert ( + run["spec"]["notifications"][0]["params"]["secret"] + == f"mlrun.notifications.{run_uid}" + ) + + +NOTIFICATION_VALIDATION_PARMETRIZE = [ + ( + { + "kind": "invalid-kind", + }, + pytest.raises(mlrun.errors.MLRunInvalidArgumentError), + ), + ( + { + "kind": mlrun.common.schemas.notification.NotificationKind.slack, + }, + does_not_raise(), + ), + ( + { + "severity": "invalid-severity", + }, + pytest.raises(mlrun.errors.MLRunInvalidArgumentError), + ), + ( + { + "severity": mlrun.common.schemas.notification.NotificationSeverity.INFO, + }, + does_not_raise(), + ), + ( + { + "status": "invalid-status", + }, + pytest.raises(mlrun.errors.MLRunInvalidArgumentError), + ), + ( + { + "status": mlrun.common.schemas.notification.NotificationStatus.PENDING, + }, + does_not_raise(), + ), +] + + +@pytest.mark.parametrize( + "notification_kwargs,expectation", + NOTIFICATION_VALIDATION_PARMETRIZE, +) +def test_notification_validation_on_object( + monkeypatch, notification_kwargs, expectation +): + with expectation: + mlrun.model.Notification(**notification_kwargs) + + +@pytest.mark.parametrize( + "notification_kwargs,expectation", + NOTIFICATION_VALIDATION_PARMETRIZE, +) +def test_notification_validation_on_run(monkeypatch, notification_kwargs, expectation): + notification = mlrun.model.Notification( + name="test-notification", when=["completed"] + ) + for key, value in notification_kwargs.items(): + setattr(notification, key, value) + function = mlrun.new_function( + "function-from-module", + kind="job", + project="test-project", + image="mlrun/mlrun", + ) + with expectation: + function.run( + handler="json.dumps", + params={"obj": {"x": 99}}, + notifications=[notification], + local=True, + ) + + +def test_notification_sent_on_handler_run(monkeypatch): + + run_many_mock = unittest.mock.Mock(return_value=[]) + push_mock = unittest.mock.Mock() + + monkeypatch.setattr(mlrun.runtimes.HandlerRuntime, "_run_many", run_many_mock) + monkeypatch.setattr(mlrun.utils.notifications.NotificationPusher, "push", push_mock) + + def hyper_func(context, p1, p2): + print(f"p1={p1}, p2={p2}, result={p1 * p2}") + context.log_result("multiplier", p1 * p2) + + notification = mlrun.model.Notification( + name="test-notification", when=["completed"] + ) + + grid_params = {"p1": [2, 4, 1], "p2": [10, 20]} + task = mlrun.new_task("grid-demo").with_hyper_params( + grid_params, selector="max.multiplier" + ) + mlrun.new_function().run(task, handler=hyper_func, notifications=[notification]) + run_many_mock.assert_called_once() + push_mock.assert_called_once() + + +def test_notification_sent_on_dask_run(monkeypatch): + + run_mock = unittest.mock.Mock(return_value=None) + push_mock = unittest.mock.Mock() + + monkeypatch.setattr(mlrun.runtimes.LocalRuntime, "_run", run_mock) + monkeypatch.setattr(mlrun.utils.notifications.NotificationPusher, "push", push_mock) + + notification = mlrun.model.Notification( + name="test-notification", when=["completed"] + ) + + function = mlrun.new_function( + "function-from-module", + kind="dask", + project="test-project", + image="mlrun/mlrun", + ) + + function.run( + handler="json.dumps", + params={"obj": {"x": 99}}, + notifications=[notification], + local=True, + ) + + run_mock.assert_called_once() + push_mock.assert_called_once() + + +@pytest.mark.parametrize( + "notification1_name,notification2_name,expectation", + [ + ("n1", "n1", pytest.raises(mlrun.errors.MLRunInvalidArgumentError)), + ("n1", "n2", does_not_raise()), + ], +) +def test_notification_name_uniqueness_validation( + notification1_name, notification2_name, expectation +): + notification1 = mlrun.model.Notification( + name=notification1_name, when=["completed"] + ) + notification2 = mlrun.model.Notification( + name=notification2_name, when=["completed"] + ) + function = mlrun.new_function( + "function-from-module", + kind="job", + project="test-project", + image="mlrun/mlrun", + ) + with expectation: + function.run( + handler="json.dumps", + params={"obj": {"x": 99}}, + notifications=[notification1, notification2], + local=True, + ) diff --git a/tests/utils/test_vault.py b/tests/utils/test_vault.py index 39ce925edb76..4c3035ef484b 100644 --- a/tests/utils/test_vault.py +++ b/tests/utils/test_vault.py @@ -12,123 +12,124 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import pytest - -import mlrun -from mlrun import code_to_function, get_run_db, mlconf, new_project, new_task -from mlrun.utils.vault import VaultStore -from tests.conftest import examples_path, out_path, verify_state - -# Set a proper token value for Vault test -user_token = "" - - -# Set test secrets and configurations - you may need to modify these. -def _set_vault_mlrun_configuration(api_server_port=None): - if api_server_port: - mlconf.dbpath = f"http://localhost:{api_server_port}" - mlconf.secret_stores.vault.url = "http://localhost:8200" - mlconf.secret_stores.vault.user_token = user_token - - -# Verify that local activation of Vault functionality is successful. This does not -# test the API-server implementation, which is verified in other tests -@pytest.mark.skipif(user_token == "", reason="no vault configuration") -def test_direct_vault_usage(): - - _set_vault_mlrun_configuration() - project_name = "the-blair-witch-project" - - vault = VaultStore() - vault.delete_vault_secrets(project=project_name) - secrets = vault.get_secrets(None, project=project_name) - assert len(secrets) == 0, "Secrets were not deleted" - - expected_secrets = {"secret1": "123456", "secret2": "654321"} - vault.add_vault_secrets(expected_secrets, project=project_name) - - secrets = vault.get_secrets(None, project=project_name) - assert ( - secrets == expected_secrets - ), "Vault contains different set of secrets than expected" - - secrets = vault.get_secrets(["secret1"], project=project_name) - assert len(secrets) == 1 and secrets["secret1"] == expected_secrets["secret1"] - - # Test the same thing for user - user_name = "pikachu" - vault.delete_vault_secrets(user=user_name) - secrets = vault.get_secrets(None, user=user_name) - assert len(secrets) == 0, "Secrets were not deleted" - - vault.add_vault_secrets(expected_secrets, user=user_name) - secrets = vault.get_secrets(None, user=user_name) - assert ( - secrets == expected_secrets - ), "Vault contains different set of secrets than expected" - - # Cleanup - vault.delete_vault_secrets(project=project_name) - vault.delete_vault_secrets(user=user_name) - - -@pytest.mark.skipif(user_token == "", reason="no vault configuration") -def test_vault_end_to_end(): - # This requires an MLRun API server to run and work with Vault. This port should - # be configured to allow access to the server. - api_server_port = 57764 - - _set_vault_mlrun_configuration(api_server_port) - project_name = "abc" - func_name = "vault-function" - aws_key_value = "1234567890" - github_key_value = "proj1Key!!!" - - project = new_project(project_name) - # This call will initialize Vault infrastructure and add the given secrets - # It executes on the API server - project.set_secrets( - {"aws_key": aws_key_value, "github_key": github_key_value}, - provider=mlrun.api.schemas.SecretProviderName.vault, - ) - - # This API executes on the client side - vault = VaultStore() - project_secrets = vault.get_secrets(["aws_key", "github_key"], project=project_name) - assert project_secrets == ["aws_key", "github_key"], "secrets not created" - - # Create function and set container configuration - function = code_to_function( - name=func_name, - filename=f"{examples_path}/vault_function.py", - handler="vault_func", - project=project_name, - kind="job", - ) - - function.spec.image = "saarcoiguazio/mlrun:unstable" - - # Create context for the execution - spec = new_task( - project=project_name, - name="vault_test_run", - handler="vault_func", - out_path=out_path, - params={"secrets": ["password", "path", "github_key", "aws_key"]}, - ) - spec.with_secrets("vault", []) - - result = function.run(spec) - verify_state(result) - - db = get_run_db().connect() - state, log = db.get_log(result.metadata.uid, project=project_name) - log = str(log) - print(state) - - assert ( - log.find(f"value: {aws_key_value}") != -1 - ), "secret value not detected in function output" - assert ( - log.find(f"value: {github_key_value}") != -1 - ), "secret value not detected in function output" +# TODO: Vault: uncomment when vault returns to be relevant +# import pytest +# +# import mlrun +# from mlrun import code_to_function, get_run_db, mlconf, new_project, new_task +# from mlrun.utils.vault import VaultStore +# from tests.conftest import examples_path, out_path, verify_state +# +# # Set a proper token value for Vault test +# user_token = "" +# +# +# # Set test secrets and configurations - you may need to modify these. +# def _set_vault_mlrun_configuration(api_server_port=None): +# if api_server_port: +# mlconf.dbpath = f"http://localhost:{api_server_port}" +# mlconf.secret_stores.vault.url = "http://localhost:8200" +# mlconf.secret_stores.vault.user_token = user_token +# +# +# # Verify that local activation of Vault functionality is successful. This does not +# # test the API-server implementation, which is verified in other tests +# @pytest.mark.skipif(user_token == "", reason="no vault configuration") +# def test_direct_vault_usage(): +# +# _set_vault_mlrun_configuration() +# project_name = "the-blair-witch-project" +# +# vault = VaultStore() +# vault.delete_vault_secrets(project=project_name) +# secrets = vault.get_secrets(None, project=project_name) +# assert len(secrets) == 0, "Secrets were not deleted" +# +# expected_secrets = {"secret1": "123456", "secret2": "654321"} +# vault.add_vault_secrets(expected_secrets, project=project_name) +# +# secrets = vault.get_secrets(None, project=project_name) +# assert ( +# secrets == expected_secrets +# ), "Vault contains different set of secrets than expected" +# +# secrets = vault.get_secrets(["secret1"], project=project_name) +# assert len(secrets) == 1 and secrets["secret1"] == expected_secrets["secret1"] +# +# # Test the same thing for user +# user_name = "pikachu" +# vault.delete_vault_secrets(user=user_name) +# secrets = vault.get_secrets(None, user=user_name) +# assert len(secrets) == 0, "Secrets were not deleted" +# +# vault.add_vault_secrets(expected_secrets, user=user_name) +# secrets = vault.get_secrets(None, user=user_name) +# assert ( +# secrets == expected_secrets +# ), "Vault contains different set of secrets than expected" +# +# # Cleanup +# vault.delete_vault_secrets(project=project_name) +# vault.delete_vault_secrets(user=user_name) +# +# +# @pytest.mark.skipif(user_token == "", reason="no vault configuration") +# def test_vault_end_to_end(): +# # This requires an MLRun API server to run and work with Vault. This port should +# # be configured to allow access to the server. +# api_server_port = 57764 +# +# _set_vault_mlrun_configuration(api_server_port) +# project_name = "abc" +# func_name = "vault-function" +# aws_key_value = "1234567890" +# github_key_value = "proj1Key!!!" +# +# project = new_project(project_name) +# # This call will initialize Vault infrastructure and add the given secrets +# # It executes on the API server +# project.set_secrets( +# {"aws_key": aws_key_value, "github_key": github_key_value}, +# provider=mlrun.api.schemas.SecretProviderName.vault, +# ) +# +# # This API executes on the client side +# vault = VaultStore() +# project_secrets = vault.get_secrets(["aws_key", "github_key"], project=project_name) +# assert project_secrets == ["aws_key", "github_key"], "secrets not created" +# +# # Create function and set container configuration +# function = code_to_function( +# name=func_name, +# filename=f"{examples_path}/vault_function.py", +# handler="vault_func", +# project=project_name, +# kind="job", +# ) +# +# function.spec.image = "saarcoiguazio/mlrun:unstable" +# +# # Create context for the execution +# spec = new_task( +# project=project_name, +# name="vault_test_run", +# handler="vault_func", +# out_path=out_path, +# params={"secrets": ["password", "path", "github_key", "aws_key"]}, +# ) +# spec.with_secrets("vault", []) +# +# result = function.run(spec) +# verify_state(result) +# +# db = get_run_db().connect() +# state, log = db.get_log(result.metadata.uid, project=project_name) +# log = str(log) +# print(state) +# +# assert ( +# log.find(f"value: {aws_key_value}") != -1 +# ), "secret value not detected in function output" +# assert ( +# log.find(f"value: {github_key_value}") != -1 +# ), "secret value not detected in function output"