diff --git a/.github/workflows/build_image.yml b/.github/workflows/build_image.yml new file mode 100644 index 0000000000..da688ac4a2 --- /dev/null +++ b/.github/workflows/build_image.yml @@ -0,0 +1,86 @@ +name: Publish Python Package + +on: + workflow_dispatch: + +jobs: + build-and-push-docker-images-manual: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.9", "3.10", "3.11"] + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: "0" + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + - name: Login to GitHub Container Registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: "${{ secrets.FLYTE_BOT_USERNAME }}" + password: "${{ secrets.FLYTE_BOT_PAT }}" + - name: Prepare Flytekit Image Names + id: flytekit-names + uses: docker/metadata-action@v3 + with: + images: | + ghcr.io/${{ github.repository_owner }}/flytekit + tags: | + py${{ matrix.python-version }}-${{ github.sha }} + - name: Build & Push Flytekit Python${{ matrix.python-version }} Docker Image to Github Registry + uses: docker/build-push-action@v2 + with: + context: . + platforms: linux/arm64, linux/amd64 + push: true + tags: ${{ steps.flytekit-names.outputs.tags }} + build-args: | + VERSION=${{ github.sha }} + DOCKER_IMAGE=ghcr.io/${{ github.repository_owner }}/flytekit:py${{ matrix.python-version }}-${{ github.sha }} + PYTHON_VERSION=${{ matrix.python-version }} + file: Dockerfile + cache-from: type=gha + cache-to: type=gha,mode=max + + build-and-push-flyteagent-images-manual: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: "0" + - name: Set up QEMU + uses: docker/setup-qemu-action@v1 + - name: Set up Docker Buildx + id: buildx + uses: docker/setup-buildx-action@v1 + - name: Login to GitHub Container Registry + uses: docker/login-action@v1 + with: + registry: ghcr.io + username: "${{ secrets.FLYTE_BOT_USERNAME }}" + password: "${{ secrets.FLYTE_BOT_PAT }}" + - name: Prepare Flyte Agent Image Names + id: flyteagent-names + uses: docker/metadata-action@v3 + with: + images: | + ghcr.io/${{ github.repository_owner }}/flyteagent + tags: | + ${{ github.sha }} + - name: Push External Plugin Service Image to GitHub Registry + uses: docker/build-push-action@v2 + with: + context: "." + platforms: linux/arm64, linux/amd64 + push: true + tags: ${{ steps.flyteagent-names.outputs.tags }} + build-args: | + VERSION=${{ github.sha }} + file: ./Dockerfile.agent + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 59b4feeb36..b5a0b02d76 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -16,13 +16,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.8", "3.9", "3.10", "3.11"] - exclude: - # Ignore this test because we failed to install docker-py - # docker-py will install pywin32==227, whereas pywin only added support for python 3.10 in version 301. - # For more detail, see https://github.com/flyteorg/flytekit/pull/856#issuecomment-1067152855 - - python-version: 3.10 - os: windows-latest + python-version: ["3.8", "3.11"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} @@ -47,13 +41,56 @@ jobs: with: fail_ci_if_error: false + build-integration: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + # python 3.11 has intermittent issues with the docker build + push step + # https://github.com/flyteorg/flytekit/actions/runs/5800978835/job/15724237979?pr=1579 + python-version: ["3.8", "3.11"] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Cache pip + uses: actions/cache@v2 + with: + # This path is specific to Ubuntu + path: ~/.cache/pip + # Look to see if there is a cache hit for the corresponding requirements files + key: ${{ format('{0}-pip-{1}', runner.os, hashFiles('dev-requirements.in', 'requirements.in')) }} + - name: Install dependencies + run: make setup && pip freeze + - name: Install FlyteCTL + uses: unionai-oss/flytectl-setup-action@master + - name: Setup Flyte Sandbox + run: | + flytectl demo start + flytectl config init + - name: Build and push to local registry + run: | + docker build . -f Dockerfile.dev -t localhost:30000/flytekit:dev --build-arg PYTHON_VERSION=${{ matrix.python-version }} + - name: Integration Test with coverage + env: + FLYTEKIT_IMAGE: localhost:30000/flytekit:dev + FLYTEKIT_CI: 1 + run: make integration_test_codecov + - name: Codecov + uses: codecov/codecov-action@v3.1.0 + with: + fail_ci_if_error: false + build-plugins: needs: build runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.8", "3.11"] plugin-names: # Please maintain an alphabetical order in the following list - flytekit-aws-athena @@ -104,12 +141,6 @@ jobs: plugin-names: "flytekit-greatexpectations" # onnxruntime does not support python 3.10 yet # https://github.com/microsoft/onnxruntime/issues/9782 - - python-version: 3.10 - plugin-names: "flytekit-onnx-pytorch" - - python-version: 3.10 - plugin-names: "flytekit-onnx-scikitlearn" - - python-version: 3.10 - plugin-names: "flytekit-onnx-tensorflow" - python-version: 3.11 plugin-names: "flytekit-onnx-pytorch" - python-version: 3.11 @@ -158,10 +189,14 @@ jobs: # onnx plugins does not support protobuf>4 yet (in fact it is tensorflow that # does not support that yet). More details in https://github.com/onnx/onnx/issues/4239. if [[ ${{ matrix.plugin-names }} == *"onnx"* || ${{ matrix.plugin-names }} == "flytekit-whylogs" || ${{ matrix.plugin-names }} == "flytekit-mlflow" ]]; then - PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python coverage run -m pytest tests + PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python coverage run -m pytest tests --cov=./ --cov-report=xml --cov-append else - coverage run -m pytest tests + coverage run -m pytest tests --cov=./ --cov-report=xml --cov-append fi + - name: Codecov + uses: codecov/codecov-action@v3.1.0 + with: + fail_ci_if_error: false lint: runs-on: ubuntu-latest steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3007f6e64d..c8f0e974eb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,28 +1,28 @@ repos: -- repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 - hooks: - - id: flake8 -- repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black -- repo: https://github.com/PyCQA/isort - rev: 5.12.0 - hooks: - - id: isort - args: ["--profile", "black"] -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.2.0 - hooks: - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: https://github.com/shellcheck-py/shellcheck-py - rev: v0.8.0.4 - hooks: - - id: shellcheck -- repo: https://github.com/conorfalvey/check_pdb_hook - rev: 0.0.9 - hooks: - - id: check_pdb_hook + - repo: https://github.com/PyCQA/flake8 + rev: 4.0.1 + hooks: + - id: flake8 + - repo: https://github.com/psf/black + rev: 22.3.0 + hooks: + - id: black + - repo: https://github.com/PyCQA/isort + rev: 5.12.0 + hooks: + - id: isort + args: ["--profile", "black"] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.2.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/shellcheck-py/shellcheck-py + rev: v0.8.0.4 + hooks: + - id: shellcheck + - repo: https://github.com/conorfalvey/check_pdb_hook + rev: 0.0.9 + hooks: + - id: check_pdb_hook diff --git a/.readthedocs.yml b/.readthedocs.yml index 18f4292317..a553c2f8e0 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -20,4 +20,3 @@ sphinx: python: install: - requirements: doc-requirements.txt - - requirements: docs/requirements.txt diff --git a/Dockerfile b/Dockerfile index 257fcb5143..d9662e2679 100644 --- a/Dockerfile +++ b/Dockerfile @@ -16,7 +16,6 @@ RUN apt-get update && apt-get install build-essential -y RUN pip install -U flytekit==$VERSION \ flytekitplugins-pod==$VERSION \ flytekitplugins-deck-standard==$VERSION \ - flytekitplugins-envd==$VERSION \ scikit-learn RUN useradd -u 1000 flytekit diff --git a/Dockerfile.agent b/Dockerfile.agent index 2194f5de23..79dfb5b9d0 100644 --- a/Dockerfile.agent +++ b/Dockerfile.agent @@ -4,6 +4,7 @@ MAINTAINER Flyte Team LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit ARG VERSION +RUN pip install prometheus-client RUN pip install -U flytekit==$VERSION flytekitplugins-bigquery==$VERSION CMD pyflyte serve --port 8000 diff --git a/Makefile b/Makefile index 10112fdadb..8593957743 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ export REPOSITORY=flytekit -PIP_COMPILE = pip-compile --upgrade --verbose +PIP_COMPILE = pip-compile --upgrade --verbose --resolver=backtracking MOCK_FLYTE_REPO=tests/flytekit/integration/remote/mock_flyte_repo/workflows .SILENT: help @@ -49,8 +49,6 @@ test: lint unit_test .PHONY: unit_test_codecov unit_test_codecov: - # Ensure coverage file - rm coverage.xml || true $(MAKE) CODECOV_OPTS="--cov=./ --cov-report=xml --cov-append" unit_test .PHONY: unit_test @@ -60,9 +58,17 @@ unit_test: pytest -m "not sandbox_test" tests/flytekit/unit/ --ignore=tests/flytekit/unit/extras/tensorflow ${CODECOV_OPTS} && \ PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python pytest tests/flytekit/unit/extras/tensorflow ${CODECOV_OPTS} +.PHONY: integration_test_codecov +integration_test_codecov: + $(MAKE) CODECOV_OPTS="--cov=./ --cov-report=xml --cov-append" integration_test + +.PHONY: integration_test +integration_test: + pytest tests/flytekit/integration/experimental ${CODECOV_OPTS} + doc-requirements.txt: export CUSTOM_COMPILE_COMMAND := make doc-requirements.txt doc-requirements.txt: doc-requirements.in install-piptools - docker run --platform linux/amd64 --rm -it --volume .:/root python:3.9-slim-buster sh -c "cd /root && apt-get update && apt-get install git -y && pip install pip-tools && pip-compile --upgrade --verbose doc-requirements.in" + $(PIP_COMPILE) $< ${MOCK_FLYTE_REPO}/requirements.txt: export CUSTOM_COMPILE_COMMAND := make ${MOCK_FLYTE_REPO}/requirements.txt ${MOCK_FLYTE_REPO}/requirements.txt: ${MOCK_FLYTE_REPO}/requirements.in install-piptools diff --git a/README.md b/README.md index 67b6d12297..95ed844bad 100644 --- a/README.md +++ b/README.md @@ -16,9 +16,10 @@ [![PyPI version fury.io](https://badge.fury.io/py/flytekit.svg)](https://pypi.python.org/pypi/flytekit/) [![PyPI download day](https://img.shields.io/pypi/dd/flytekit.svg)](https://pypi.python.org/pypi/flytekit/) [![PyPI download month](https://img.shields.io/pypi/dm/flytekit.svg)](https://pypi.python.org/pypi/flytekit/) +[![PyPI total download](https://static.pepy.tech/badge/flytekit)](https://static.pepy.tech/badge/flytekit) [![PyPI format](https://img.shields.io/pypi/format/flytekit.svg)](https://pypi.python.org/pypi/flytekit/) [![PyPI implementation](https://img.shields.io/pypi/implementation/flytekit.svg)](https://pypi.python.org/pypi/flytekit/) -![Codecov](https://img.shields.io/codecov/c/github/flyteorg/flytekit?style=plastic) +[![Codecov](https://img.shields.io/codecov/c/github/flyteorg/flytekit?style=plastic)](https://app.codecov.io/gh/flyteorg/flytekit) [![PyPI pyversions](https://img.shields.io/pypi/pyversions/flytekit.svg)](https://pypi.python.org/pypi/flytekit/) [![Docs](https://readthedocs.org/projects/flytekit/badge/?version=latest&style=plastic)](https://flytekit.rtfd.io) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) diff --git a/codecov.yml b/codecov.yml index 89ec57f646..9c38fc7bc3 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,7 +1,9 @@ -ignore_paths: +ignore: - "flytekit/bin" + - "flytekit/clis/**/*" - "test_*.py" - - "flytekit/__init__.py" - - "flytekit/extend/__init__.py" - - "flytekit/testing/__init__.py" - - "tests/*" + - "tests/**/*" + - "setup.py" + - "plugins/tests/**/*" + - "plugins/setup.py" + - "plugins/**/setup.py" diff --git a/dev-requirements.in b/dev-requirements.in index bd16ba151a..3cb16d8d3b 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -5,6 +5,7 @@ hypothesis joblib mock pytest +pytest-asyncio pytest-cov mypy pre-commit @@ -17,7 +18,7 @@ keyrings.alt # Only install tensorflow if not running on an arm Mac. tensorflow==2.8.1; python_version<'3.11' and (platform_machine!='arm64' or platform_system!='Darwin') # Tensorflow release candidate supports python 3.11 -tensorflow==2.12.0rc1; python_version>='3.11' and (platform_machine!='arm64' or platform_system!='Darwin') +tensorflow==2.13.0; python_version>='3.11' and (platform_machine!='arm64' or platform_system!='Darwin') # Newer versions of torch bring in nvidia dependencies that are not present in windows, so # we put this constraint while we do not have per-environment requirements files @@ -30,3 +31,4 @@ types-protobuf types-croniter types-mock autoflake +prometheus-client diff --git a/doc-requirements.in b/doc-requirements.in index 485b78be26..4a30b8afef 100644 --- a/doc-requirements.in +++ b/doc-requirements.in @@ -11,12 +11,14 @@ sphinx-autoapi sphinx-copybutton sphinx_fontawesome sphinx-panels -sphinxcontrib-youtube +sphinxcontrib-youtube==1.2.0 cryptography google-api-core[grpc] scikit-learn sphinx-tags sphinx-click +retry +mashumaro # Packages for Plugin docs # Package name Plugin needing it @@ -29,23 +31,25 @@ plotly # deck pandas_profiling # deck dolt_integrations # dolt great-expectations # greatexpectations +datasets # huggingface kubernetes # k8s-pod modin # modin pandera # pandera papermill # papermill jupyter # papermill +polars # polars pyspark # spark sqlalchemy # sqlalchemy torch # pytorch -# TODO: Remove after buf migration is done and packages updated -# skl2onnx # onnxscikitlearn -# tf2onnx # onnxtensorflow +skl2onnx # onnxscikitlearn +tf2onnx # onnxtensorflow tensorflow # onnxtensorflow -whylogs # whylogs +whylogs==1.3.3 # whylogs whylabs-client # whylogs -ray # ray +ray==2.6.3 # ray scikit-learn # scikit-learn dask[distributed] # dask vaex # vaex -mlflow # mlflow +mlflow==2.7.0 # mlflow duckdb # duckdb +snowflake-connector-python # snowflake diff --git a/doc-requirements.txt b/doc-requirements.txt index 072d70087e..c35ba0c62e 100644 --- a/doc-requirements.txt +++ b/doc-requirements.txt @@ -2,7 +2,7 @@ # This file is autogenerated by pip-compile with Python 3.9 # by the following command: # -# pip-compile doc-requirements.in +# make doc-requirements.txt # -e file:.#egg=flytekit # via -r doc-requirements.in @@ -10,14 +10,16 @@ absl-py==1.4.0 # via # tensorboard # tensorflow -adlfs==2023.4.0 +adlfs==2023.8.0 # via flytekit -aiobotocore==2.5.2 +aiobotocore==2.5.4 # via s3fs aiohttp==3.8.5 # via # adlfs # aiobotocore + # datasets + # fsspec # gcsfs # s3fs aioitertools==0.11.0 @@ -28,7 +30,7 @@ aiosignal==1.3.1 # ray alabaster==0.7.13 # via sphinx -alembic==1.11.1 +alembic==1.12.0 # via mlflow altair==4.2.2 # via great-expectations @@ -36,12 +38,17 @@ ansiwrap==0.8.4 # via papermill anyio==3.7.1 # via + # fastapi # jupyter-server # starlette # watchfiles aplus==0.11.0 # via vaex-core -argon2-cffi==21.3.0 +appnope==0.1.3 + # via + # ipykernel + # ipython +argon2-cffi==23.1.0 # via jupyter-server argon2-cffi-bindings==21.2.0 # via argon2-cffi @@ -49,17 +56,21 @@ arrow==1.2.3 # via # cookiecutter # isoduration +asn1crypto==1.5.1 + # via + # oscrypto + # snowflake-connector-python astroid==2.15.6 # via sphinx-autoapi -astropy==5.3.1 +astropy==5.3.3 # via vaex-astro -asttokens==2.2.1 +asttokens==2.4.0 # via stack-data astunparse==1.6.3 # via tensorflow -async-lru==2.0.3 +async-lru==2.0.4 # via jupyterlab -async-timeout==4.0.2 +async-timeout==4.0.3 # via aiohttp attrs==23.1.0 # via @@ -67,16 +78,16 @@ attrs==23.1.0 # jsonschema # referencing # visions -azure-core==1.28.0 +azure-core==1.29.1 # via # adlfs # azure-identity # azure-storage-blob azure-datalake-store==0.0.53 # via adlfs -azure-identity==1.13.0 +azure-identity==1.14.0 # via adlfs -azure-storage-blob==12.17.0 +azure-storage-blob==12.18.1 # via adlfs babel==2.12.1 # via @@ -98,7 +109,7 @@ bleach==6.0.0 # via nbconvert blinker==1.6.2 # via flask -botocore==1.29.161 +botocore==1.31.17 # via # -r doc-requirements.in # aiobotocore @@ -116,20 +127,23 @@ certifi==2023.7.22 # via # kubernetes # requests + # snowflake-connector-python cffi==1.15.1 # via # argon2-cffi-bindings # azure-datalake-store # cryptography -cfgv==3.3.1 + # snowflake-connector-python +cfgv==3.4.0 # via pre-commit -chardet==5.1.0 +chardet==5.2.0 # via binaryornot charset-normalizer==3.2.0 # via # aiohttp # requests -click==8.1.3 + # snowflake-connector-python +click==8.1.7 # via # cookiecutter # dask @@ -151,19 +165,19 @@ cloudpickle==2.2.1 # flytekit # mlflow # vaex-core -cmake==3.27.0 - # via triton colorama==0.4.6 # via great-expectations -comm==0.1.3 - # via ipykernel -contourpy==1.1.0 +comm==0.1.4 + # via + # ipykernel + # ipywidgets +contourpy==1.1.1 # via matplotlib -cookiecutter==2.2.3 +cookiecutter==2.3.0 # via flytekit croniter==1.4.1 # via flytekit -cryptography==41.0.2 +cryptography==41.0.3 # via # -r doc-requirements.in # azure-identity @@ -172,14 +186,14 @@ cryptography==41.0.2 # msal # pyjwt # pyopenssl - # secretstorage + # snowflake-connector-python css-html-js-minify==2.5.5 # via sphinx-material cycler==0.11.0 # via matplotlib dacite==1.8.1 # via ydata-profiling -dask[distributed]==2023.7.1 +dask[distributed]==2023.9.2 # via # -r doc-requirements.in # distributed @@ -190,21 +204,28 @@ dataclasses-json==0.5.9 # via # dolt-integrations # flytekit -debugpy==1.6.7 +datasets==2.14.5 + # via -r doc-requirements.in +debugpy==1.8.0 # via ipykernel decorator==5.1.1 # via # gcsfs # ipython + # retry defusedxml==0.7.1 # via nbconvert deprecated==1.2.14 # via flytekit -diskcache==5.6.1 +dill==0.3.7 + # via + # datasets + # multiprocess +diskcache==5.6.3 # via flytekit distlib==0.3.7 # via virtualenv -distributed==2023.7.1 +distributed==2023.9.2 # via dask docker==6.1.3 # via @@ -230,28 +251,33 @@ entrypoints==0.4 # altair # mlflow # papermill -exceptiongroup==1.1.2 - # via anyio +exceptiongroup==1.1.3 + # via + # anyio + # ipython executing==1.2.0 # via stack-data -fastapi==0.100.0 +fastapi==0.103.1 # via vaex-server fastjsonschema==2.18.0 # via nbformat -filelock==3.12.2 +filelock==3.12.4 # via + # huggingface-hub # ray + # snowflake-connector-python # torch - # triton # vaex-core # virtualenv -flask==2.3.2 +flask==2.3.3 # via mlflow flatbuffers==23.5.26 - # via tensorflow -flyteidl==1.5.13 + # via + # tensorflow + # tf2onnx +flyteidl==1.5.17 # via flytekit -fonttools==4.41.1 +fonttools==4.42.1 # via matplotlib fqdn==1.5.1 # via jsonschema @@ -262,13 +288,15 @@ frozenlist==1.4.0 # aiohttp # aiosignal # ray -fsspec==2023.6.0 +fsspec[http]==2023.6.0 # via # -r doc-requirements.in # adlfs # dask + # datasets # flytekit # gcsfs + # huggingface-hub # modin # s3fs furo @ git+https://github.com/flyteorg/furo@main @@ -281,7 +309,7 @@ gcsfs==2023.6.0 # via flytekit gitdb==4.0.10 # via gitpython -gitpython==3.1.32 +gitpython==3.1.36 # via # flytekit # mlflow @@ -291,7 +319,7 @@ google-api-core[grpc]==2.11.1 # google-cloud-bigquery # google-cloud-core # google-cloud-storage -google-auth==2.22.0 +google-auth==2.23.0 # via # gcsfs # google-api-core @@ -318,21 +346,19 @@ google-crc32c==1.5.0 # via google-resumable-media google-pasta==0.2.0 # via tensorflow -google-resumable-media==2.5.0 +google-resumable-media==2.6.0 # via # google-cloud-bigquery # google-cloud-storage -googleapis-common-protos==1.59.1 +googleapis-common-protos==1.60.0 # via # flyteidl # flytekit # google-api-core # grpcio-status -great-expectations==0.17.6 +great-expectations==0.17.16 # via -r doc-requirements.in -greenlet==2.0.2 - # via sqlalchemy -grpcio==1.56.2 +grpcio==1.53.0 # via # -r doc-requirements.in # flytekit @@ -342,11 +368,11 @@ grpcio==1.56.2 # ray # tensorboard # tensorflow -grpcio-status==1.56.2 +grpcio-status==1.53.0 # via # flytekit # google-api-core -gunicorn==20.1.0 +gunicorn==21.2.0 # via mlflow h11==0.14.0 # via uvicorn @@ -358,13 +384,16 @@ htmlmin==0.1.12 # via ydata-profiling httptools==0.6.0 # via uvicorn -identify==2.5.26 +huggingface-hub==0.17.1 + # via datasets +identify==2.5.29 # via pre-commit idna==3.4 # via # anyio # jsonschema # requests + # snowflake-connector-python # yarl imagehash==4.3.1 # via @@ -377,7 +406,6 @@ importlib-metadata==6.8.0 # dask # flask # flytekit - # great-expectations # jupyter-client # jupyter-lsp # jupyterlab @@ -387,22 +415,21 @@ importlib-metadata==6.8.0 # mlflow # nbconvert # sphinx -importlib-resources==6.0.0 +importlib-resources==6.0.1 # via matplotlib ipydatawidgets==4.3.5 # via pythreejs -ipykernel==6.25.0 +ipykernel==6.25.2 # via - # ipywidgets # jupyter # jupyter-console # jupyterlab # qtconsole -ipyleaflet==0.17.3 +ipyleaflet==0.17.4 # via vaex-jupyter ipympl==0.9.3 # via vaex-jupyter -ipython==8.14.0 +ipython==8.15.0 # via # great-expectations # ipykernel @@ -415,7 +442,7 @@ ipython-genutils==0.2.0 # qtconsole ipyvolume==0.6.3 # via vaex-jupyter -ipyvue==1.9.2 +ipyvue==1.10.1 # via # ipyvolume # ipyvuetify @@ -425,7 +452,7 @@ ipyvuetify==1.8.10 # vaex-jupyter ipywebrtc==0.6.0 # via ipyvolume -ipywidgets==8.0.7 +ipywidgets==8.1.1 # via # bqplot # great-expectations @@ -444,12 +471,8 @@ itsdangerous==2.1.2 # via flask jaraco-classes==3.3.0 # via keyring -jedi==0.18.2 +jedi==0.19.0 # via ipython -jeepney==0.8.0 - # via - # keyring - # secretstorage jinja2==3.1.2 # via # altair @@ -470,7 +493,7 @@ jinja2==3.1.2 # ydata-profiling jmespath==1.0.1 # via botocore -joblib==1.3.1 +joblib==1.3.2 # via # flytekit # phik @@ -479,11 +502,13 @@ json5==0.9.14 # via jupyterlab-server jsonpatch==1.33 # via great-expectations +jsonpickle==3.0.2 + # via flytekit jsonpointer==2.4 # via # jsonpatch # jsonschema -jsonschema[format-nongpl]==4.18.0 +jsonschema[format-nongpl]==4.19.0 # via # altair # great-expectations @@ -495,7 +520,7 @@ jsonschema-specifications==2023.7.1 # via jsonschema jupyter==1.0.0 # via -r doc-requirements.in -jupyter-client==8.3.0 +jupyter-client==8.3.1 # via # ipykernel # jupyter-console @@ -515,11 +540,11 @@ jupyter-core==5.3.1 # nbconvert # nbformat # qtconsole -jupyter-events==0.6.3 +jupyter-events==0.7.0 # via jupyter-server jupyter-lsp==2.2.0 # via jupyterlab -jupyter-server==2.7.0 +jupyter-server==2.7.3 # via # jupyter-lsp # jupyterlab @@ -528,21 +553,21 @@ jupyter-server==2.7.0 # notebook-shim jupyter-server-terminals==0.4.4 # via jupyter-server -jupyterlab==4.0.3 +jupyterlab==4.0.6 # via notebook jupyterlab-pygments==0.2.2 # via nbconvert -jupyterlab-server==2.24.0 +jupyterlab-server==2.25.0 # via # jupyterlab # notebook -jupyterlab-widgets==3.0.8 +jupyterlab-widgets==3.0.9 # via ipywidgets keras==2.13.1 # via tensorflow keyring==24.2.0 # via flytekit -kiwisolver==1.4.4 +kiwisolver==1.4.5 # via matplotlib kubernetes==27.2.0 # via @@ -591,7 +616,11 @@ marshmallow-enum==1.5.1 # flytekit marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.7.2 +mashumaro==3.10 + # via + # -r doc-requirements.in + # flytekit +matplotlib==3.8.0 # via # ipympl # ipyvolume @@ -611,15 +640,15 @@ mistune==3.0.1 # via # great-expectations # nbconvert -mlflow==2.5.0 +mlflow==2.7.0 # via -r doc-requirements.in modin==0.22.3 # via -r doc-requirements.in -more-itertools==10.0.0 +more-itertools==10.1.0 # via jaraco-classes mpmath==1.3.0 # via sympy -msal==1.23.0 +msal==1.24.0 # via # azure-datalake-store # azure-identity @@ -639,6 +668,8 @@ multimethod==1.9.1 # pandera # visions # ydata-profiling +multiprocess==0.70.15 + # via datasets mypy-extensions==1.0.0 # via typing-inspect natsort==8.4.0 @@ -647,18 +678,18 @@ nbclient==0.8.0 # via # nbconvert # papermill -nbconvert==7.7.3 +nbconvert==7.8.0 # via # jupyter # jupyter-server -nbformat==5.9.1 +nbformat==5.9.2 # via # great-expectations # jupyter-server # nbclient # nbconvert # papermill -nest-asyncio==1.5.6 +nest-asyncio==1.5.8 # via # ipykernel # vaex-core @@ -668,7 +699,7 @@ networkx==3.1 # visions nodeenv==1.8.0 # via pre-commit -notebook==7.0.0 +notebook==7.0.3 # via # great-expectations # jupyter @@ -684,6 +715,7 @@ numpy==1.23.5 # astropy # bqplot # contourpy + # datasets # flytekit # great-expectations # h5py @@ -695,6 +727,8 @@ numpy==1.23.5 # mlflow # modin # numba + # onnx + # onnxconverter-common # opt-einsum # pandas # pandera @@ -711,53 +745,41 @@ numpy==1.23.5 # statsmodels # tensorboard # tensorflow + # tf2onnx # vaex-core # visions # wordcloud # xarray # ydata-profiling -nvidia-cublas-cu11==11.10.3.66 - # via - # nvidia-cudnn-cu11 - # nvidia-cusolver-cu11 - # torch -nvidia-cuda-cupti-cu11==11.7.101 - # via torch -nvidia-cuda-nvrtc-cu11==11.7.99 - # via torch -nvidia-cuda-runtime-cu11==11.7.99 - # via torch -nvidia-cudnn-cu11==8.5.0.96 - # via torch -nvidia-cufft-cu11==10.9.0.58 - # via torch -nvidia-curand-cu11==10.2.10.91 - # via torch -nvidia-cusolver-cu11==11.4.0.1 - # via torch -nvidia-cusparse-cu11==11.7.4.91 - # via torch -nvidia-nccl-cu11==2.14.3 - # via torch -nvidia-nvtx-cu11==11.7.91 - # via torch oauthlib==3.2.2 # via # databricks-cli # kubernetes # requests-oauthlib +onnx==1.14.1 + # via + # onnxconverter-common + # skl2onnx + # tf2onnx +onnxconverter-common==1.13.0 + # via skl2onnx opt-einsum==3.3.0 # via tensorflow -overrides==7.3.1 +oscrypto==1.3.0 + # via snowflake-connector-python +overrides==7.4.0 # via jupyter-server packaging==23.1 # via # astropy # dask + # datasets # distributed # docker # google-cloud-bigquery # great-expectations + # gunicorn + # huggingface-hub # ipykernel # jupyter-server # jupyterlab @@ -767,11 +789,13 @@ packaging==23.1 # mlflow # modin # nbconvert + # onnxconverter-common # pandera # plotly # qtconsole # qtpy # ray + # snowflake-connector-python # sphinx # statsmodels # tensorflow @@ -780,6 +804,7 @@ pandas==1.5.3 # via # altair # bqplot + # datasets # dolt-integrations # flytekit # great-expectations @@ -813,7 +838,7 @@ phik==0.12.3 # via ydata-profiling pickleshare==0.7.5 # via ipython -pillow==10.0.0 +pillow==10.0.1 # via # imagehash # ipympl @@ -822,16 +847,19 @@ pillow==10.0.0 # vaex-viz # visions # wordcloud -platformdirs==3.9.1 +platformdirs==3.8.1 # via # jupyter-core + # snowflake-connector-python # virtualenv # whylogs -plotly==5.15.0 +plotly==5.17.0 + # via -r doc-requirements.in +polars==0.19.3 # via -r doc-requirements.in -portalocker==2.7.0 +portalocker==2.8.2 # via msal-extensions -pre-commit==3.3.3 +pre-commit==3.4.0 # via sphinx-tags progressbar2==4.2.0 # via vaex-core @@ -843,7 +871,7 @@ prompt-toolkit==3.0.39 # jupyter-console proto-plus==1.22.3 # via google-cloud-bigquery -protobuf==4.23.4 +protobuf==4.24.3 # via # flyteidl # google-api-core @@ -851,6 +879,8 @@ protobuf==4.23.4 # googleapis-common-protos # grpcio-status # mlflow + # onnx + # onnxconverter-common # proto-plus # protoc-gen-swagger # ray @@ -870,10 +900,13 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.2 # via stack-data +py==1.11.0 + # via retry py4j==0.10.9.7 # via pyspark pyarrow==10.0.1 # via + # datasets # flytekit # mlflow # vaex-core @@ -885,6 +918,8 @@ pyasn1-modules==0.3.0 # via google-auth pycparser==2.21 # via cffi +pycryptodomex==3.19.0 + # via snowflake-connector-python pydantic==1.10.12 # via # fastapi @@ -894,7 +929,7 @@ pydantic==1.10.12 # ydata-profiling pyerfa==2.0.0.3 # via astropy -pygments==2.15.1 +pygments==2.16.1 # via # furo # ipython @@ -908,9 +943,12 @@ pyjwt[crypto]==2.8.0 # via # databricks-cli # msal + # snowflake-connector-python pyopenssl==23.2.0 - # via flytekit -pyparsing==3.0.9 + # via + # flytekit + # snowflake-connector-python +pyparsing==3.1.1 # via # great-expectations # matplotlib @@ -945,12 +983,13 @@ pythreejs==2.4.2 # via ipyvolume pytimeparse==1.1.8 # via flytekit -pytz==2023.3 +pytz==2023.3.post1 # via # flytekit # great-expectations # mlflow # pandas + # snowflake-connector-python pywavelets==1.4.1 # via imagehash pyyaml==6.0.1 @@ -958,8 +997,10 @@ pyyaml==6.0.1 # astropy # cookiecutter # dask + # datasets # distributed # flytekit + # huggingface-hub # jupyter-events # kubernetes # mlflow @@ -970,26 +1011,27 @@ pyyaml==6.0.1 # uvicorn # vaex-core # ydata-profiling -pyzmq==25.1.0 +pyzmq==25.1.1 # via # ipykernel # jupyter-client # jupyter-console # jupyter-server # qtconsole -qtconsole==5.4.3 +qtconsole==5.4.4 # via jupyter -qtpy==2.3.1 +qtpy==2.4.0 # via qtconsole querystring-parser==1.2.4 # via mlflow -ray==2.6.1 +ray==2.6.3 # via -r doc-requirements.in -referencing==0.30.0 +referencing==0.30.2 # via # jsonschema # jsonschema-specifications -regex==2023.6.3 + # jupyter-events +regex==2023.8.8 # via docker-image-py requests==2.31.0 # via @@ -997,13 +1039,16 @@ requests==2.31.0 # azure-datalake-store # cookiecutter # databricks-cli + # datasets # docker # flytekit + # fsspec # gcsfs # google-api-core # google-cloud-bigquery # google-cloud-storage # great-expectations + # huggingface-hub # ipyvolume # jupyterlab-server # kubernetes @@ -1012,9 +1057,11 @@ requests==2.31.0 # papermill # ray # requests-oauthlib + # snowflake-connector-python # sphinx # sphinxcontrib-youtube # tensorboard + # tf2onnx # vaex-core # whylogs # ydata-profiling @@ -1022,6 +1069,8 @@ requests-oauthlib==1.3.1 # via # google-auth-oauthlib # kubernetes +retry==0.9.2 + # via -r doc-requirements.in rfc3339-validator==0.1.4 # via # jsonschema @@ -1030,14 +1079,15 @@ rfc3986-validator==0.1.1 # via # jsonschema # jupyter-events -rich==13.4.2 +rich==13.5.3 # via + # cookiecutter # flytekit # rich-click # vaex-core rich-click==1.6.1 # via flytekit -rpds-py==0.9.2 +rpds-py==0.10.3 # via # jsonschema # referencing @@ -1053,7 +1103,8 @@ scikit-learn==1.3.0 # via # -r doc-requirements.in # mlflow -scipy==1.10.1 + # skl2onnx +scipy==1.11.2 # via # great-expectations # imagehash @@ -1064,8 +1115,6 @@ scipy==1.10.1 # ydata-profiling seaborn==0.12.2 # via ydata-profiling -secretstorage==3.3.3 - # via keyring send2trash==1.8.2 # via jupyter-server six==1.16.0 @@ -1073,10 +1122,8 @@ six==1.16.0 # asttokens # astunparse # azure-core - # azure-identity # bleach # databricks-cli - # google-auth # google-pasta # isodate # kubernetes @@ -1086,18 +1133,24 @@ six==1.16.0 # rfc3339-validator # sphinx-code-include # tensorflow + # tf2onnx # vaex-core -smmap==5.0.0 +skl2onnx==1.15.0 + # via -r doc-requirements.in +smmap==5.0.1 # via gitdb sniffio==1.3.0 # via anyio snowballstemmer==2.2.0 # via sphinx +snowflake-connector-python==3.2.0 + # via -r doc-requirements.in sortedcontainers==2.4.0 # via # distributed # flytekit -soupsieve==2.4.1 + # snowflake-connector-python +soupsieve==2.5 # via beautifulsoup4 sphinx==4.5.0 # via @@ -1119,7 +1172,7 @@ sphinx-autoapi==2.0.1 # via -r doc-requirements.in sphinx-basic-ng==1.0.0b2 # via furo -sphinx-click==4.4.0 +sphinx-click==5.0.1 # via -r doc-requirements.in sphinx-code-include==1.1.1 # via -r doc-requirements.in @@ -1127,9 +1180,9 @@ sphinx-copybutton==0.5.2 # via -r doc-requirements.in sphinx-fontawesome==0.0.6 # via -r doc-requirements.in -sphinx-gallery==0.13.0 +sphinx-gallery==0.14.0 # via -r doc-requirements.in -sphinx-material==0.0.35 +sphinx-material==0.0.36 # via -r doc-requirements.in sphinx-panels==0.6.0 # via -r doc-requirements.in @@ -1151,7 +1204,7 @@ sphinxcontrib-serializinghtml==1.1.5 # via sphinx sphinxcontrib-youtube==1.2.0 # via -r doc-requirements.in -sqlalchemy==2.0.19 +sqlalchemy==2.0.20 # via # -r doc-requirements.in # alembic @@ -1176,7 +1229,7 @@ tangled-up-in-unicode==0.2.0 # via visions tblib==2.0.0 # via distributed -tenacity==8.2.2 +tenacity==8.2.3 # via # papermill # plotly @@ -1188,7 +1241,7 @@ tensorflow==2.13.0 # via -r doc-requirements.in tensorflow-estimator==2.13.0 # via tensorflow -tensorflow-io-gcs-filesystem==0.32.0 +tensorflow==2.13.0 # via tensorflow termcolor==2.3.0 # via tensorflow @@ -1200,12 +1253,16 @@ text-unidecode==1.3 # via python-slugify textwrap3==0.9.2 # via ansiwrap +tf2onnx==1.8.4 + # via -r doc-requirements.in threadpoolctl==3.2.0 # via scikit-learn tinycss2==1.2.1 # via nbconvert tomli==2.0.1 # via jupyterlab +tomlkit==0.12.1 + # via snowflake-connector-python toolz==0.12.0 # via # altair @@ -1213,10 +1270,8 @@ toolz==0.12.0 # distributed # partd torch==2.0.1 - # via - # -r doc-requirements.in - # triton -tornado==6.3.2 + # via -r doc-requirements.in +tornado==6.3.3 # via # distributed # ipykernel @@ -1226,12 +1281,14 @@ tornado==6.3.2 # notebook # terminado # vaex-server -tqdm==4.65.0 +tqdm==4.66.1 # via + # datasets # great-expectations + # huggingface-hub # papermill # ydata-profiling -traitlets==5.9.0 +traitlets==5.10.0 # via # bqplot # comm @@ -1260,8 +1317,6 @@ traittypes==0.2.1 # ipydatawidgets # ipyleaflet # ipyvolume -triton==2.0.0 - # via torch typed-ast==1.5.5 # via doltcli typeguard==2.13.3 @@ -1281,9 +1336,13 @@ typing-extensions==4.5.0 # fastapi # flytekit # great-expectations + # huggingface-hub # ipython + # mashumaro + # onnx # pydantic # python-utils + # snowflake-connector-python # sqlalchemy # starlette # tensorflow @@ -1314,8 +1373,9 @@ urllib3==1.26.16 # great-expectations # kubernetes # requests + # snowflake-connector-python # whylabs-client -uvicorn[standard]==0.23.1 +uvicorn[standard]==0.23.2 # via vaex-server uvloop==0.17.0 # via uvicorn @@ -1344,11 +1404,11 @@ vaex-viz==0.5.4 # via # vaex # vaex-jupyter -virtualenv==20.24.2 +virtualenv==20.24.1 # via pre-commit visions[type_image_path]==0.7.5 # via ydata-profiling -watchfiles==0.19.0 +watchfiles==0.20.0 # via uvicorn wcwidth==0.2.6 # via prompt-toolkit @@ -1358,37 +1418,31 @@ webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.6.1 +websocket-client==1.6.3 # via # docker # jupyter-server # kubernetes websockets==11.0.3 # via uvicorn -werkzeug==2.3.6 +werkzeug==2.3.7 # via # flask # tensorboard -wheel==0.41.0 +wheel==0.41.2 # via # astunparse # flytekit - # nvidia-cublas-cu11 - # nvidia-cuda-cupti-cu11 - # nvidia-cuda-runtime-cu11 - # nvidia-curand-cu11 - # nvidia-cusparse-cu11 - # nvidia-nvtx-cu11 # tensorboard -whylabs-client==0.5.3 +whylabs-client==0.5.7 # via # -r doc-requirements.in # whylogs -whylogs==1.2.6 +whylogs==1.3.3 # via -r doc-requirements.in whylogs-sketching==3.4.1.dev3 # via whylogs -widgetsnbextension==4.0.8 +widgetsnbextension==4.0.9 # via ipywidgets wordcloud==1.9.2 # via ydata-profiling @@ -1400,13 +1454,15 @@ wrapt==1.15.0 # flytekit # pandera # tensorflow -xarray==2023.7.0 +xarray==2023.8.0 # via vaex-jupyter +xxhash==3.3.0 + # via datasets xyzservices==2023.7.0 # via ipyleaflet yarl==1.9.2 # via aiohttp -ydata-profiling==4.3.2 +ydata-profiling==4.5.1 # via pandas-profiling zict==3.0.0 # via distributed diff --git a/docs/Makefile b/docs/Makefile index afa73807cb..ab89fb24cf 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -21,4 +21,4 @@ help: clean: - rm -rf ./build ./source/generated + rm -rf ./build ./source/generated ./source/plugins/generated diff --git a/docs/requirements.txt b/docs/requirements.txt deleted file mode 100644 index 1fb1b91359..0000000000 --- a/docs/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -# TODO: Remove after buf migration is done and packages updated, see doc-requirements.in -# skl2onnx and tf2onnx added here so that the plugin API reference is rendered, -# with the caveat that the docs build environment has the environment variable -# PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set so that protobuf can be parsed -# using Python, which is acceptable for docs building. -skl2onnx -tf2onnx diff --git a/docs/source/_templates/custom.rst b/docs/source/_templates/custom.rst index 9566f75d21..17c9b00963 100644 --- a/docs/source/_templates/custom.rst +++ b/docs/source/_templates/custom.rst @@ -26,6 +26,7 @@ .. rubric:: {{ _('Attributes') }} {% for item in attributes %} .. autoattribute:: {{ item }} + :noindex: {%- endfor %} {% endif %} diff --git a/docs/source/conf.py b/docs/source/conf.py index 6c0663f6b5..b12e355845 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -31,7 +31,7 @@ # -- Project information ----------------------------------------------------- project = "Flytekit" -copyright = "2021, Flyte" +copyright = "2023, Flyte" author = "Flyte" # The full version, including alpha/beta/rc tags diff --git a/docs/source/data.extend.rst b/docs/source/data.extend.rst deleted file mode 100644 index 517df6dd3f..0000000000 --- a/docs/source/data.extend.rst +++ /dev/null @@ -1,40 +0,0 @@ -###################### -Data Persistence Layer -###################### - -.. tags:: Data, AWS, GCP, Intermediate - -Flytekit provides a data persistence layer, which is used for recording metadata that is shared with the Flyte backend. This persistence layer is available for various types to store raw user data and is designed to be cross-cloud compatible. -Moreover, it is designed to be extensible and users can bring their own data persistence plugins by following the persistence interface. - -.. note:: - This will become extensive for a variety of use-cases, but the core set of APIs have been battle tested. - -.. automodule:: flytekit.core.data_persistence - :no-members: - :no-inherited-members: - :no-special-members: - -.. automodule:: flytekit.extras.persistence - :no-members: - :no-inherited-members: - :no-special-members: - -The ``fsspec`` Data Plugin --------------------------- - -Flytekit ships with a default storage driver that uses aws-cli on AWS and gsutil on GCP. By default, Flyte uploads the task outputs to S3 or GCS using these storage drivers. - -Why ``fsspec``? -^^^^^^^^^^^^^^^ - -You can use the fsspec plugin implementation to utilize all its available plugins with flytekit. The `fsspec `_ plugin provides an implementation of the data persistence layer in Flytekit. For example: HDFS, FTP are supported in fsspec, so you can use them with flytekit too. -The data persistence layer helps store logs of metadata and raw user data. -As a consequence of the implementation, an S3 driver can be installed using ``pip install s3fs``. - -`Here `__ is a code snippet that shows protocols mapped to the class it implements. - -Once you install the plugin, it overrides all default implementations of the `DataPersistencePlugins `_ and provides the ones supported by fsspec. - -.. note:: - This plugin installs fsspec core only. To install all the fsspec plugins, see `here `_. diff --git a/docs/source/design/clis.rst b/docs/source/design/clis.rst index bde51e774d..32ba6e9edb 100644 --- a/docs/source/design/clis.rst +++ b/docs/source/design/clis.rst @@ -16,9 +16,9 @@ The client code is located in ``flytekit/clients`` and there are two. * Similar to the :ref:`design-models` files, but a bit more complex, the ``raw`` one is basically a wrapper around the protobuf generated code, with some handling for authentication in place, and acts as a mechanism for autocompletion and comments. * The ``friendly`` client uses the ``raw`` client, adds handling of things like pagination, and is structurally more aligned with the functionality and call pattern of the CLI itself. -.. autoclass:: flytekit.clients.friendly.SynchronousFlyteClient +:py:class:`clients.friendly.SynchronousFlyteClient` -.. autoclass:: flytekit.clients.raw.RawSynchronousFlyteClient +:py:class:`clients.raw.RawSynchronousFlyteClient` *********************** Command Line Interfaces diff --git a/docs/source/design/control_plane.rst b/docs/source/design/control_plane.rst index 156bc46212..4c5357cb87 100644 --- a/docs/source/design/control_plane.rst +++ b/docs/source/design/control_plane.rst @@ -10,7 +10,7 @@ For those who require programmatic access to the control plane, the :mod:`~flyte certain operations in a Python runtime environment. Since this section naturally deals with the control plane, this discussion is only relevant for those who have a Flyte -backend set up and have access to it (a :std:ref:`local sandbox ` will suffice as well). +backend set up and have access to it (a local demo cluster will suffice as well). ***************************** Creating a FlyteRemote Object @@ -51,7 +51,7 @@ Sandbox ======= The :py:class:`~flytekit.configuration.Config` class's :py:meth:`~flytekit.configuration.Config.for_sandbox` method can be used to -construct the ``Config`` object, specifically to connect to the :std:ref:`sandbox `. +construct the ``Config`` object, specifically to connect to the Flyte cluster. .. code-block:: python diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst new file mode 100644 index 0000000000..50d6c42ad9 --- /dev/null +++ b/docs/source/experimental.rst @@ -0,0 +1,16 @@ +Experimental Features +===================== + +.. currentmodule:: flytekit + +.. important:: + + The constructs below are experimental and the API is subject to breaking changes. + +.. autosummary:: + :nosignatures: + :toctree: generated/ + + ~experimental.map_task + ~experimental.eager + ~experimental.EagerException diff --git a/docs/source/index.rst b/docs/source/index.rst index db5902391b..f123248cec 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -83,6 +83,6 @@ Expected output: plugins/index tasks.extend types.extend - data.extend + experimental pyflyte contributing diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 04e2a5debe..75037d3370 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -10,7 +10,7 @@ Basic Authoring =============== -These are the essentials needed to get started writing tasks and workflows. The elements here correspond well with :std:ref:`Basics ` section of the user guide. +These are the essentials needed to get started writing tasks and workflows. .. autosummary:: :nosignatures: @@ -29,9 +29,10 @@ ~core.promise.NodeOutput FlyteContextManager -Running Locally ------------------- -Tasks and Workflows can both be locally run (assuming the relevant tasks are capable of local execution). This is useful for unit testing. +.. important:: + + Tasks and Workflows can both be locally run, assuming the relevant tasks are capable of local execution. + This is useful for unit testing. Branching and Conditionals @@ -176,7 +177,6 @@ :template: custom.rst :toctree: generated/ - Deck HashMethod Documentation @@ -224,6 +224,7 @@ from flytekit.core.resources import Resources from flytekit.core.schedule import CronSchedule, FixedRate from flytekit.core.task import Secret, reference_task, task +from flytekit.core.type_engine import BatchSize from flytekit.core.workflow import ImperativeWorkflow as Workflow from flytekit.core.workflow import WorkflowFailurePolicy, reference_workflow, workflow from flytekit.deck import Deck @@ -235,6 +236,7 @@ from flytekit.models.documentation import Description, Documentation, SourceCode from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar from flytekit.models.types import LiteralType +from flytekit.sensor.sensor_engine import SensorEngine from flytekit.types import directory, file, iterator from flytekit.types.structured.structured_dataset import ( StructuredDataset, diff --git a/flytekit/bin/entrypoint.py b/flytekit/bin/entrypoint.py index c85092569f..adda627286 100644 --- a/flytekit/bin/entrypoint.py +++ b/flytekit/bin/entrypoint.py @@ -1,7 +1,10 @@ +import asyncio import contextlib import datetime as _datetime +import inspect import os import pathlib +import signal import subprocess import tempfile import traceback as _traceback @@ -18,6 +21,7 @@ ) from flytekit.core import constants as _constants from flytekit.core import utils +from flytekit.core.array_node_map_task import ArrayNodeMapTaskResolver from flytekit.core.base_task import IgnoreOutputs, PythonTask from flytekit.core.checkpointer import SyncCheckpoint from flytekit.core.context_manager import ExecutionParameters, ExecutionState, FlyteContext, FlyteContextManager @@ -89,6 +93,11 @@ def _dispatch_execute( # Decorate the dispatch execute function before calling it, this wraps all exceptions into one # of the FlyteScopedExceptions outputs = _scoped_exceptions.system_entry_point(task_def.dispatch_execute)(ctx, idl_input_literals) + if inspect.iscoroutine(outputs): + # Handle eager-mode (async) tasks + logger.info("Output is a coroutine") + outputs = asyncio.run(outputs) + # Step3a if isinstance(outputs, VoidPromise): logger.warning("Task produces no outputs") @@ -375,6 +384,7 @@ def _execute_map_task( prev_checkpoint: Optional[str] = None, dynamic_addl_distro: Optional[str] = None, dynamic_dest_dir: Optional[str] = None, + experimental: Optional[bool] = False, ): """ This function should be called by map task and aws-batch task @@ -399,11 +409,14 @@ def _execute_map_task( with setup_execution( raw_output_data_prefix, checkpoint_path, prev_checkpoint, dynamic_addl_distro, dynamic_dest_dir ) as ctx: - mtr = MapTaskResolver() - map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency) - task_index = _compute_array_job_index() - output_prefix = os.path.join(output_prefix, str(task_index)) + if experimental: + mtr = ArrayNodeMapTaskResolver() + else: + mtr = MapTaskResolver() + output_prefix = os.path.join(output_prefix, str(task_index)) + + map_task = mtr.load_task(loader_args=resolver_args, max_concurrency=max_concurrency) if test: logger.info( @@ -514,8 +527,15 @@ def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, task_exec # Use the commandline to run the task execute command rather than calling it directly in python code # since the current runtime bytecode references the older user code, rather than the downloaded distribution. - p = subprocess.run(cmd, check=False) - exit(p.returncode) + p = subprocess.Popen(cmd) + + def handle_sigterm(signum, frame): + logger.info(f"passing signum {signum} [frame={frame}] to subprocess") + p.send_signal(signum) + + signal.signal(signal.SIGTERM, handle_sigterm) + returncode = p.wait() + exit(returncode) @_pass_through.command("pyflyte-map-execute") @@ -529,6 +549,7 @@ def fast_execute_task_cmd(additional_distribution: str, dest_dir: str, task_exec @_click.option("--resolver", required=True) @_click.option("--checkpoint-path", required=False) @_click.option("--prev-checkpoint", required=False) +@_click.option("--experimental", is_flag=True, default=False, required=False) @_click.argument( "resolver-args", type=_click.UNPROCESSED, @@ -545,6 +566,7 @@ def map_execute_task_cmd( resolver, resolver_args, prev_checkpoint, + experimental, checkpoint_path, ): logger.info(get_version_message()) @@ -565,6 +587,7 @@ def map_execute_task_cmd( resolver_args=resolver_args, checkpoint_path=checkpoint_path, prev_checkpoint=prev_checkpoint, + experimental=experimental, ) diff --git a/flytekit/clients/auth/auth_client.py b/flytekit/clients/auth/auth_client.py index ec1fd4d3e1..29a995ca89 100644 --- a/flytekit/clients/auth/auth_client.py +++ b/flytekit/clients/auth/auth_client.py @@ -184,6 +184,11 @@ def __init__( redirect_uri: typing.Optional[str] = None, endpoint_metadata: typing.Optional[EndpointMetadata] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + session: typing.Optional[_requests.Session] = None, + request_auth_code_params: typing.Optional[typing.Dict[str, str]] = None, + request_access_token_params: typing.Optional[typing.Dict[str, str]] = None, + refresh_access_token_params: typing.Optional[typing.Dict[str, str]] = None, + add_request_auth_code_params_to_request_access_token_params: typing.Optional[bool] = False, ): """ Create new AuthorizationClient @@ -192,7 +197,9 @@ def __init__( :param auth_endpoint: str endpoint where auth metadata can be found :param token_endpoint: str endpoint to retrieve token from :param scopes: list[str] oauth2 scopes - :param client_id + :param client_id: oauth2 client id + :param redirect_uri: oauth2 redirect uri + :param endpoint_metadata: EndpointMetadata object to control the rendering of the page on login successful or failure :param verify: (optional) Either a boolean, in which case it controls whether we verify the server's TLS certificate, or a string, in which case it must be a path to a CA bundle to use. Defaults to ``True``. When set to @@ -201,6 +208,15 @@ def __init__( certificates, which will make your application vulnerable to man-in-the-middle (MitM) attacks. Setting verify to ``False`` may be useful during local development or testing. + :param session: (optional) A custom requests.Session object to use for making HTTP requests. + If not provided, a new Session object will be created. + :param request_auth_code_params: (optional) dict of parameters to add to login uri opened in the browser + :param request_access_token_params: (optional) dict of parameters to add when exchanging the auth code for the access token + :param refresh_access_token_params: (optional) dict of parameters to add when refreshing the access token + :param add_request_auth_code_params_to_request_access_token_params: Whether to add the `request_auth_code_params` to + the parameters sent when exchanging the auth code for the access token. Defaults to False. + Required e.g. for the PKCE flow with flyteadmin. + Not required for e.g. the standard OAuth2 flow on GCP. """ self._endpoint = endpoint self._auth_endpoint = auth_endpoint @@ -213,15 +229,13 @@ def __init__( self._client_id = client_id self._scopes = scopes or [] self._redirect_uri = redirect_uri - self._code_verifier = _generate_code_verifier() - code_challenge = _create_code_challenge(self._code_verifier) - self._code_challenge = code_challenge state = _generate_state_parameter() self._state = state self._verify = verify self._headers = {"content-type": "application/x-www-form-urlencoded"} + self._session = session or _requests.Session() - self._params = { + self._request_auth_code_params = { "client_id": client_id, # This must match the Client ID of the OAuth application. "response_type": "code", # Indicates the authorization code grant "scope": " ".join(s.strip("' ") for s in self._scopes).strip( @@ -230,10 +244,18 @@ def __init__( # callback location where the user-agent will be directed to. "redirect_uri": self._redirect_uri, "state": state, - "code_challenge": code_challenge, - "code_challenge_method": "S256", } + if request_auth_code_params: + # Allow adding additional parameters to the request_auth_code_params + self._request_auth_code_params.update(request_auth_code_params) + + self._request_access_token_params = request_access_token_params or {} + self._refresh_access_token_params = refresh_access_token_params or {} + + if add_request_auth_code_params_to_request_access_token_params: + self._request_access_token_params.update(self._request_auth_code_params) + def __repr__(self): return f"AuthorizationClient({self._auth_endpoint}, {self._token_endpoint}, {self._client_id}, {self._scopes}, {self._redirect_uri})" @@ -249,7 +271,7 @@ def _create_callback_server(self): def _request_authorization_code(self): scheme, netloc, path, _, _, _ = _urlparse.urlparse(self._auth_endpoint) - query = _urlencode(self._params) + query = _urlencode(self._request_auth_code_params) endpoint = _urlparse.urlunparse((scheme, netloc, path, None, query, None)) logging.debug(f"Requesting authorization code through {endpoint}") _webbrowser.open_new_tab(endpoint) @@ -262,9 +284,12 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials: "refresh_token": "bar", "token_type": "Bearer" } + + Can additionally contain "expires_in" and "id_token" fields. """ response_body = auth_token_resp.json() refresh_token = None + id_token = None if "access_token" not in response_body: raise ValueError('Expected "access_token" in response from oauth server') if "refresh_token" in response_body: @@ -272,23 +297,25 @@ def _credentials_from_response(self, auth_token_resp) -> Credentials: if "expires_in" in response_body: expires_in = response_body["expires_in"] access_token = response_body["access_token"] + if "id_token" in response_body: + id_token = response_body["id_token"] - return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in) + return Credentials(access_token, refresh_token, self._endpoint, expires_in=expires_in, id_token=id_token) def _request_access_token(self, auth_code) -> Credentials: if self._state != auth_code.state: raise ValueError(f"Unexpected state parameter [{auth_code.state}] passed") - self._params.update( - { - "code": auth_code.code, - "code_verifier": self._code_verifier, - "grant_type": "authorization_code", - } - ) - resp = _requests.post( + params = { + "code": auth_code.code, + "grant_type": "authorization_code", + } + + params.update(self._request_access_token_params) + + resp = self._session.post( url=self._token_endpoint, - data=self._params, + data=params, headers=self._headers, allow_redirects=False, verify=self._verify, @@ -332,13 +359,17 @@ def refresh_access_token(self, credentials: Credentials) -> Credentials: if credentials.refresh_token is None: raise ValueError("no refresh token available with which to refresh authorization credentials") - resp = _requests.post( + data = { + "refresh_token": credentials.refresh_token, + "grant_type": "refresh_token", + "client_id": self._client_id, + } + + data.update(self._refresh_access_token_params) + + resp = self._session.post( url=self._token_endpoint, - data={ - "grant_type": "refresh_token", - "client_id": self._client_id, - "refresh_token": credentials.refresh_token, - }, + data=data, headers=self._headers, allow_redirects=False, verify=self._verify, diff --git a/flytekit/clients/auth/authenticator.py b/flytekit/clients/auth/authenticator.py index b2b82831c7..0d9ee6ef95 100644 --- a/flytekit/clients/auth/authenticator.py +++ b/flytekit/clients/auth/authenticator.py @@ -5,6 +5,7 @@ from dataclasses import dataclass import click +import requests from . import token_client from .auth_client import AuthorizationClient @@ -95,6 +96,7 @@ def __init__( cfg_store: ClientConfigStore, header_key: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + session: typing.Optional[requests.Session] = None, ): """ Initialize with default creds from KeyStore using the endpoint name @@ -102,9 +104,16 @@ def __init__( super().__init__(endpoint, header_key, KeyringStore.retrieve(endpoint), verify=verify) self._cfg_store = cfg_store self._auth_client = None + self._session = session or requests.Session() def _initialize_auth_client(self): if not self._auth_client: + + from .auth_client import _create_code_challenge, _generate_code_verifier + + code_verifier = _generate_code_verifier() + code_challenge = _create_code_challenge(code_verifier) + cfg = self._cfg_store.get_client_config() self._set_header_key(cfg.header_key) self._auth_client = AuthorizationClient( @@ -115,6 +124,16 @@ def _initialize_auth_client(self): auth_endpoint=cfg.authorization_endpoint, token_endpoint=cfg.token_endpoint, verify=self._verify, + session=self._session, + request_auth_code_params={ + "code_challenge": code_challenge, + "code_challenge_method": "S256", + }, + request_access_token_params={ + "code_verifier": code_verifier, + }, + refresh_access_token_params={}, + add_request_auth_code_params_to_request_access_token_params=True, ) def refresh_credentials(self): @@ -176,6 +195,7 @@ def __init__( http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, audience: typing.Optional[str] = None, + session: typing.Optional[requests.Session] = None, ): if not client_id or not client_secret: raise ValueError("Client ID and Client SECRET both are required.") @@ -186,6 +206,7 @@ def __init__( self._client_id = client_id self._client_secret = client_secret self._audience = audience or cfg.audience + self._session = session or requests.Session() super().__init__(endpoint, cfg.header_key or header_key, http_proxy_url=http_proxy_url, verify=verify) def refresh_credentials(self): @@ -211,6 +232,7 @@ def refresh_credentials(self): verify=self._verify, scopes=scopes, audience=audience, + session=self._session, ) logging.info("Retrieved new token, expires in {}".format(expires_in)) @@ -234,6 +256,7 @@ def __init__( audience: typing.Optional[str] = None, http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + session: typing.Optional[requests.Session] = None, ): self._audience = audience cfg = cfg_store.get_client_config() @@ -245,6 +268,7 @@ def __init__( raise AuthenticationError( "Device Authentication is not available on the Flyte backend / authentication server" ) + self._session = session or requests.Session() super().__init__( endpoint=endpoint, header_key=header_key or cfg.header_key, @@ -255,7 +279,13 @@ def __init__( def refresh_credentials(self): resp = token_client.get_device_code( - self._device_auth_endpoint, self._client_id, self._audience, self._scope, self._http_proxy_url, self._verify + self._device_auth_endpoint, + self._client_id, + self._audience, + self._scope, + self._http_proxy_url, + self._verify, + self._session, ) text = f"To Authenticate, navigate in a browser to the following URL: {click.style(resp.verification_uri, fg='blue', underline=True)} and enter code: {click.style(resp.user_code, fg='blue')}" click.secho(text) diff --git a/flytekit/clients/auth/keyring.py b/flytekit/clients/auth/keyring.py index 79f5e86c68..2d4b4488f0 100644 --- a/flytekit/clients/auth/keyring.py +++ b/flytekit/clients/auth/keyring.py @@ -3,7 +3,7 @@ from dataclasses import dataclass import keyring as _keyring -from keyring.errors import NoKeyringError +from keyring.errors import NoKeyringError, PasswordDeleteError @dataclass @@ -16,6 +16,7 @@ class Credentials(object): refresh_token: str = "na" for_endpoint: str = "flyte-default" expires_in: typing.Optional[int] = None + id_token: typing.Optional[str] = None class KeyringStore: @@ -25,20 +26,28 @@ class KeyringStore: _access_token_key = "access_token" _refresh_token_key = "refresh_token" + _id_token_key = "id_token" @staticmethod def store(credentials: Credentials) -> Credentials: try: - _keyring.set_password( - credentials.for_endpoint, - KeyringStore._refresh_token_key, - credentials.refresh_token, - ) + if credentials.refresh_token: + _keyring.set_password( + credentials.for_endpoint, + KeyringStore._refresh_token_key, + credentials.refresh_token, + ) _keyring.set_password( credentials.for_endpoint, KeyringStore._access_token_key, credentials.access_token, ) + if credentials.id_token: + _keyring.set_password( + credentials.for_endpoint, + KeyringStore._id_token_key, + credentials.id_token, + ) except NoKeyringError as e: logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") return credentials @@ -48,18 +57,23 @@ def retrieve(for_endpoint: str) -> typing.Optional[Credentials]: try: refresh_token = _keyring.get_password(for_endpoint, KeyringStore._refresh_token_key) access_token = _keyring.get_password(for_endpoint, KeyringStore._access_token_key) + id_token = _keyring.get_password(for_endpoint, KeyringStore._id_token_key) except NoKeyringError as e: logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") return None - if not access_token: + if not access_token and not id_token: return None - return Credentials(access_token, refresh_token, for_endpoint) + return Credentials(access_token, refresh_token, for_endpoint, id_token=id_token) @staticmethod def delete(for_endpoint: str): try: _keyring.delete_password(for_endpoint, KeyringStore._access_token_key) _keyring.delete_password(for_endpoint, KeyringStore._refresh_token_key) + try: + _keyring.delete_password(for_endpoint, KeyringStore._id_token_key) + except PasswordDeleteError as e: + logging.debug(f"Id token not found in key store, not deleting. Error: {e}") except NoKeyringError as e: logging.debug(f"KeyRing not available, tokens will not be cached. Error: {e}") diff --git a/flytekit/clients/auth/token_client.py b/flytekit/clients/auth/token_client.py index e5eae32ed7..4584866b21 100644 --- a/flytekit/clients/auth/token_client.py +++ b/flytekit/clients/auth/token_client.py @@ -78,6 +78,7 @@ def get_token( grant_type: GrantType = GrantType.CLIENT_CREDS, http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + session: typing.Optional[requests.Session] = None, ) -> typing.Tuple[str, int]: """ :rtype: (Text,Int) The first element is the access token retrieved from the IDP, the second is the expiration @@ -103,7 +104,10 @@ def get_token( body["audience"] = audience proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None - response = requests.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify) + + if not session: + session = requests.Session() + response = session.post(token_endpoint, data=body, headers=headers, proxies=proxies, verify=verify) if not response.ok: j = response.json() @@ -125,6 +129,7 @@ def get_device_code( scope: typing.Optional[typing.List[str]] = None, http_proxy_url: typing.Optional[str] = None, verify: typing.Optional[typing.Union[bool, str]] = None, + session: typing.Optional[requests.Session] = None, ) -> DeviceCodeResponse: """ Retrieves the device Authentication code that can be done to authenticate the request using a browser on a @@ -133,7 +138,9 @@ def get_device_code( _scope = " ".join(scope) if scope is not None else "" payload = {"client_id": client_id, "scope": _scope, "audience": audience} proxies = {"https": http_proxy_url, "http": http_proxy_url} if http_proxy_url else None - resp = requests.post(device_auth_endpoint, payload, proxies=proxies, verify=verify) + if not session: + session = requests.Session() + resp = session.post(device_auth_endpoint, payload, proxies=proxies, verify=verify) if not resp.ok: raise AuthenticationError(f"Unable to retrieve Device Authentication Code for {payload}, Reason {resp.reason}") return DeviceCodeResponse.from_json_response(resp.json()) diff --git a/flytekit/clients/auth_helper.py b/flytekit/clients/auth_helper.py index 5c4fafe579..75bc52378e 100644 --- a/flytekit/clients/auth_helper.py +++ b/flytekit/clients/auth_helper.py @@ -1,7 +1,9 @@ import logging import ssl +from http import HTTPStatus import grpc +import requests from flyteidl.service.auth_pb2 import OAuth2MetadataRequest, PublicClientAuthConfigRequest from flyteidl.service.auth_pb2_grpc import AuthMetadataServiceStub from OpenSSL import crypto @@ -16,6 +18,7 @@ PKCEAuthenticator, ) from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor +from flytekit.clients.grpc_utils.default_metadata_interceptor import DefaultMetadataInterceptor from flytekit.clients.grpc_utils.wrap_exception_interceptor import RetryExceptionWrapperInterceptor from flytekit.configuration import AuthType, PlatformConfig @@ -65,8 +68,10 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth elif cfg.ca_cert_file_path: verify = cfg.ca_cert_file_path + session = get_session(cfg) + if cfg_auth == AuthType.STANDARD or cfg_auth == AuthType.PKCE: - return PKCEAuthenticator(cfg.endpoint, cfg_store, verify=verify) + return PKCEAuthenticator(cfg.endpoint, cfg_store, verify=verify, session=session) elif cfg_auth == AuthType.BASIC or cfg_auth == AuthType.CLIENT_CREDENTIALS or cfg_auth == AuthType.CLIENTSECRET: return ClientCredentialsAuthenticator( endpoint=cfg.endpoint, @@ -77,6 +82,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth audience=cfg.audience, http_proxy_url=cfg.http_proxy_url, verify=verify, + session=session, ) elif cfg_auth == AuthType.EXTERNAL_PROCESS or cfg_auth == AuthType.EXTERNALCOMMAND: client_cfg = None @@ -93,6 +99,7 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth audience=cfg.audience, http_proxy_url=cfg.http_proxy_url, verify=verify, + session=session, ) else: raise ValueError( @@ -100,6 +107,28 @@ def get_authenticator(cfg: PlatformConfig, cfg_store: ClientConfigStore) -> Auth ) +def get_proxy_authenticator(cfg: PlatformConfig) -> Authenticator: + return CommandAuthenticator( + command=cfg.proxy_command, + header_key="proxy-authorization", + ) + + +def upgrade_channel_to_proxy_authenticated(cfg: PlatformConfig, in_channel: grpc.Channel) -> grpc.Channel: + """ + If activated in the platform config, given a grpc.Channel, preferrably a secure channel, it returns a composed + channel that uses Interceptor to perform authentication with a proxy infront of Flyte + :param cfg: PlatformConfig + :param in_channel: grpc.Channel Precreated channel + :return: grpc.Channel. New composite channel + """ + if cfg.proxy_command: + proxy_authenticator = get_proxy_authenticator(cfg) + return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(proxy_authenticator)) + else: + return in_channel + + def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Channel) -> grpc.Channel: """ Given a grpc.Channel, preferrably a secure channel, it returns a composed channel that uses Interceptor to @@ -121,6 +150,7 @@ def get_authenticated_channel(cfg: PlatformConfig) -> grpc.Channel: if cfg.insecure else grpc.secure_channel(cfg.endpoint, grpc.ssl_channel_credentials()) ) # noqa + channel = upgrade_channel_to_proxy_authenticated(cfg, channel) return upgrade_channel_to_authenticated(cfg, channel) @@ -171,7 +201,7 @@ def get_channel(cfg: PlatformConfig, **kwargs) -> grpc.Channel: :return: grpc.Channel (secure / insecure) """ if cfg.insecure: - return grpc.insecure_channel(cfg.endpoint, **kwargs) + return grpc.intercept_channel(grpc.insecure_channel(cfg.endpoint, **kwargs), DefaultMetadataInterceptor()) credentials = None if "credentials" not in kwargs: @@ -189,11 +219,14 @@ def get_channel(cfg: PlatformConfig, **kwargs) -> grpc.Channel: ) else: credentials = kwargs["credentials"] - return grpc.secure_channel( - target=cfg.endpoint, - credentials=credentials, - options=kwargs.get("options", None), - compression=kwargs.get("compression", None), + return grpc.intercept_channel( + grpc.secure_channel( + target=cfg.endpoint, + credentials=credentials, + options=kwargs.get("options", None), + compression=kwargs.get("compression", None), + ), + DefaultMetadataInterceptor(), ) @@ -209,3 +242,64 @@ def wrap_exceptions_channel(cfg: PlatformConfig, in_channel: grpc.Channel) -> gr :return: grpc.Channel """ return grpc.intercept_channel(in_channel, RetryExceptionWrapperInterceptor(max_retries=cfg.rpc_retries)) + + +class AuthenticationHTTPAdapter(requests.adapters.HTTPAdapter): + """ + A custom HTTPAdapter that adds authentication headers to requests of a session. + """ + + def __init__(self, authenticator, *args, **kwargs): + self.authenticator = authenticator + super().__init__(*args, **kwargs) + + def add_auth_header(self, request): + """ + Adds authentication headers to the request. + :param request: The request object to add headers to. + """ + if self.authenticator.get_credentials() is None: + self.authenticator.refresh_credentials() + + auth_header_key, auth_header_val = self.authenticator.fetch_grpc_call_auth_metadata() + request.headers[auth_header_key] = auth_header_val + + def send(self, request, *args, **kwargs): + """ + Sends the request with added authentication headers. + If the response returns a 401 status code, refreshes the credentials and retries the request. + :param request: The request object to send. + :return: The response object. + """ + self.add_auth_header(request) + response = super().send(request, *args, **kwargs) + if response.status_code == HTTPStatus.UNAUTHORIZED: + self.authenticator.refresh_credentials() + self.add_auth_header(request) + response = super().send(request, *args, **kwargs) + return response + + +def upgrade_session_to_proxy_authenticated(cfg: PlatformConfig, session: requests.Session) -> requests.Session: + """ + Given a requests.Session, it returns a new session that uses a custom HTTPAdapter to + perform authentication with a proxy infront of Flyte + + :param cfg: PlatformConfig + :param session: requests.Session Precreated session + :return: requests.Session. New session with custom HTTPAdapter mounted + """ + proxy_authenticator = get_proxy_authenticator(cfg) + adapter = AuthenticationHTTPAdapter(proxy_authenticator) + + session.mount("http://", adapter) + session.mount("https://", adapter) + return session + + +def get_session(cfg: PlatformConfig, **kwargs) -> requests.Session: + """Return a new session for the given platform config.""" + session = requests.Session() + if cfg.proxy_command: + session = upgrade_session_to_proxy_authenticated(cfg, session) + return session diff --git a/flytekit/clients/friendly.py b/flytekit/clients/friendly.py index 2bae266a53..d6d0581b2a 100644 --- a/flytekit/clients/friendly.py +++ b/flytekit/clients/friendly.py @@ -802,8 +802,8 @@ def list_task_executions_paginated( :param flytekit.models.core.identifier.NodeExecutionIdentifier node_execution_identifier: :param int limit: :param Text token: [Optional] If specified, this specifies where in the rows of results to skip before reading. - If you previously retrieved a page response with token="foo" and you want the next page, - specify token="foo". + If you previously retrieved a page response with token="foo" and you want the next page, + specify token="foo". :param list[flytekit.models.filters.Filter] filters: :param flytekit.models.admin.common.Sort sort_by: [Optional] If provided, the results will be sorted. :rtype: (list[flytekit.models.admin.task_execution.TaskExecution], Text) @@ -981,7 +981,8 @@ def get_upload_signed_url( self, project: str, domain: str, content_md5: bytes, filename: str = None, expires_in: datetime.timedelta = None ) -> _data_proxy_pb2.CreateUploadLocationResponse: """ - Get a signed url to be used during fast registration + Get a signed url to be used during fast registration. + :param str project: Project to create the upload location for :param str domain: Domain to create the upload location for :param bytes content_md5: ContentMD5 restricts the upload location to the specific MD5 provided. The content_md5 diff --git a/flytekit/clients/grpc_utils/auth_interceptor.py b/flytekit/clients/grpc_utils/auth_interceptor.py index 21bcc30136..e467801a77 100644 --- a/flytekit/clients/grpc_utils/auth_interceptor.py +++ b/flytekit/clients/grpc_utils/auth_interceptor.py @@ -32,7 +32,7 @@ def _call_details_with_auth_metadata(self, client_call_details: grpc.ClientCallD """ Returns new ClientCallDetails with metadata added. """ - metadata = None + metadata = client_call_details.metadata auth_metadata = self._authenticator.fetch_grpc_call_auth_metadata() if auth_metadata: metadata = [] @@ -61,7 +61,7 @@ def intercept_unary_unary( fut: grpc.Future = continuation(updated_call_details, request) e = fut.exception() if e: - if e.code() == grpc.StatusCode.UNAUTHENTICATED: + if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN: self._authenticator.refresh_credentials() updated_call_details = self._call_details_with_auth_metadata(client_call_details) return continuation(updated_call_details, request) diff --git a/flytekit/clients/grpc_utils/default_metadata_interceptor.py b/flytekit/clients/grpc_utils/default_metadata_interceptor.py new file mode 100644 index 0000000000..12b06cca03 --- /dev/null +++ b/flytekit/clients/grpc_utils/default_metadata_interceptor.py @@ -0,0 +1,43 @@ +import typing + +import grpc + +from flytekit.clients.grpc_utils.auth_interceptor import _ClientCallDetails + + +class DefaultMetadataInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor): + def _inject_default_metadata(self, call_details: grpc.ClientCallDetails): + metadata = [("accept", "application/grpc")] + if call_details.metadata: + metadata.extend(list(call_details.metadata)) + new_details = _ClientCallDetails( + call_details.method, + call_details.timeout, + metadata, + call_details.credentials, + ) + return new_details + + def intercept_unary_unary( + self, + continuation: typing.Callable, + client_call_details: grpc.ClientCallDetails, + request: typing.Any, + ): + """ + Intercepts unary calls and inject default metadata + """ + updated_call_details = self._inject_default_metadata(client_call_details) + return continuation(updated_call_details, request) + + def intercept_unary_stream( + self, + continuation: typing.Callable, + client_call_details: grpc.ClientCallDetails, + request: typing.Any, + ): + """ + Handles a stream call and inject default metadata + """ + updated_call_details = self._inject_default_metadata(client_call_details) + return continuation(updated_call_details, request) diff --git a/flytekit/clients/raw.py b/flytekit/clients/raw.py index 836d5ffa3b..6cb80d4b8f 100644 --- a/flytekit/clients/raw.py +++ b/flytekit/clients/raw.py @@ -11,7 +11,12 @@ from flyteidl.service import signal_pb2_grpc as signal_service from flyteidl.service.dataproxy_pb2_grpc import DataProxyServiceStub -from flytekit.clients.auth_helper import get_channel, upgrade_channel_to_authenticated, wrap_exceptions_channel +from flytekit.clients.auth_helper import ( + get_channel, + upgrade_channel_to_authenticated, + upgrade_channel_to_proxy_authenticated, + wrap_exceptions_channel, +) from flytekit.configuration import PlatformConfig from flytekit.loggers import cli_logger @@ -41,7 +46,9 @@ def __init__(self, cfg: PlatformConfig, **kwargs): insecure: if insecure is desired """ self._cfg = cfg - self._channel = wrap_exceptions_channel(cfg, upgrade_channel_to_authenticated(cfg, get_channel(cfg))) + self._channel = wrap_exceptions_channel( + cfg, upgrade_channel_to_authenticated(cfg, upgrade_channel_to_proxy_authenticated(cfg, get_channel(cfg))) + ) self._stub = _admin_service.AdminServiceStub(self._channel) self._signal = signal_service.SignalServiceStub(self._channel) self._dataproxy_stub = dataproxy_service.DataProxyServiceStub(self._channel) diff --git a/flytekit/clis/sdk_in_container/backfill.py b/flytekit/clis/sdk_in_container/backfill.py index 49c2667d5b..7723a98f44 100644 --- a/flytekit/clis/sdk_in_container/backfill.py +++ b/flytekit/clis/sdk_in_container/backfill.py @@ -3,8 +3,10 @@ import rich_click as click +from flytekit import WorkflowFailurePolicy from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context -from flytekit.clis.sdk_in_container.run import DateTimeType, DurationParamType +from flytekit.clis.sdk_in_container.utils import domain_option_dec, project_option_dec +from flytekit.interaction.click_types import DateTimeType, DurationParamType _backfill_help = """ The backfill command generates and registers a new workflow based on the input launchplan to run an @@ -42,22 +44,8 @@ def resolve_backfill_window( @click.command("backfill", help=_backfill_help) -@click.option( - "-p", - "--project", - required=False, - type=str, - default="flytesnacks", - help="Project to register and run this workflow in", -) -@click.option( - "-d", - "--domain", - required=False, - type=str, - default="development", - help="Domain to register and run this workflow in", -) +@project_option_dec +@domain_option_dec @click.option( "-v", "--version", @@ -125,6 +113,17 @@ def resolve_backfill_window( "backfills between. This is needed with from-date / to-date. Optional if both from-date and to-date are " "provided", ) +@click.option( + "--fail-fast/--no-fail-fast", + required=False, + type=bool, + is_flag=True, + default=True, + show_default=True, + help="If set to true, the backfill will fail immediately (WorkflowFailurePolicy.FAIL_IMMEDIATELY) if any of the " + "backfill steps fail. If set to false, the backfill will continue to run even if some of the backfill steps " + "fail (WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE).", +) @click.argument( "launchplan", required=True, @@ -151,6 +150,7 @@ def backfill( parallel: bool, execution_name: str, version: str, + fail_fast: bool, ): from_date, to_date = resolve_backfill_window(from_date, to_date, backfill_window) remote = get_and_save_remote_with_click_context(ctx, project, domain) @@ -167,6 +167,9 @@ def backfill( dry_run=dry_run, execute=execute, parallel=parallel, + failure_policy=WorkflowFailurePolicy.FAIL_IMMEDIATELY + if fail_fast + else WorkflowFailurePolicy.FAIL_AFTER_EXECUTABLE_NODES_COMPLETE, ) if dry_run: return diff --git a/flytekit/clis/sdk_in_container/build.py b/flytekit/clis/sdk_in_container/build.py index 33b8346f10..c4eb819eb6 100644 --- a/flytekit/clis/sdk_in_container/build.py +++ b/flytekit/clis/sdk_in_container/build.py @@ -1,32 +1,30 @@ -import os -import pathlib import typing +from dataclasses import dataclass import rich_click as click from typing_extensions import OrderedDict -from flytekit.clis.sdk_in_container.constants import CTX_MODULE, CTX_PROJECT_ROOT -from flytekit.clis.sdk_in_container.run import RUN_LEVEL_PARAMS_KEY, get_entities_in_file, load_naive_entity +from flytekit.clis.sdk_in_container.run import RunCommand, RunLevelParams, WorkflowCommand +from flytekit.clis.sdk_in_container.utils import make_field from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.workflow import PythonFunctionWorkflow -from flytekit.tools.script_mode import _find_project_root from flytekit.tools.translator import get_serializable -def get_workflow_command_base_params() -> typing.List[click.Option]: - """ - Return the set of base parameters added to every pyflyte build workflow subcommand. - """ - return [ +@dataclass +class BuildParams(RunLevelParams): + + fast: bool = make_field( click.Option( param_decls=["--fast"], required=False, is_flag=True, default=False, + show_default=True, help="Use fast serialization. The image won't contain the source code. The value is false by default.", - ), - ] + ) + ) def build_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]): @@ -37,84 +35,60 @@ def build_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflo def _build(*args, **kwargs): m = OrderedDict() options = None - run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY] + build_params: BuildParams = ctx.obj - project, domain = run_level_params.get("project"), run_level_params.get("domain") serialization_settings = SerializationSettings( - project=project, - domain=domain, + project=build_params.project, + domain=build_params.domain, image_config=ImageConfig.auto_default_image(), ) - if not run_level_params.get("fast"): - serialization_settings.source_root = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT) + if not build_params.fast: + serialization_settings.source_root = build_params.computed_params.project_root _ = get_serializable(m, settings=serialization_settings, entity=entity, options=options) return _build -class WorkflowCommand(click.MultiCommand): +class BuildWorkflowCommand(WorkflowCommand): """ click multicommand at the python file layer, subcommands should be all the workflows in the file. """ - def __init__(self, filename: str, *args, **kwargs): - super().__init__(*args, **kwargs) - self._filename = pathlib.Path(filename).resolve() - - def list_commands(self, ctx): - entities = get_entities_in_file(self._filename.__str__(), False) - return entities.all() - - def get_command(self, ctx, exe_entity): - """ - This command uses the filename with which this command was created, and the string name of the entity passed - after the Python filename on the command line, to load the Python object, and then return the Command that - click should run. - :param ctx: The click Context object. - :param exe_entity: string of the flyte entity provided by the user. Should be the name of a workflow, or task - function. - :return: - """ - rel_path = os.path.relpath(self._filename) - if rel_path.startswith(".."): - raise ValueError( - f"You must call pyflyte from the same or parent dir, {self._filename} not under {os.getcwd()}" - ) - - project_root = _find_project_root(self._filename) - rel_path = self._filename.relative_to(project_root) - module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") - - ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root - ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module - - entity = load_naive_entity(module, exe_entity, project_root) - + def _create_command( + self, + ctx: click.Context, + entity_name: str, + run_level_params: RunLevelParams, + loaded_entity: typing.Any, + is_workflow: bool, + ): cmd = click.Command( - name=exe_entity, - callback=build_command(ctx, entity), - help=f"Build an image for {module}.{exe_entity}.", + name=entity_name, + callback=build_command(ctx, loaded_entity), + help=f"Build an image for {run_level_params.computed_params.module}.{entity_name}.", ) return cmd -class BuildCommand(click.MultiCommand): +class BuildCommand(RunCommand): """ A click command group for building a image for flyte workflows & tasks in a file. """ def __init__(self, *args, **kwargs): - params = get_workflow_command_base_params() - super().__init__(*args, params=params, **kwargs) + params = BuildParams.options() + kwargs["params"] = params + super().__init__(*args, **kwargs) - def list_commands(self, ctx): - return [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"] + def list_commands(self, ctx, *args, **kwargs): + return super().list_commands(ctx, add_remote=False) def get_command(self, ctx, filename): - if ctx.obj: - ctx.obj[RUN_LEVEL_PARAMS_KEY] = ctx.params - return WorkflowCommand(filename, name=filename, help="Build an image for [workflow|task]") + if ctx.obj is None: + ctx.obj = {} + ctx.obj = BuildParams.from_dict(ctx.params) + return BuildWorkflowCommand(filename, name=filename, help=f"Build an image for [workflow|task] from {filename}") _build_help = """ diff --git a/flytekit/clis/sdk_in_container/constants.py b/flytekit/clis/sdk_in_container/constants.py index 8059d4d14d..dd9c6f4e87 100644 --- a/flytekit/clis/sdk_in_container/constants.py +++ b/flytekit/clis/sdk_in_container/constants.py @@ -1,5 +1,3 @@ -import rich_click as _click - CTX_PROJECT = "project" CTX_DOMAIN = "domain" CTX_VERSION = "version" @@ -7,31 +5,4 @@ CTX_PACKAGES = "pkgs" CTX_NOTIFICATIONS = "notifications" CTX_CONFIG_FILE = "config_file" -CTX_PROJECT_ROOT = "project_root" -CTX_MODULE = "module" CTX_VERBOSE = "verbose" -CTX_COPY_ALL = "copy_all" -CTX_FILE_NAME = "file_name" - - -project_option = _click.option( - "-p", - "--project", - required=True, - type=str, - help="Flyte project to use. You can have more than one project per repo", -) -domain_option = _click.option( - "-d", - "--domain", - required=True, - type=str, - help="This is usually development, staging, or production", -) -version_option = _click.option( - "-v", - "--version", - required=False, - type=str, - help="This is the version to apply globally for this context", -) diff --git a/flytekit/clis/sdk_in_container/helpers.py b/flytekit/clis/sdk_in_container/helpers.py index 4c66a2046c..03bc4a6ccb 100644 --- a/flytekit/clis/sdk_in_container/helpers.py +++ b/flytekit/clis/sdk_in_container/helpers.py @@ -1,3 +1,4 @@ +import typing from dataclasses import replace from typing import Optional @@ -11,6 +12,19 @@ FLYTE_REMOTE_INSTANCE_KEY = "flyte_remote" +def get_remote(cfg_file_path: typing.Optional[str], project: str, domain: str) -> FlyteRemote: + cfg_file = get_config_file(cfg_file_path) + if cfg_file is None: + cfg_obj = Config.for_sandbox() + cli_logger.info("No config files found, creating remote with sandbox config") + else: + cfg_obj = Config.auto(cfg_file_path) + cli_logger.info( + f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_path}" if cfg_file_path else "") + ) + return FlyteRemote(cfg_obj, default_project=project, default_domain=domain) + + def get_and_save_remote_with_click_context( ctx: click.Context, project: str, domain: str, save: bool = True ) -> FlyteRemote: @@ -24,17 +38,10 @@ def get_and_save_remote_with_click_context( :param save: If false, will not mutate the context.obj dict :return: FlyteRemote instance """ + if ctx.obj.get(FLYTE_REMOTE_INSTANCE_KEY) is not None: + return ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] cfg_file_location = ctx.obj.get(CTX_CONFIG_FILE) - cfg_file = get_config_file(cfg_file_location) - if cfg_file is None: - cfg_obj = Config.for_sandbox() - cli_logger.info("No config files found, creating remote with sandbox config") - else: - cfg_obj = Config.auto(cfg_file_location) - cli_logger.info( - f"Creating remote with config {cfg_obj}" + (f" with file {cfg_file_location}" if cfg_file_location else "") - ) - r = FlyteRemote(cfg_obj, default_project=project, default_domain=domain) + r = get_remote(cfg_file_location, project, domain) if save: ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] = r return r diff --git a/flytekit/clis/sdk_in_container/launchplan.py b/flytekit/clis/sdk_in_container/launchplan.py index 2d33e2e3d7..b5022ef5d7 100644 --- a/flytekit/clis/sdk_in_container/launchplan.py +++ b/flytekit/clis/sdk_in_container/launchplan.py @@ -1,6 +1,8 @@ import rich_click as click +from rich.progress import Progress from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.clis.sdk_in_container.utils import domain_option_dec, project_option_dec from flytekit.models.launch_plan import LaunchPlanState _launchplan_help = """ @@ -13,22 +15,8 @@ @click.command("launchplan", help=_launchplan_help) -@click.option( - "-p", - "--project", - required=False, - type=str, - default="flytesnacks", - help="Fecth launchplan from this project", -) -@click.option( - "-d", - "--domain", - required=False, - type=str, - default="development", - help="Fetch launchplan from this domain", -) +@project_option_dec +@domain_option_dec @click.option( "--activate/--deactivate", required=True, @@ -57,18 +45,25 @@ def launchplan( launchplan_version: str, ): remote = get_and_save_remote_with_click_context(ctx, project, domain) - try: - launchplan = remote.fetch_launch_plan( - project=project, - domain=domain, - name=launchplan, - version=launchplan_version, - ) - state = LaunchPlanState.ACTIVE if activate else LaunchPlanState.INACTIVE - remote.client.update_launch_plan(id=launchplan.id, state=state) - click.secho( - f"\n Launchplan was set to {LaunchPlanState.enum_to_string(state)}: {launchplan.name}:{launchplan.id.version}", - fg="green", - ) - except StopIteration as e: - click.secho(f"{e.value}", fg="red") + with Progress() as progress: + t1 = progress.add_task(f"[cyan] {'Activating' if activate else 'Deactivating'}...", total=1) + try: + progress.start_task(t1) + launchplan = remote.fetch_launch_plan( + project=project, + domain=domain, + name=launchplan, + version=launchplan_version, + ) + progress.advance(t1) + + state = LaunchPlanState.ACTIVE if activate else LaunchPlanState.INACTIVE + remote.client.update_launch_plan(id=launchplan.id, state=state) + progress.advance(t1) + progress.update(t1, completed=True, visible=False) + click.secho( + f"\n Launchplan was set to {LaunchPlanState.enum_to_string(state)}: {launchplan.name}:{launchplan.id.version}", + fg="green", + ) + except StopIteration as e: + click.secho(f"{e.value}", fg="red") diff --git a/flytekit/clis/sdk_in_container/pyflyte.py b/flytekit/clis/sdk_in_container/pyflyte.py index 67bc99ab19..6b5e1f0cf2 100644 --- a/flytekit/clis/sdk_in_container/pyflyte.py +++ b/flytekit/clis/sdk_in_container/pyflyte.py @@ -1,9 +1,7 @@ import os import typing -import grpc import rich_click as click -from google.protobuf.json_format import MessageToJson from flytekit import configuration from flytekit.clis.sdk_in_container.backfill import backfill @@ -18,77 +16,12 @@ from flytekit.clis.sdk_in_container.run import run from flytekit.clis.sdk_in_container.serialize import serialize from flytekit.clis.sdk_in_container.serve import serve +from flytekit.clis.sdk_in_container.utils import ErrorHandlingCommand, validate_package from flytekit.configuration.file import FLYTECTL_CONFIG_ENV_VAR, FLYTECTL_CONFIG_ENV_VAR_OVERRIDE from flytekit.configuration.internal import LocalSDK -from flytekit.exceptions.base import FlyteException -from flytekit.exceptions.user import FlyteInvalidInputException from flytekit.loggers import cli_logger -def validate_package(ctx, param, values): - pkgs = [] - for val in values: - if "/" in val or "-" in val or "\\" in val: - raise click.BadParameter( - f"Illegal package value {val} for parameter: {param}. Expected for the form [a.b.c]" - ) - elif "," in val: - pkgs.extend(val.split(",")) - else: - pkgs.append(val) - cli_logger.debug(f"Using packages: {pkgs}") - return pkgs - - -def pretty_print_grpc_error(e: grpc.RpcError): - if isinstance(e, grpc._channel._InactiveRpcError): # noqa - click.secho(f"RPC Failed, with Status: {e.code()}", fg="red", bold=True) - click.secho(f"\tdetails: {e.details()}", fg="magenta", bold=True) - click.secho(f"\tDebug string {e.debug_error_string()}", dim=True) - return - - -def pretty_print_exception(e: Exception): - if isinstance(e, click.exceptions.Exit): - raise e - - if isinstance(e, click.ClickException): - click.secho(e.message, fg="red") - raise e - - if isinstance(e, FlyteException): - click.secho(f"Failed with Exception Code: {e._ERROR_CODE}", fg="red") # noqa - if isinstance(e, FlyteInvalidInputException): - click.secho("Request rejected by the API, due to Invalid input.", fg="red") - click.secho(f"\tInput Request: {MessageToJson(e.request)}", dim=True) - - cause = e.__cause__ - if cause: - if isinstance(cause, grpc.RpcError): - pretty_print_grpc_error(cause) - else: - click.secho(f"Underlying Exception: {cause}") - return - - if isinstance(e, grpc.RpcError): - pretty_print_grpc_error(e) - return - - click.secho(f"Failed with Unknown Exception {type(e)} Reason: {e}", fg="red") # noqa - - -class ErrorHandlingCommand(click.RichGroup): - def invoke(self, ctx: click.Context) -> typing.Any: - try: - return super().invoke(ctx) - except Exception as e: - if CTX_VERBOSE in ctx.obj and ctx.obj[CTX_VERBOSE]: - print("Verbose mode on") - raise e - pretty_print_exception(e) - raise SystemExit(e) - - @click.group("pyflyte", invoke_without_command=True, cls=ErrorHandlingCommand) @click.option( "--verbose", required=False, default=False, is_flag=True, help="Show verbose messages and exception traces" diff --git a/flytekit/clis/sdk_in_container/register.py b/flytekit/clis/sdk_in_container/register.py index afc7aeb99e..2313b00fc6 100644 --- a/flytekit/clis/sdk_in_container/register.py +++ b/flytekit/clis/sdk_in_container/register.py @@ -6,6 +6,7 @@ from flytekit.clis.helpers import display_help_with_error from flytekit.clis.sdk_in_container import constants from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context, patch_image_config +from flytekit.clis.sdk_in_container.utils import domain_option_dec, project_option_dec from flytekit.configuration import ImageConfig from flytekit.configuration.default_images import DefaultImages from flytekit.loggers import cli_logger @@ -27,14 +28,8 @@ @click.command("register", help=_register_help) -@click.option( - "-p", - "--project", - required=False, - type=str, - default="flytesnacks", - help="Project to register and run this workflow in", -) +@project_option_dec +@domain_option_dec @click.option( "-d", "--domain", @@ -113,6 +108,13 @@ is_flag=True, help="Execute registration in dry-run mode. Skips actual registration to remote", ) +@click.option( + "--activate-launchplans", + "--activate-launchplan", + default=False, + is_flag=True, + help="Activate newly registered Launchplans. This operation deactivates previous versions of Launchplans.", +) @click.argument("package-or-module", type=click.Path(exists=True, readable=True, resolve_path=True), nargs=-1) @click.pass_context def register( @@ -129,6 +131,7 @@ def register( non_fast: bool, package_or_module: typing.Tuple[str], dry_run: bool, + activate_launchplans: bool, ): """ see help @@ -179,6 +182,7 @@ def register( package_or_module=package_or_module, remote=remote, dry_run=dry_run, + activate_launchplans=activate_launchplans, ) except Exception as e: raise e diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 1406276263..2205511b94 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -1,544 +1,87 @@ -import datetime +import asyncio import functools import importlib +import inspect import json -import logging import os import pathlib import typing -from dataclasses import dataclass -from typing import cast +from dataclasses import dataclass, field, fields +from typing import cast, get_args -import cloudpickle import rich_click as click -import yaml from dataclasses_json import DataClassJsonMixin -from pytimeparse import parse -from typing_extensions import get_args - -from flytekit import BlobType, Literal, Scalar -from flytekit.clis.sdk_in_container.constants import ( - CTX_CONFIG_FILE, - CTX_COPY_ALL, - CTX_DOMAIN, - CTX_FILE_NAME, - CTX_MODULE, - CTX_PROJECT, - CTX_PROJECT_ROOT, +from rich.progress import Progress + +from flytekit import Annotations, FlyteContext, Labels, Literal +from flytekit.clis.sdk_in_container.helpers import get_remote, patch_image_config +from flytekit.clis.sdk_in_container.utils import ( + PyFlyteParams, + domain_option, + get_option_from_metadata, + make_field, + pretty_print_exception, + project_option, ) -from flytekit.clis.sdk_in_container.helpers import ( - FLYTE_REMOTE_INSTANCE_KEY, - get_and_save_remote_with_click_context, - patch_image_config, -) -from flytekit.configuration import ImageConfig -from flytekit.configuration.default_images import DefaultImages +from flytekit.configuration import DefaultImages, ImageConfig from flytekit.core import context_manager from flytekit.core.base_task import PythonTask -from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import PythonFunctionWorkflow, WorkflowBase -from flytekit.models import literals -from flytekit.models.interface import Variable -from flytekit.models.literals import Blob, BlobMetadata, LiteralCollection, LiteralMap, Primitive, Union -from flytekit.models.types import LiteralType, SimpleType +from flytekit.exceptions.system import FlyteSystemException +from flytekit.interaction.click_types import FlyteLiteralConverter, key_value_callback +from flytekit.models import security +from flytekit.models.common import RawOutputDataConfig +from flytekit.models.interface import Parameter, Variable +from flytekit.models.types import SimpleType +from flytekit.remote import FlyteLaunchPlan, FlyteRemote, FlyteTask, FlyteWorkflow from flytekit.remote.executions import FlyteWorkflowExecution -from flytekit.tools import module_loader, script_mode +from flytekit.tools import module_loader from flytekit.tools.script_mode import _find_project_root from flytekit.tools.translator import Options -from flytekit.types.pickle.pickle import FlytePickleTransformer - -REMOTE_FLAG_KEY = "remote" -RUN_LEVEL_PARAMS_KEY = "run_level_params" -DATA_PROXY_CALLBACK_KEY = "data_proxy" - - -def remove_prefix(text, prefix): - if text.startswith(prefix): - return text[len(prefix) :] - return text - - -@dataclass -class Directory(object): - dir_path: str - local_file: typing.Optional[pathlib.Path] = None - local: bool = True - - -class DirParamType(click.ParamType): - name = "directory path" - - def convert( - self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] - ) -> typing.Any: - if FileAccessProvider.is_remote(value): - return Directory(dir_path=value, local=False) - p = pathlib.Path(value) - if p.exists() and p.is_dir(): - files = list(p.iterdir()) - if len(files) != 1: - raise ValueError( - f"Currently only directories containing one file are supported, found [{len(files)}] files found in {p.resolve()}" - ) - return Directory(dir_path=str(p), local_file=files[0].resolve()) - raise click.BadParameter(f"parameter should be a valid directory path, {value}") - - -@dataclass -class FileParam(object): - filepath: str - local: bool = True - - -class FileParamType(click.ParamType): - name = "file path" - - def convert( - self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] - ) -> typing.Any: - if FileAccessProvider.is_remote(value): - return FileParam(filepath=value, local=False) - p = pathlib.Path(value) - if p.exists() and p.is_file(): - return FileParam(filepath=str(p.resolve())) - raise click.BadParameter(f"parameter should be a valid file path, {value}") - - -class PickleParamType(click.ParamType): - name = "pickle" - - def convert( - self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] - ) -> typing.Any: - - uri = FlyteContextManager.current_context().file_access.get_random_local_path() - with open(uri, "w+b") as outfile: - cloudpickle.dump(value, outfile) - return FileParam(filepath=str(pathlib.Path(uri).resolve())) - - -class DateTimeType(click.DateTime): - - _NOW_FMT = "now" - _ADDITONAL_FORMATS = [_NOW_FMT] - - def __init__(self): - super().__init__() - self.formats.extend(self._ADDITONAL_FORMATS) - - def convert( - self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] - ) -> typing.Any: - if value in self._ADDITONAL_FORMATS: - if value == self._NOW_FMT: - return datetime.datetime.now() - return super().convert(value, param, ctx) - - -class DurationParamType(click.ParamType): - name = "[1:24 | :22 | 1 minute | 10 days | ...]" - - def convert( - self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] - ) -> typing.Any: - if value is None: - raise click.BadParameter("None value cannot be converted to a Duration type.") - return datetime.timedelta(seconds=parse(value)) - - -class JsonParamType(click.ParamType): - name = "json object OR json/yaml file path" - - def convert( - self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] - ) -> typing.Any: - if value is None: - raise click.BadParameter("None value cannot be converted to a Json type.") - if type(value) == dict or type(value) == list: - return value - try: - return json.loads(value) - except Exception: # noqa - try: - # We failed to load the json, so we'll try to load it as a file - if os.path.exists(value): - # if the value is a yaml file, we'll try to load it as yaml - if value.endswith(".yaml") or value.endswith(".yml"): - with open(value, "r") as f: - return yaml.safe_load(f) - with open(value, "r") as f: - return json.load(f) - raise - except json.JSONDecodeError as e: - raise click.BadParameter(f"parameter {param} should be a valid json object, {value}, error: {e}") @dataclass -class DefaultConverter(object): - click_type: click.ParamType - primitive_type: typing.Optional[str] = None - scalar_type: typing.Optional[str] = None - - def convert(self, value: typing.Any, python_type_hint: typing.Optional[typing.Type] = None) -> Scalar: - if self.primitive_type: - return Scalar(primitive=Primitive(**{self.primitive_type: value})) - if self.scalar_type: - return Scalar(**{self.scalar_type: value}) - - raise NotImplementedError("Not implemented yet!") - - -class FlyteLiteralConverter(object): - name = "literal_type" - - SIMPLE_TYPE_CONVERTER: typing.Dict[SimpleType, DefaultConverter] = { - SimpleType.FLOAT: DefaultConverter(click.FLOAT, primitive_type="float_value"), - SimpleType.INTEGER: DefaultConverter(click.INT, primitive_type="integer"), - SimpleType.STRING: DefaultConverter(click.STRING, primitive_type="string_value"), - SimpleType.BOOLEAN: DefaultConverter(click.BOOL, primitive_type="boolean"), - SimpleType.DURATION: DefaultConverter(DurationParamType(), primitive_type="duration"), - SimpleType.DATETIME: DefaultConverter(click.DateTime(), primitive_type="datetime"), - } - - def __init__( - self, - ctx: click.Context, - flyte_ctx: FlyteContext, - literal_type: LiteralType, - python_type: typing.Type, - get_upload_url_fn: typing.Callable, - ): - self._remote = ctx.obj[REMOTE_FLAG_KEY] - self._literal_type = literal_type - self._python_type = python_type - self._create_upload_fn = get_upload_url_fn - self._flyte_ctx = flyte_ctx - self._click_type = click.UNPROCESSED - - if self._literal_type.simple: - if self._literal_type.simple == SimpleType.STRUCT: - self._click_type = JsonParamType() - self._click_type.name = f"JSON object {self._python_type.__name__}" - elif self._literal_type.simple not in self.SIMPLE_TYPE_CONVERTER: - raise NotImplementedError(f"Type {self._literal_type.simple} is not supported in pyflyte run") - else: - self._converter = self.SIMPLE_TYPE_CONVERTER[self._literal_type.simple] - self._click_type = self._converter.click_type - - if self._literal_type.enum_type: - self._converter = self.SIMPLE_TYPE_CONVERTER[SimpleType.STRING] - self._click_type = click.Choice(self._literal_type.enum_type.values) - - if self._literal_type.structured_dataset_type: - self._click_type = DirParamType() - - if self._literal_type.collection_type or self._literal_type.map_value_type: - self._click_type = JsonParamType() - if self._literal_type.collection_type: - self._click_type.name = "json list" - else: - self._click_type.name = "json dictionary" - - if self._literal_type.blob: - if self._literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE: - if self._literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT: - self._click_type = PickleParamType() - else: - self._click_type = FileParamType() - else: - self._click_type = DirParamType() - - @property - def click_type(self) -> click.ParamType: - return self._click_type - - def is_bool(self) -> bool: - if self._literal_type.simple: - return self._literal_type.simple == SimpleType.BOOLEAN - return False - - def get_uri_for_dir( - self, ctx: typing.Optional[click.Context], value: Directory, remote_filename: typing.Optional[str] = None - ): - uri = value.dir_path - - if self._remote and value.local: - md5, _ = script_mode.hash_file(value.local_file) - if not remote_filename: - remote_filename = value.local_file.name - remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] - _, native_url = remote.upload_file(value.local_file) - uri = native_url[: -len(remote_filename)] - - return uri - - def convert_to_structured_dataset( - self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: Directory - ) -> Literal: - - uri = self.get_uri_for_dir(ctx, value, "00000.parquet") - - lit = Literal( - scalar=Scalar( - structured_dataset=literals.StructuredDataset( - uri=uri, - metadata=literals.StructuredDatasetMetadata( - structured_dataset_type=self._literal_type.structured_dataset_type - ), - ), - ), - ) - - return lit - - def convert_to_blob( - self, - ctx: typing.Optional[click.Context], - param: typing.Optional[click.Parameter], - value: typing.Union[Directory, FileParam], - ) -> Literal: - if isinstance(value, Directory): - uri = self.get_uri_for_dir(ctx, value) - else: - uri = value.filepath - if self._remote and value.local: - fp = pathlib.Path(value.filepath) - remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] - _, uri = remote.upload_file(fp) - - lit = Literal( - scalar=Scalar( - blob=Blob( - metadata=BlobMetadata(type=self._literal_type.blob), - uri=uri, - ), - ), - ) - - return lit - - def convert_to_union( - self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any - ) -> Literal: - lt = self._literal_type - for i in range(len(self._literal_type.union_type.variants)): - variant = self._literal_type.union_type.variants[i] - python_type = get_args(self._python_type)[i] - converter = FlyteLiteralConverter( - ctx, - self._flyte_ctx, - variant, - python_type, - self._create_upload_fn, - ) - try: - # Here we use click converter to convert the input in command line to native python type, - # and then use flyte converter to convert it to literal. - python_val = converter._click_type.convert(value, param, ctx) - literal = converter.convert_to_literal(ctx, param, python_val) - return Literal(scalar=Scalar(union=Union(literal, variant))) - except (Exception or AttributeError) as e: - logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) - raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") - - def convert_to_list( - self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: list - ) -> Literal: - """ - Convert a python list into a Flyte Literal - """ - if not value: - raise click.BadParameter("Expected non-empty list") - if not isinstance(value, list): - raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(value)}") - converter = FlyteLiteralConverter( - ctx, - self._flyte_ctx, - self._literal_type.collection_type, - type(value[0]), - self._create_upload_fn, - ) - lt = Literal(collection=LiteralCollection([])) - for v in value: - click_val = converter._click_type.convert(v, param, ctx) - lt.collection.literals.append(converter.convert_to_literal(ctx, param, click_val)) - return lt - - def convert_to_map( - self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: dict - ) -> Literal: - """ - Convert a python dict into a Flyte Literal. - It is assumed that the click parameter type is a JsonParamType. The map is also assumed to be univariate. - """ - if not value: - raise click.BadParameter("Expected non-empty dict") - if not isinstance(value, dict): - raise click.BadParameter(f"Expected json dict '{{...}}', parsed value is {type(value)}") - converter = FlyteLiteralConverter( - ctx, - self._flyte_ctx, - self._literal_type.map_value_type, - type(value[list(value.keys())[0]]), - self._create_upload_fn, - ) - lt = Literal(map=LiteralMap({})) - for k, v in value.items(): - click_val = converter._click_type.convert(v, param, ctx) - lt.map.literals[k] = converter.convert_to_literal(ctx, param, click_val) - return lt - - def convert_to_struct( - self, - ctx: typing.Optional[click.Context], - param: typing.Optional[click.Parameter], - value: typing.Union[dict, typing.Any], - ) -> Literal: - """ - Convert the loaded json object to a Flyte Literal struct type. - """ - if type(value) != self._python_type: - o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value)) - else: - o = value - return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) - - def convert_to_literal( - self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any - ) -> Literal: - if self._literal_type.structured_dataset_type: - return self.convert_to_structured_dataset(ctx, param, value) - - if self._literal_type.blob: - return self.convert_to_blob(ctx, param, value) - - if self._literal_type.collection_type: - return self.convert_to_list(ctx, param, value) - - if self._literal_type.map_value_type: - return self.convert_to_map(ctx, param, value) - - if self._literal_type.union_type: - return self.convert_to_union(ctx, param, value) - - if self._literal_type.simple or self._literal_type.enum_type: - if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT: - return self.convert_to_struct(ctx, param, value) - return Literal(scalar=self._converter.convert(value, self._python_type)) - - if self._literal_type.schema: - raise DeprecationWarning("Schema Types are not supported in pyflyte run. Use StructuredDataset instead.") - - raise NotImplementedError( - f"CLI parsing is not available for Python Type:`{self._python_type}`, LiteralType:`{self._literal_type}`." - ) - - def convert(self, ctx, param, value) -> typing.Union[Literal, typing.Any]: - try: - lit = self.convert_to_literal(ctx, param, value) - if not self._remote: - return TypeEngine.to_python_value(self._flyte_ctx, lit, self._python_type) - return lit - except click.BadParameter: - raise - except Exception as e: - raise click.BadParameter(f"Failed to convert param {param}, {value} to {self._python_type}") from e - - -def to_click_option( - ctx: click.Context, - flyte_ctx: FlyteContext, - input_name: str, - literal_var: Variable, - python_type: typing.Type, - default_val: typing.Any, - get_upload_url_fn: typing.Callable, -) -> click.Option: +class RunLevelComputedParams: """ - This handles converting workflow input types to supported click parameters with callbacks to initialize - the input values to their expected types. + This class is used to store the computed parameters that are used to run a workflow / task / launchplan. + Computed parameters are created during the execution """ - literal_converter = FlyteLiteralConverter( - ctx, flyte_ctx, literal_type=literal_var.type, python_type=python_type, get_upload_url_fn=get_upload_url_fn - ) - if literal_converter.is_bool() and not default_val: - default_val = False + project_root: typing.Optional[str] = None + module: typing.Optional[str] = None + temp_file_name: typing.Optional[str] = None # Used to store the temporary location of the file downloaded - if literal_var.type.simple == SimpleType.STRUCT: - if default_val: - if type(default_val) == dict or type(default_val) == list: - default_val = json.dumps(default_val) - else: - default_val = cast(DataClassJsonMixin, default_val).to_json() - return click.Option( - param_decls=[f"--{input_name}"], - type=literal_converter.click_type, - is_flag=literal_converter.is_bool(), - default=default_val, - show_default=True, - required=default_val is None, - help=literal_var.description, - callback=literal_converter.convert, - ) - - -def set_is_remote(ctx: click.Context, param: str, value: str): - ctx.obj[REMOTE_FLAG_KEY] = bool(value) - - -def get_workflow_command_base_params() -> typing.List[click.Option]: +@dataclass +class RunLevelParams(PyFlyteParams): """ - Return the set of base parameters added to every pyflyte run workflow subcommand. + This class is used to store the parameters that are used to run a workflow / task / launchplan. """ - return [ - click.Option( - param_decls=["--remote"], - required=False, - is_flag=True, - default=False, - expose_value=False, # since we're handling in the callback, no need to expose this in params - is_eager=True, - callback=set_is_remote, - help="Whether to register and run the workflow on a Flyte deployment", - ), - click.Option( - param_decls=["-p", "--project"], - required=False, - type=str, - default="flytesnacks", - help="Project to register and run this workflow in", - ), - click.Option( - param_decls=["-d", "--domain"], - required=False, - type=str, - default="development", - help="Domain to register and run this workflow in", - ), - click.Option( - param_decls=["--name"], - required=False, - type=str, - help="Name to assign to this execution", - ), + + project: str = make_field(project_option) + domain: str = make_field(domain_option) + destination_dir: str = make_field( click.Option( param_decls=["--destination-dir", "destination_dir"], required=False, type=str, default="/root", + show_default=True, help="Directory inside the image where the tar file containing the code will be copied to", - ), + ) + ) + copy_all: bool = make_field( click.Option( param_decls=["--copy-all", "copy_all"], required=False, is_flag=True, default=False, + show_default=True, help="Copy all files in the source root directory to the destination directory", - ), + ) + ) + image_config: ImageConfig = make_field( click.Option( param_decls=["-i", "--image", "image_config"], required=False, @@ -546,43 +89,191 @@ def get_workflow_command_base_params() -> typing.List[click.Option]: type=click.UNPROCESSED, callback=ImageConfig.validate_image, default=[DefaultImages.default_image()], + show_default=True, help="Image used to register and run.", - ), + ) + ) + service_account: str = make_field( click.Option( param_decls=["--service-account", "service_account"], required=False, type=str, default="", help="Service account used when executing this workflow", - ), + ) + ) + wait_execution: bool = make_field( click.Option( param_decls=["--wait-execution", "wait_execution"], required=False, is_flag=True, default=False, + show_default=True, help="Whether to wait for the execution to finish", - ), + ) + ) + dump_snippet: bool = make_field( click.Option( param_decls=["--dump-snippet", "dump_snippet"], required=False, is_flag=True, default=False, + show_default=True, help="Whether to dump a code snippet instructing how to load the workflow execution using flyteremote", - ), + ) + ) + overwrite_cache: bool = make_field( click.Option( param_decls=["--overwrite-cache", "overwrite_cache"], required=False, is_flag=True, default=False, + show_default=True, help="Whether to overwrite the cache if it already exists", - ), + ) + ) + envvars: typing.Dict[str, str] = make_field( click.Option( - param_decls=["--envs", "envs"], + param_decls=["--envvars", "--env"], required=False, - type=JsonParamType(), - help="Environment variables to set in the container", - ), - ] + multiple=True, + type=str, + show_default=True, + callback=key_value_callback, + help="Environment variables to set in the container, of the format `ENV_NAME=ENV_VALUE`", + ) + ) + tags: typing.List[str] = make_field( + click.Option( + param_decls=["--tags", "--tag"], + required=False, + multiple=True, + type=str, + show_default=True, + help="Tags to set for the execution", + ) + ) + name: str = make_field( + click.Option( + param_decls=["--name"], + required=False, + type=str, + show_default=True, + help="Name to assign to this execution", + ) + ) + labels: typing.Dict[str, str] = make_field( + click.Option( + param_decls=["--labels", "--label"], + required=False, + multiple=True, + type=str, + show_default=True, + callback=key_value_callback, + help="Labels to be attached to the execution of the format `label_key=label_value`.", + ) + ) + annotations: typing.Dict[str, str] = make_field( + click.Option( + param_decls=["--annotations", "--annotation"], + required=False, + multiple=True, + type=str, + show_default=True, + callback=key_value_callback, + help="Annotations to be attached to the execution of the format `key=value`.", + ) + ) + raw_output_data_prefix: str = make_field( + click.Option( + param_decls=["--raw-output-data-prefix", "--raw-data-prefix"], + required=False, + type=str, + show_default=True, + help="File Path prefix to store raw output data." + " Examples are file://, s3://, gs:// etc as supported by fsspec." + " If not specified, raw data will be stored in default configured location in remote of locally" + " to temp file system." + + click.style( + "Note, this is not metadata, but only the raw data location " + "used to store Flytefile, Flytedirectory, Structuredataset," + " dataframes" + ), + ) + ) + max_parallelism: int = make_field( + click.Option( + param_decls=["--max-parallelism"], + required=False, + type=int, + show_default=True, + help="Number of nodes of a workflow that can be executed in parallel. If not specified," + " project/domain defaults are used. If 0 then it is unlimited.", + ) + ) + disable_notifications: bool = make_field( + click.Option( + param_decls=["--disable-notifications"], + required=False, + is_flag=True, + default=False, + show_default=True, + help="Should notifications be disabled for this execution.", + ) + ) + remote: bool = make_field( + click.Option( + param_decls=["--remote"], + required=False, + is_flag=True, + default=False, + is_eager=True, + show_default=True, + help="Whether to register and run the workflow on a Flyte deployment", + ) + ) + limit: int = make_field( + click.Option( + param_decls=["--limit", "limit"], + required=False, + type=int, + default=10, + show_default=True, + help="Use this to limit number of launch plans retreived from the backend, " + "if `from-server` option is used", + ) + ) + cluster_pool: str = make_field( + click.Option( + param_decls=["--cluster-pool", "cluster_pool"], + required=False, + type=str, + default="", + help="Assign newly created execution to a given cluster pool", + ) + ) + computed_params: RunLevelComputedParams = field(default_factory=RunLevelComputedParams) + _remote: typing.Optional[FlyteRemote] = None + + def remote_instance(self) -> FlyteRemote: + if self._remote is None: + self._remote = get_remote(self.config_file, self.project, self.domain) + return self._remote + + @property + def is_remote(self) -> bool: + return self.remote + + @classmethod + def from_dict(cls, d: typing.Dict[str, typing.Any]) -> "RunLevelParams": + return cls(**d) + + @classmethod + def options(cls) -> typing.List[click.Option]: + """ + Return the set of base parameters added to every pyflyte run workflow subcommand. + """ + return [get_option_from_metadata(f.metadata) for f in fields(cls) if f.metadata] def load_naive_entity(module_name: str, entity_name: str, project_root: str) -> typing.Union[WorkflowBase, PythonTask]: @@ -652,74 +343,296 @@ def get_entities_in_file(filename: pathlib.Path, should_delete: bool) -> Entitie return Entities(workflows, tasks) +def to_click_option( + ctx: click.Context, + flyte_ctx: FlyteContext, + input_name: str, + literal_var: Variable, + python_type: typing.Type, + default_val: typing.Any, + get_upload_url_fn: typing.Callable, + required: bool, +) -> click.Option: + """ + This handles converting workflow input types to supported click parameters with callbacks to initialize + the input values to their expected types. + """ + run_level_params: RunLevelParams = ctx.obj + + literal_converter = FlyteLiteralConverter( + flyte_ctx, + literal_type=literal_var.type, + python_type=python_type, + get_upload_url_fn=get_upload_url_fn, + is_remote=run_level_params.is_remote, + remote_instance_accessor=run_level_params.remote_instance, + ) + + if literal_converter.is_bool() and not default_val: + default_val = False + + description_extra = "" + if literal_var.type.simple == SimpleType.STRUCT: + if default_val: + if type(default_val) == dict or type(default_val) == list: + default_val = json.dumps(default_val) + else: + default_val = cast(DataClassJsonMixin, default_val).to_json() + if literal_var.type.metadata: + description_extra = f": {json.dumps(literal_var.type.metadata)}" + + return click.Option( + param_decls=[f"--{input_name}"], + type=literal_converter.click_type, + is_flag=literal_converter.is_bool(), + default=default_val, + show_default=True, + required=required, + help=literal_var.description + description_extra, + callback=literal_converter.convert, + ) + + +def options_from_run_params(run_level_params: RunLevelParams) -> Options: + return Options( + labels=Labels(run_level_params.labels) if run_level_params.labels else None, + annotations=Annotations(run_level_params.annotations) if run_level_params.annotations else None, + raw_output_data_config=RawOutputDataConfig(output_location_prefix=run_level_params.raw_output_data_prefix) + if run_level_params.raw_output_data_prefix + else None, + max_parallelism=run_level_params.max_parallelism, + disable_notifications=run_level_params.disable_notifications, + security_context=security.SecurityContext( + run_as=security.Identity(k8s_service_account=run_level_params.service_account) + ) + if run_level_params.service_account + else None, + notifications=[], + ) + + +def run_remote( + remote: FlyteRemote, + entity: typing.Union[FlyteWorkflow, FlyteTask, FlyteLaunchPlan], + project: str, + domain: str, + inputs: typing.Dict[str, typing.Any], + run_level_params: RunLevelParams, + type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, +): + """ + Helper method that executes the given remote FlyteLaunchplan, FlyteWorkflow or FlyteTask + """ + + execution = remote.execute( + entity, + inputs=inputs, + project=project, + domain=domain, + name=run_level_params.name, + wait=run_level_params.wait_execution, + options=options_from_run_params(run_level_params), + type_hints=type_hints, + overwrite_cache=run_level_params.overwrite_cache, + envs=run_level_params.envvars, + tags=run_level_params.tags, + cluster_pool=run_level_params.cluster_pool, + ) + + console_url = remote.generate_console_url(execution) + s = ( + click.style("\n[✔] ", fg="green") + + "Go to " + + click.style(console_url, fg="cyan") + + " to see execution in the console." + ) + click.echo(s) + + if run_level_params.dump_snippet: + dump_flyte_remote_snippet(execution, project, domain) + + def run_command(ctx: click.Context, entity: typing.Union[PythonFunctionWorkflow, PythonTask]): """ Returns a function that is used to implement WorkflowCommand and execute a flyte workflow. """ def _run(*args, **kwargs): + """ + Click command function that is used to execute a flyte workflow from the given entity in the file. + """ # By the time we get to this function, all the loading has already happened - run_level_params = ctx.obj[RUN_LEVEL_PARAMS_KEY] - project, domain = run_level_params.get("project"), run_level_params.get("domain") - inputs = {} - for input_name, _ in entity.python_interface.inputs.items(): - inputs[input_name] = kwargs.get(input_name) - - if not ctx.obj[REMOTE_FLAG_KEY]: - output = entity(**inputs) - click.echo(output) - if ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME): - os.remove(ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME)) - return - - remote = ctx.obj[FLYTE_REMOTE_INSTANCE_KEY] - config_file = ctx.obj.get(CTX_CONFIG_FILE) - - image_config = run_level_params.get("image_config") - image_config = patch_image_config(config_file, image_config) - - remote_entity = remote.register_script( - entity, - project=project, - domain=domain, - image_config=image_config, - destination_dir=run_level_params.get("destination_dir"), - source_path=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT_ROOT), - module_name=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_MODULE), - copy_all=ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_COPY_ALL), - ) + run_level_params: RunLevelParams = ctx.obj + if run_level_params.verbose: + click.echo(f"Running {entity.name} with {kwargs} and run_level_params {run_level_params}") + + click.secho(f"Running Execution on {'Remote' if run_level_params.is_remote else 'local'}.", fg="cyan") + try: + inputs = {} + for input_name, _ in entity.python_interface.inputs.items(): + inputs[input_name] = kwargs.get(input_name) + + if not run_level_params.is_remote: + output = entity(**inputs) + if inspect.iscoroutine(output): + # TODO: make eager mode workflows run with local-mode + output = asyncio.run(output) + click.echo(output) + return + + remote = run_level_params.remote_instance() + config_file = run_level_params.config_file + + image_config = run_level_params.image_config + image_config = patch_image_config(config_file, image_config) + + remote_entity = remote.register_script( + entity, + project=run_level_params.project, + domain=run_level_params.domain, + image_config=image_config, + destination_dir=run_level_params.destination_dir, + source_path=run_level_params.computed_params.project_root, + module_name=run_level_params.computed_params.module, + copy_all=run_level_params.copy_all, + ) + + run_remote( + remote, + remote_entity, + run_level_params.project, + run_level_params.domain, + inputs, + run_level_params, + type_hints=entity.python_interface.inputs, + ) + finally: + if run_level_params.computed_params.temp_file_name: + os.remove(run_level_params.computed_params.temp_file_name) - options = None - service_account = run_level_params.get("service_account") - if service_account: - # options are only passed for the execution. This is to prevent errors when registering a duplicate workflow - # It is assumed that the users expectations is to override the service account only for the execution - options = Options.default_from(k8s_service_account=service_account) - - execution = remote.execute( - remote_entity, - inputs=inputs, - project=project, - domain=domain, - name=run_level_params.get("name"), - wait=run_level_params.get("wait_execution"), - options=options, - type_hints=entity.python_interface.inputs, - overwrite_cache=run_level_params.get("overwrite_cache"), - envs=run_level_params.get("envs"), + return _run + + +class DynamicLaunchPlanCommand(click.RichCommand): + """ + This is a dynamic command that is created for each launch plan. This is used to execute a launch plan. + It will fetch the launch plan from remote and create parameters from all the inputs of the launch plan. + """ + + def __init__(self, name: str, h: str, lp_name: str, **kwargs): + super().__init__(name=name, help=h, **kwargs) + self._lp_name = lp_name + self._lp = None + + def _fetch_launch_plan(self, ctx: click.Context) -> FlyteLaunchPlan: + if self._lp: + return self._lp + run_level_params: RunLevelParams = ctx.obj + r = run_level_params.remote_instance() + self._lp = r.fetch_launch_plan(run_level_params.project, run_level_params.domain, self._lp_name) + return self._lp + + def _get_params( + self, + ctx: click.Context, + inputs: typing.Dict[str, Variable], + native_inputs: typing.Dict[str, type], + fixed: typing.Dict[str, Literal], + defaults: typing.Dict[str, Parameter], + ) -> typing.List["click.Parameter"]: + params = [] + run_level_params: RunLevelParams = ctx.obj + r = run_level_params.remote_instance() + + get_upload_url_fn = functools.partial( + r.client.get_upload_signed_url, project=run_level_params.project, domain=run_level_params.domain + ) + flyte_ctx = context_manager.FlyteContextManager.current_context() + for name, var in inputs.items(): + if fixed and name in fixed: + continue + required = True + if defaults and name in defaults: + required = False + params.append( + to_click_option(ctx, flyte_ctx, name, var, native_inputs[name], None, get_upload_url_fn, required) + ) + return params + + def get_params(self, ctx: click.Context) -> typing.List["click.Parameter"]: + if not self.params: + self.params = [] + lp = self._fetch_launch_plan(ctx) + if lp.interface: + if lp.interface.inputs: + types = TypeEngine.guess_python_types(lp.interface.inputs) + self.params = self._get_params( + ctx, lp.interface.inputs, types, lp.fixed_inputs.literals, lp.default_inputs.parameters + ) + + return super().get_params(ctx) + + def invoke(self, ctx: click.Context) -> typing.Any: + """ + Default or None values should be ignored. Only values that are provided by the user should be passed to the + remote execution. + """ + run_level_params: RunLevelParams = ctx.obj + r = run_level_params.remote_instance() + lp = self._fetch_launch_plan(ctx) + run_remote( + r, + lp, + run_level_params.project, + run_level_params.domain, + ctx.params, + run_level_params, + type_hints=lp.python_interface.inputs if lp.python_interface else None, ) - console_url = remote.generate_console_url(execution) - click.secho(f"Go to {console_url} to see execution in the console.") - if run_level_params.get("dump_snippet"): - dump_flyte_remote_snippet(execution, project, domain) +class RemoteLaunchPlanGroup(click.RichGroup): + """ + click multicommand that retrieves launchplans from a remote flyte instance and executes them. + """ - if ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME): - os.remove(ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_FILE_NAME)) + COMMAND_NAME = "remote-launchplan" - return _run + def __init__(self): + super().__init__( + name="from-server", + help="Retrieve launchplans from a remote flyte instance and execute them.", + params=[ + click.Option( + ["--limit"], help="Limit the number of launchplans to retrieve.", default=10, show_default=True + ) + ], + ) + self._lps = [] + + def list_commands(self, ctx): + if self._lps or ctx.obj is None: + return self._lps + + run_level_params: RunLevelParams = ctx.obj + r = run_level_params.remote_instance() + progress = Progress(transient=True) + task = progress.add_task(f"[cyan]Gathering [{run_level_params.limit}] remote LaunchPlans...", total=None) + with progress: + progress.start_task(task) + try: + lps = r.client.list_launch_plan_ids_paginated( + project=run_level_params.project, domain=run_level_params.domain, limit=run_level_params.limit + ) + self._lps = [l.name for l in lps[0]] + return self._lps + except FlyteSystemException as e: + pretty_print_exception(e) + return [] + + def get_command(self, ctx, name): + return DynamicLaunchPlanCommand(name=name, h="Execute a launchplan from remote.", lp_name=name) class WorkflowCommand(click.RichGroup): @@ -739,11 +652,59 @@ def __init__(self, filename: str, *args, **kwargs): else: self._filename = pathlib.Path(filename).resolve() self._should_delete = False + self._entities = None def list_commands(self, ctx): + if self._entities: + return self._entities.all() entities = get_entities_in_file(self._filename, self._should_delete) + self._entities = entities return entities.all() + def _create_command( + self, + ctx: click.Context, + entity_name: str, + run_level_params: RunLevelParams, + loaded_entity: typing.Any, + is_workflow: bool, + ): + """ + Delegate that creates the command for a given entity. + """ + + # If this is a remote execution, which we should know at this point, then create the remote object + r = run_level_params.remote_instance() + get_upload_url_fn = functools.partial( + r.client.get_upload_signed_url, project=run_level_params.project, domain=run_level_params.domain + ) + + flyte_ctx = context_manager.FlyteContextManager.current_context() + + # Add options for each of the workflow inputs + params = [] + for input_name, input_type_val in loaded_entity.python_interface.inputs_with_defaults.items(): + literal_var = loaded_entity.interface.inputs.get(input_name) + python_type, default_val = input_type_val + required = type(None) not in get_args(python_type) and default_val is None + params.append( + to_click_option( + ctx, flyte_ctx, input_name, literal_var, python_type, default_val, get_upload_url_fn, required + ) + ) + + entity_type = "Workflow" if is_workflow else "Task" + h = f"{click.style(entity_type, bold=True)} ({run_level_params.computed_params.module}.{entity_name})" + if loaded_entity.__doc__: + h = h + click.style(f"{loaded_entity.__doc__}", dim=True) + cmd = click.RichCommand( + name=entity_name, + params=params, + callback=run_command(ctx, loaded_entity), + help=h, + ) + return cmd + def get_command(self, ctx, exe_entity): """ This command uses the filename with which this command was created, and the string name of the entity passed @@ -754,7 +715,9 @@ def get_command(self, ctx, exe_entity): function. :return: """ - + is_workflow = False + if self._entities: + is_workflow = exe_entity in self._entities.workflows rel_path = os.path.relpath(self._filename) if rel_path.startswith(".."): raise ValueError( @@ -769,35 +732,18 @@ def get_command(self, ctx, exe_entity): rel_path = self._filename.relative_to(project_root) module = os.path.splitext(rel_path)[0].replace(os.path.sep, ".") - ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_PROJECT_ROOT] = project_root - ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_MODULE] = module - if self._should_delete: - ctx.obj[RUN_LEVEL_PARAMS_KEY][CTX_FILE_NAME] = self._filename - entity = load_naive_entity(module, exe_entity, project_root) + run_level_params: RunLevelParams = ctx.obj - # If this is a remote execution, which we should know at this point, then create the remote object - p = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_PROJECT) - d = ctx.obj[RUN_LEVEL_PARAMS_KEY].get(CTX_DOMAIN) - r = get_and_save_remote_with_click_context(ctx, p, d) - get_upload_url_fn = functools.partial(r.client.get_upload_signed_url, project=p, domain=d) + # update computed params + run_level_params.computed_params.project_root = project_root + run_level_params.computed_params.module = module - flyte_ctx = context_manager.FlyteContextManager.current_context() + if self._should_delete: + run_level_params.computed_params.temp_file_name = self._filename - # Add options for each of the workflow inputs - params = [] - for input_name, input_type_val in entity.python_interface.inputs_with_defaults.items(): - literal_var = entity.interface.inputs.get(input_name) - python_type, default_val = input_type_val - params.append( - to_click_option(ctx, flyte_ctx, input_name, literal_var, python_type, default_val, get_upload_url_fn) - ) - cmd = click.Command( - name=exe_entity, - params=params, - callback=run_command(ctx, entity), - help=f"Run {module}.{exe_entity} in script mode", - ) - return cmd + entity = load_naive_entity(module, exe_entity, project_root) + + return self._create_command(ctx, exe_entity, run_level_params, entity, is_workflow) class RunCommand(click.RichGroup): @@ -806,16 +752,32 @@ class RunCommand(click.RichGroup): """ def __init__(self, *args, **kwargs): - params = get_workflow_command_base_params() - super().__init__(*args, params=params, **kwargs) + if "params" not in kwargs: + params = RunLevelParams.options() + kwargs["params"] = params + super().__init__(*args, **kwargs) + self._files = [] - def list_commands(self, ctx): - return [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"] + def list_commands(self, ctx, add_remote: bool = True): + if self._files: + return self._files + self._files = [str(p) for p in pathlib.Path(".").glob("*.py") if str(p) != "__init__.py"] + self._files = sorted(self._files) + if add_remote: + self._files = self._files + [RemoteLaunchPlanGroup.COMMAND_NAME] + return self._files def get_command(self, ctx, filename): - if ctx.obj: - ctx.obj[RUN_LEVEL_PARAMS_KEY] = ctx.params - return WorkflowCommand(filename, name=filename, help="Run a [workflow|task] in a file using script mode") + if ctx.obj is None: + ctx.obj = {} + if not isinstance(ctx.obj, RunLevelParams): + params = {} + params.update(ctx.params) + params.update(ctx.obj) + ctx.obj = RunLevelParams.from_dict(params) + if filename == RemoteLaunchPlanGroup.COMMAND_NAME: + return RemoteLaunchPlanGroup() + return WorkflowCommand(filename, name=filename, help=f"Run a [workflow|task] from {filename}") _run_help = """ diff --git a/flytekit/clis/sdk_in_container/serve.py b/flytekit/clis/sdk_in_container/serve.py index c95754e6c6..145dc90212 100644 --- a/flytekit/clis/sdk_in_container/serve.py +++ b/flytekit/clis/sdk_in_container/serve.py @@ -1,10 +1,8 @@ from concurrent import futures import click -import grpc from flyteidl.service.agent_pb2_grpc import add_AsyncAgentServiceServicer_to_server - -from flytekit.extend.backend.agent_service import AgentService +from grpc import aio _serve_help = """Start a grpc server for the agent service.""" @@ -37,10 +35,26 @@ def serve(_: click.Context, port, worker, timeout): """ Start a grpc server for the agent service. """ + import asyncio + + asyncio.run(_start_grpc_server(port, worker, timeout)) + + +async def _start_grpc_server(port: int, worker: int, timeout: int): + click.secho("Starting up the server to expose the prometheus metrics...", fg="blue") + from flytekit.extend.backend.agent_service import AsyncAgentService + + try: + from prometheus_client import start_http_server + + start_http_server(9090) + except ImportError as e: + click.secho(f"Failed to start the prometheus server with error {e}", fg="red") click.secho("Starting the agent service...", fg="blue") - server = grpc.server(futures.ThreadPoolExecutor(max_workers=worker)) - add_AsyncAgentServiceServicer_to_server(AgentService(), server) + server = aio.server(futures.ThreadPoolExecutor(max_workers=worker)) + + add_AsyncAgentServiceServicer_to_server(AsyncAgentService(), server) server.add_insecure_port(f"[::]:{port}") - server.start() - server.wait_for_termination(timeout=timeout) + await server.start() + await server.wait_for_termination(timeout) diff --git a/flytekit/clis/sdk_in_container/utils.py b/flytekit/clis/sdk_in_container/utils.py index e69de29bb2..3f975913e0 100644 --- a/flytekit/clis/sdk_in_container/utils.py +++ b/flytekit/clis/sdk_in_container/utils.py @@ -0,0 +1,152 @@ +import os +import typing +from dataclasses import Field, dataclass, field +from types import MappingProxyType + +import grpc +import rich_click as click +from google.protobuf.json_format import MessageToJson + +from flytekit.clis.sdk_in_container.constants import CTX_VERBOSE +from flytekit.exceptions.base import FlyteException +from flytekit.exceptions.user import FlyteInvalidInputException +from flytekit.loggers import cli_logger + +project_option = click.Option( + param_decls=["-p", "--project"], + required=False, + type=str, + default=os.getenv("FLYTE_DEFAULT_PROJECT", "flytesnacks"), + show_default=True, + help="Project to register and run this workflow in. Can also be set through envvar " "``FLYTE_DEFAULT_PROJECT``", +) + +domain_option = click.Option( + param_decls=["-d", "--domain"], + required=False, + type=str, + default=os.getenv("FLYTE_DEFAULT_DOMAIN", "development"), + show_default=True, + help="Domain to register and run this workflow in, can also be set through envvar " "``FLYTE_DEFAULT_DOMAIN``", +) + +project_option_dec = click.option( + "-p", + "--project", + required=False, + type=str, + default=os.getenv("FLYTE_DEFAULT_PROJECT", "flytesnacks"), + show_default=True, + help="Project for workflow/launchplan. Can also be set through envvar " "``FLYTE_DEFAULT_PROJECT``", +) + +domain_option_dec = click.option( + "-d", + "--domain", + required=False, + type=str, + default=os.getenv("FLYTE_DEFAULT_DOMAIN", "development"), + show_default=True, + help="Domain for workflow/launchplan, can also be set through envvar " "``FLYTE_DEFAULT_DOMAIN``", +) + + +def validate_package(ctx, param, values): + """ + This method will validate the packages passed in by the user. It will check that the packages are in the correct + format, and will also split the packages if the user passed in a comma separated list. + """ + pkgs = [] + for val in values: + if "/" in val or "-" in val or "\\" in val: + raise click.BadParameter( + f"Illegal package value {val} for parameter: {param}. Expected for the form [a.b.c]" + ) + elif "," in val: + pkgs.extend(val.split(",")) + else: + pkgs.append(val) + cli_logger.debug(f"Using packages: {pkgs}") + return pkgs + + +def pretty_print_grpc_error(e: grpc.RpcError): + """ + This method will print the grpc error that us more human readable. + """ + if isinstance(e, grpc._channel._InactiveRpcError): # noqa + click.secho(f"RPC Failed, with Status: {e.code()}", fg="red", bold=True) + click.secho(f"\tdetails: {e.details()}", fg="magenta", bold=True) + click.secho(f"\tDebug string {e.debug_error_string()}", dim=True) + return + + +def pretty_print_exception(e: Exception): + """ + This method will print the exception in a nice way. It will also check if the exception is a grpc.RpcError and + print it in a human-readable way. + """ + if isinstance(e, click.exceptions.Exit): + raise e + + if isinstance(e, click.ClickException): + click.secho(e.message, fg="red") + raise e + + if isinstance(e, FlyteException): + click.secho(f"Failed with Exception Code: {e._ERROR_CODE}", fg="red") # noqa + if isinstance(e, FlyteInvalidInputException): + click.secho("Request rejected by the API, due to Invalid input.", fg="red") + click.secho(f"\tInput Request: {MessageToJson(e.request)}", dim=True) + + cause = e.__cause__ + if cause: + if isinstance(cause, grpc.RpcError): + pretty_print_grpc_error(cause) + else: + click.secho(f"Underlying Exception: {cause}") + return + + if isinstance(e, grpc.RpcError): + pretty_print_grpc_error(e) + return + + click.secho(f"Failed with Unknown Exception {type(e)} Reason: {e}", fg="red") # noqa + + +class ErrorHandlingCommand(click.RichGroup): + """ + Helper class that wraps the invoke method of a click command to catch exceptions and print them in a nice way. + """ + + def invoke(self, ctx: click.Context) -> typing.Any: + try: + return super().invoke(ctx) + except Exception as e: + if CTX_VERBOSE in ctx.obj and ctx.obj[CTX_VERBOSE]: + click.secho("Verbose mode on") + raise e + pretty_print_exception(e) + raise SystemExit(e) from e + + +def make_field(o: click.Option) -> Field: + if o.multiple: + o.help = click.style("Multiple values allowed.", bold=True) + f"{o.help}" + return field(default_factory=lambda: o.default, metadata={"click.option": o}) + return field(default=o.default, metadata={"click.option": o}) + + +def get_option_from_metadata(metadata: MappingProxyType) -> click.Option: + return metadata["click.option"] + + +@dataclass +class PyFlyteParams: + config_file: typing.Optional[str] = None + verbose: bool = False + pkgs: typing.List[str] = field(default_factory=list) + + @classmethod + def from_dict(cls, d: typing.Dict[str, typing.Any]) -> "PyFlyteParams": + return cls(**d) diff --git a/flytekit/configuration/__init__.py b/flytekit/configuration/__init__.py index 5131e0378a..f842f1451b 100644 --- a/flytekit/configuration/__init__.py +++ b/flytekit/configuration/__init__.py @@ -144,7 +144,7 @@ from typing import Dict, List, Optional import yaml -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from flytekit.configuration import internal as _internal from flytekit.configuration.default_images import DefaultImages @@ -164,9 +164,8 @@ SERIALIZED_CONTEXT_ENV_VAR = "_F_SS_C" -@dataclass_json @dataclass(init=True, repr=True, eq=True, frozen=True) -class Image(object): +class Image(DataClassJsonMixin): """ Image is a structured wrapper for task container images used in object serialization. @@ -224,15 +223,17 @@ def look_up_image_info(name: str, tag: str, optional_tag: bool = False) -> Image return Image(name=name, fqn=ref["name"], tag=ref["tag"]) -@dataclass_json @dataclass(init=True, repr=True, eq=True, frozen=True) -class ImageConfig(object): +class ImageConfig(DataClassJsonMixin): """ + We recommend you to use ImageConfig.auto(img_name=None) to create an ImageConfig. + For example, ImageConfig.auto(img_name=""ghcr.io/flyteorg/flytecookbook:v1.0.0"") will create an ImageConfig. + ImageConfig holds available images which can be used at registration time. A default image can be specified along with optional additional images. Each image in the config must have a unique name. Attributes: - default_image (str): The default image to be used as a container for task serialization. + default_image (Optional[Image]): The default image to be used as a container for task serialization. images (List[Image]): Optional, additional images which can be used in task container definitions. """ @@ -372,6 +373,7 @@ class PlatformConfig(object): :param insecure_skip_verify: Whether to skip SSL certificate verification :param console_endpoint: endpoint for console if different from Flyte backend :param command: This command is executed to return a token using an external process + :param proxy_command: This command is executed to return a token for proxy authorization using an external process :param client_id: This is the public identifier for the app which handles authorization for a Flyte deployment. More details here: https://www.oauth.com/oauth2-servers/client-registration/client-id-secret/. :param client_credentials_secret: Used for service auth, which is automatically called during pyflyte. This will @@ -389,6 +391,7 @@ class PlatformConfig(object): ca_cert_file_path: typing.Optional[str] = None console_endpoint: typing.Optional[str] = None command: typing.Optional[typing.List[str]] = None + proxy_command: typing.Optional[typing.List[str]] = None client_id: typing.Optional[str] = None client_credentials_secret: typing.Optional[str] = None scopes: List[str] = field(default_factory=list) @@ -412,24 +415,37 @@ def auto(cls, config_file: typing.Optional[typing.Union[str, ConfigFile]] = None ) kwargs = set_if_exists(kwargs, "ca_cert_file_path", _internal.Platform.CA_CERT_FILE_PATH.read(config_file)) kwargs = set_if_exists(kwargs, "command", _internal.Credentials.COMMAND.read(config_file)) + kwargs = set_if_exists(kwargs, "proxy_command", _internal.Credentials.PROXY_COMMAND.read(config_file)) kwargs = set_if_exists(kwargs, "client_id", _internal.Credentials.CLIENT_ID.read(config_file)) kwargs = set_if_exists( kwargs, "client_credentials_secret", _internal.Credentials.CLIENT_CREDENTIALS_SECRET.read(config_file) ) + is_client_secret = False client_credentials_secret = read_file_if_exists( _internal.Credentials.CLIENT_CREDENTIALS_SECRET_LOCATION.read(config_file) ) - if client_credentials_secret and client_credentials_secret.endswith("\n"): - logger.info("Newline stripped from client secret") - client_credentials_secret = client_credentials_secret.strip() + if client_credentials_secret: + is_client_secret = True + if client_credentials_secret.endswith("\n"): + logger.info("Newline stripped from client secret") + client_credentials_secret = client_credentials_secret.strip() kwargs = set_if_exists( kwargs, "client_credentials_secret", client_credentials_secret, ) + + client_credentials_secret_env_var = _internal.Credentials.CLIENT_CREDENTIALS_SECRET_ENV_VAR.read(config_file) + if client_credentials_secret_env_var: + client_credentials_secret = os.getenv(client_credentials_secret_env_var) + if client_credentials_secret: + is_client_secret = True + kwargs = set_if_exists(kwargs, "client_credentials_secret", client_credentials_secret) kwargs = set_if_exists(kwargs, "scopes", _internal.Credentials.SCOPES.read(config_file)) kwargs = set_if_exists(kwargs, "auth_mode", _internal.Credentials.AUTH_MODE.read(config_file)) + if is_client_secret: + kwargs = set_if_exists(kwargs, "auth_mode", AuthType.CLIENTSECRET.value) kwargs = set_if_exists(kwargs, "endpoint", _internal.Platform.URL.read(config_file)) kwargs = set_if_exists(kwargs, "console_endpoint", _internal.Platform.CONSOLE_ENDPOINT.read(config_file)) @@ -549,6 +565,30 @@ def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: return GCSConfig(**kwargs) +@dataclass(init=True, repr=True, eq=True, frozen=True) +class AzureBlobStorageConfig(object): + """ + Any Azure Blob Storage specific configuration. + """ + + account_name: typing.Optional[str] = None + account_key: typing.Optional[str] = None + tenant_id: typing.Optional[str] = None + client_id: typing.Optional[str] = None + client_secret: typing.Optional[str] = None + + @classmethod + def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> GCSConfig: + config_file = get_config_file(config_file) + kwargs = {} + kwargs = set_if_exists(kwargs, "account_name", _internal.AZURE.STORAGE_ACCOUNT_NAME.read(config_file)) + kwargs = set_if_exists(kwargs, "account_key", _internal.AZURE.STORAGE_ACCOUNT_KEY.read(config_file)) + kwargs = set_if_exists(kwargs, "tenant_id", _internal.AZURE.TENANT_ID.read(config_file)) + kwargs = set_if_exists(kwargs, "client_id", _internal.AZURE.CLIENT_ID.read(config_file)) + kwargs = set_if_exists(kwargs, "client_secret", _internal.AZURE.CLIENT_SECRET.read(config_file)) + return AzureBlobStorageConfig(**kwargs) + + @dataclass(init=True, repr=True, eq=True, frozen=True) class DataConfig(object): """ @@ -559,11 +599,13 @@ class DataConfig(object): s3: S3Config = S3Config() gcs: GCSConfig = GCSConfig() + azure: AzureBlobStorageConfig = AzureBlobStorageConfig() @classmethod def auto(cls, config_file: typing.Union[str, ConfigFile] = None) -> DataConfig: config_file = get_config_file(config_file) return DataConfig( + azure=AzureBlobStorageConfig.auto(config_file), s3=S3Config.auto(config_file), gcs=GCSConfig.auto(config_file), ) @@ -668,9 +710,8 @@ def for_endpoint( return c.with_params(platform=PlatformConfig.for_endpoint(endpoint, insecure), data_config=data_config) -@dataclass_json @dataclass -class EntrypointSettings(object): +class EntrypointSettings(DataClassJsonMixin): """ This object carries information about the path of the entrypoint command that will be invoked at runtime. This is where `pyflyte-execute` code can be found. This is useful for cases like pyspark execution. @@ -679,9 +720,8 @@ class EntrypointSettings(object): path: Optional[str] = None -@dataclass_json @dataclass -class FastSerializationSettings(object): +class FastSerializationSettings(DataClassJsonMixin): """ This object hold information about settings necessary to serialize an object so that it can be fast-registered. """ @@ -695,9 +735,8 @@ class FastSerializationSettings(object): # TODO: ImageConfig, python_interpreter, venv_root, fast_serialization_settings.destination_dir should be combined. -@dataclass_json -@dataclass() -class SerializationSettings(object): +@dataclass +class SerializationSettings(DataClassJsonMixin): """ These settings are provided while serializing a workflow and task, before registration. This is required to get runtime information at serialization time, as well as some defaults. diff --git a/flytekit/configuration/file.py b/flytekit/configuration/file.py index db4774e626..7fafc348f4 100644 --- a/flytekit/configuration/file.py +++ b/flytekit/configuration/file.py @@ -41,7 +41,7 @@ class LegacyConfigEntry(object): option: str type_: typing.Type = str - def get_env_name(self): + def get_env_name(self) -> str: return f"FLYTE_{self.section.upper()}_{self.option.upper()}" def read_from_env(self, transform: typing.Optional[typing.Callable] = None) -> typing.Optional[typing.Any]: @@ -97,7 +97,7 @@ def read_from_file( return None -def bool_transformer(config_val: typing.Any): +def bool_transformer(config_val: typing.Any) -> bool: if type(config_val) is str: return True if config_val and not config_val.lower() in ["false", "0", "off", "no"] else False else: diff --git a/flytekit/configuration/internal.py b/flytekit/configuration/internal.py index 4f993b4e11..b12103a3fd 100644 --- a/flytekit/configuration/internal.py +++ b/flytekit/configuration/internal.py @@ -57,6 +57,15 @@ class GCP(object): GSUTIL_PARALLELISM = ConfigEntry(LegacyConfigEntry(SECTION, "gsutil_parallelism", bool)) +class AZURE(object): + SECTION = "azure" + STORAGE_ACCOUNT_NAME = ConfigEntry(LegacyConfigEntry(SECTION, "storage_account_name")) + STORAGE_ACCOUNT_KEY = ConfigEntry(LegacyConfigEntry(SECTION, "storage_account_key")) + TENANT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "tenant_id")) + CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "client_id")) + CLIENT_SECRET = ConfigEntry(LegacyConfigEntry(SECTION, "client_secret")) + + class Credentials(object): SECTION = "credentials" COMMAND = ConfigEntry(LegacyConfigEntry(SECTION, "command", list), YamlConfigEntry("admin.command", list)) @@ -64,6 +73,13 @@ class Credentials(object): This command is executed to return a token using an external process. """ + PROXY_COMMAND = ConfigEntry( + LegacyConfigEntry(SECTION, "proxy_command", list), YamlConfigEntry("admin.proxyCommand", list) + ) + """ + This command is executed to return a token for authorization with a proxy in front of Flyte using an external process. + """ + CLIENT_ID = ConfigEntry(LegacyConfigEntry(SECTION, "client_id"), YamlConfigEntry("admin.clientId")) """ This is the public identifier for the app which handles authorization for a Flyte deployment. @@ -85,6 +101,14 @@ class Credentials(object): password from a mounted file. """ + CLIENT_CREDENTIALS_SECRET_ENV_VAR = ConfigEntry( + LegacyConfigEntry(SECTION, "client_secret_env_var"), YamlConfigEntry("admin.clientSecretEnvVar") + ) + """ + Used for basic auth, which is automatically called during pyflyte. This will allow the Flyte engine to read the + password from a mounted environment variable. + """ + SCOPES = ConfigEntry(LegacyConfigEntry(SECTION, "scopes", list), YamlConfigEntry("admin.scopes", list)) AUTH_MODE = ConfigEntry(LegacyConfigEntry(SECTION, "auth_mode"), YamlConfigEntry("admin.authType")) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py new file mode 100644 index 0000000000..aafbaea3ad --- /dev/null +++ b/flytekit/core/array_node_map_task.py @@ -0,0 +1,376 @@ +# TODO: has to support the SupportsNodeCreation protocol +import functools +import hashlib +import logging +import os # TODO: use flytekit logger +from contextlib import contextmanager +from typing import Dict, List, Optional, Set, Union, cast + +from typing_extensions import Any + +from flytekit.configuration import SerializationSettings +from flytekit.core import tracker +from flytekit.core.base_task import PythonTask, TaskResolverMixin +from flytekit.core.constants import SdkTaskType +from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager +from flytekit.core.interface import transform_interface_to_list_interface +from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask +from flytekit.core.utils import timeit +from flytekit.exceptions import scopes as exception_scopes +from flytekit.models.array_job import ArrayJob +from flytekit.models.core.workflow import NodeMetadata +from flytekit.models.interface import Variable +from flytekit.models.task import Container, K8sPod, Sql, Task +from flytekit.tools.module_loader import load_object_from_module + + +class ArrayNodeMapTask(PythonTask): + def __init__( + self, + # TODO: add support for other Flyte entities + python_function_task: Union[PythonFunctionTask, PythonInstanceTask, functools.partial], + concurrency: Optional[int] = None, + min_successes: Optional[int] = None, + min_success_ratio: Optional[float] = None, + bound_inputs: Optional[Set[str]] = None, + **kwargs, + ): + """ + :param python_function_task: The task to be executed in parallel + :param concurrency: The number of parallel executions to run + :param min_successes: The minimum number of successful executions + :param min_success_ratio: The minimum ratio of successful executions + :param bound_inputs: The set of inputs that should be bound to the map task + :param kwargs: Additional keyword arguments to pass to the base class + """ + self._partial = None + if isinstance(python_function_task, functools.partial): + # TODO: We should be able to support partial tasks with lists as inputs + for arg in python_function_task.keywords.values(): + if isinstance(arg, list): + raise ValueError("Cannot use a partial task with lists as inputs") + self._partial = python_function_task + actual_task = self._partial.func + else: + actual_task = python_function_task + + # TODO: add support for other Flyte entities + if not (isinstance(actual_task, PythonFunctionTask) or isinstance(actual_task, PythonInstanceTask)): + raise ValueError("Only PythonFunctionTask and PythonInstanceTask are supported in map tasks.") + + n_outputs = len(actual_task.python_interface.outputs) + if n_outputs > 1: + raise ValueError("Only tasks with a single output are supported in map tasks.") + + self._bound_inputs: Set[str] = bound_inputs or set(bound_inputs) if bound_inputs else set() + if self._partial: + self._bound_inputs.update(self._partial.keywords.keys()) + + # Transform the interface to List[Optional[T]] in case `min_success_ratio` is set + output_as_list_of_optionals = min_success_ratio is not None and min_success_ratio != 1 and n_outputs == 1 + collection_interface = transform_interface_to_list_interface( + actual_task.python_interface, self._bound_inputs, output_as_list_of_optionals + ) + + self._run_task: Union[PythonFunctionTask, PythonInstanceTask] = actual_task # type: ignore + if isinstance(actual_task, PythonInstanceTask): + mod = actual_task.task_type + f = actual_task.lhs + else: + _, mod, f, _ = tracker.extract_task_module(cast(PythonFunctionTask, actual_task).task_function) + h = hashlib.md5( + f"{collection_interface.__str__()}{concurrency}{min_successes}{min_success_ratio}".encode("utf-8") + ).hexdigest() + self._name = f"{mod}.map_{f}_{h}-arraynode" + + self._concurrency: Optional[int] = concurrency + self._min_successes: Optional[int] = min_successes + self._min_success_ratio: Optional[float] = min_success_ratio + self._collection_interface = collection_interface + + if "metadata" not in kwargs and actual_task.metadata: + kwargs["metadata"] = actual_task.metadata + if "security_ctx" not in kwargs and actual_task.security_context: + kwargs["security_ctx"] = actual_task.security_context + + super().__init__( + name=self.name, + interface=collection_interface, + task_type=SdkTaskType.PYTHON_TASK, + task_config=None, + task_type_version=1, + **kwargs, + ) + + @property + def name(self) -> str: + return self._name + + @property + def python_interface(self): + return self._collection_interface + + def construct_node_metadata(self) -> NodeMetadata: + # TODO: add support for other Flyte entities + return NodeMetadata( + name=self.name, + ) + + @property + def min_success_ratio(self) -> Optional[float]: + return self._min_success_ratio + + @property + def min_successes(self) -> Optional[int]: + return self._min_successes + + @property + def concurrency(self) -> Optional[int]: + return self._concurrency + + @property + def python_function_task(self) -> Union[PythonFunctionTask, PythonInstanceTask]: + return self._run_task + + @property + def bound_inputs(self) -> Set[str]: + return self._bound_inputs + + @contextmanager + def prepare_target(self): + """ + Alters the underlying run_task command to modify it for map task execution and then resets it after. + """ + self.python_function_task.set_command_fn(self.get_command) + try: + yield + finally: + self.python_function_task.reset_command_fn() + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return ArrayJob(parallelism=self._concurrency, min_success_ratio=self._min_success_ratio).to_dict() + + def get_container(self, settings: SerializationSettings) -> Container: + with self.prepare_target(): + return self.python_function_task.get_container(settings) + + def get_k8s_pod(self, settings: SerializationSettings) -> K8sPod: + with self.prepare_target(): + return self.python_function_task.get_k8s_pod(settings) + + def get_sql(self, settings: SerializationSettings) -> Sql: + with self.prepare_target(): + return self.python_function_task.get_sql(settings) + + def get_command(self, settings: SerializationSettings) -> List[str]: + """ + TODO ADD bound variables to the resolver. Maybe we need a different resolver? + """ + mt = ArrayNodeMapTaskResolver() + container_args = [ + "pyflyte-map-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--experimental", + "--resolver", + mt.name(), + "--", + *mt.loader_args(settings, self), + ] + + # TODO: add support for ContainerTask + # if self._cmd_prefix: + # return self._cmd_prefix + container_args + return container_args + + def __call__(self, *args, **kwargs): + """ + This call method modifies the kwargs and adds kwargs from partial. + This is mostly done in the local_execute and compilation only. + At runtime, the map_task is created with all the inputs filled in. to support this, we have modified + the map_task interface in the constructor. + """ + if self._partial: + """If partial exists, then mix-in all partial values""" + kwargs = {**self._partial.keywords, **kwargs} + return super().__call__(*args, **kwargs) + + def execute(self, **kwargs) -> Any: + ctx = FlyteContextManager.current_context() + if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION: + return self._execute_map_task(ctx, **kwargs) + + return self._raw_execute(**kwargs) + + def _execute_map_task(self, _: FlyteContext, **kwargs) -> Any: + task_index = self._compute_array_job_index() + map_task_inputs = {} + for k in self.interface.inputs.keys(): + v = kwargs[k] + if isinstance(v, list) and k not in self.bound_inputs: + map_task_inputs[k] = v[task_index] + else: + map_task_inputs[k] = v + return exception_scopes.user_entry_point(self.python_function_task.execute)(**map_task_inputs) + + @staticmethod + def _compute_array_job_index() -> int: + """ + Computes the absolute index of the current array job. This is determined by summing the compute-environment-specific + environment variable and the offset (if one's set). The offset will be set and used when the user request that the + job runs in a number of slots less than the size of the input. + """ + return int(os.environ.get("BATCH_JOB_ARRAY_INDEX_OFFSET", "0")) + int( + os.environ.get(os.environ.get("BATCH_JOB_ARRAY_INDEX_VAR_NAME", "0"), "0") + ) + + @property + def _outputs_interface(self) -> Dict[Any, Variable]: + """ + We override this method from PythonTask because the dispatch_execute method uses this + interface to construct outputs. Each instance of an container_array task will however produce outputs + according to the underlying run_task interface and the array plugin handler will actually create a collection + from these individual outputs as the final output value. + """ + + ctx = FlyteContextManager.current_context() + if ctx.execution_state is not None and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + # In workflow execution mode we actually need to use the parent (mapper) task output interface. + return self.interface.outputs + return self.python_function_task.interface.outputs + + def get_type_for_output_var(self, k: str, v: Any) -> type: + """ + We override this method from flytekit.core.base_task Task because the dispatch_execute method uses this + interface to construct outputs. Each instance of an container_array task will however produce outputs + according to the underlying run_task interface and the array plugin handler will actually create a collection + from these individual outputs as the final output value. + """ + ctx = FlyteContextManager.current_context() + if ctx.execution_state and ctx.execution_state.is_local_execution(): + # In workflow execution mode we actually need to use the parent (mapper) task output interface. + return self._python_interface.outputs[k] + return self.python_function_task.python_interface.outputs[k] + + def _raw_execute(self, **kwargs) -> Any: + """ + This is called during locally run executions. Unlike array task execution on the Flyte platform, _raw_execute + produces the full output collection. + """ + outputs_expected = True + if not self.interface.outputs: + outputs_expected = False + outputs = [] + + any_input_key = ( + list(self.python_function_task.interface.inputs.keys())[0] + if self.python_function_task.interface.inputs.items() is not None + else None + ) + + for i in range(len(kwargs[any_input_key])): + single_instance_inputs = {} + for k in self.interface.inputs.keys(): + v = kwargs[k] + if isinstance(v, list) and k not in self._bound_inputs: + single_instance_inputs[k] = kwargs[k][i] + else: + single_instance_inputs[k] = kwargs[k] + o = exception_scopes.user_entry_point(self.python_function_task.execute)(**single_instance_inputs) + if outputs_expected: + outputs.append(o) + + return outputs + + +def map_task( + task_function: PythonFunctionTask, + concurrency: int = 0, + # TODO why no min_successes? + min_success_ratio: float = 1.0, + **kwargs, +): + """Map task that uses the ``ArrayNode`` construct.. + + .. important:: + + This is an experimental drop-in replacement for :py:func:`~flytekit.map_task`. + + :param task_function: This argument is implicitly passed and represents the repeatable function + :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch + size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until + all inputs are processed. If left unspecified, this means unbounded concurrency. + :param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete + successfully before terminating this task and marking it successful. + """ + return ArrayNodeMapTask(task_function, concurrency=concurrency, min_success_ratio=min_success_ratio, **kwargs) + + +class ArrayNodeMapTaskResolver(tracker.TrackedInstance, TaskResolverMixin): + """ + Special resolver that is used for ArrayNodeMapTasks. + This exists because it is possible that ArrayNodeMapTasks are created using nested "partial" subtasks. + When a maptask is created its interface is interpolated from the interface of the subtask - the interpolation, + simply converts every input into a list/collection input. + + For example: + interface -> (i: int, j: str) -> str => map_task interface -> (i: List[int], j: List[str]) -> List[str] + + But in cases in which `j` is bound to a fixed value by using `functools.partial` we need a way to ensure that + the interface is not simply interpolated, but only the unbound inputs are interpolated. + + .. code-block:: python + + def foo((i: int, j: str) -> str: + ... + + mt = map_task(functools.partial(foo, j=10)) + + print(mt.interface) + + output: + + (i: List[int], j: str) -> List[str] + + But, at runtime this information is lost. To reconstruct this, we use ArrayNodeMapTaskResolver that records the "bound vars" + and then at runtime reconstructs the interface with this knowledge + """ + + def name(self) -> str: + return "ArrayNodeMapTaskResolver" + + @timeit("Load map task") + def load_task(self, loader_args: List[str], max_concurrency: int = 0) -> ArrayNodeMapTask: + """ + Loader args should be of the form + vars "var1,var2,.." resolver "resolver" [resolver_args] + """ + _, bound_vars, _, resolver, *resolver_args = loader_args + logging.info(f"MapTask found task resolver {resolver} and arguments {resolver_args}") + resolver_obj = load_object_from_module(resolver) + # Use the resolver to load the actual task object + _task_def = resolver_obj.load_task(loader_args=resolver_args) + bound_inputs = set(bound_vars.split(",")) + return ArrayNodeMapTask( + python_function_task=_task_def, max_concurrency=max_concurrency, bound_inputs=bound_inputs + ) + + def loader_args(self, settings: SerializationSettings, t: ArrayNodeMapTask) -> List[str]: # type:ignore + return [ + "vars", + f'{",".join(t.bound_inputs)}', + "resolver", + t.python_function_task.task_resolver.location, + *t.python_function_task.task_resolver.loader_args(settings, t.python_function_task), + ] + + def get_all_tasks(self) -> List[Task]: + raise NotImplementedError("MapTask resolver cannot return every instance of the map task") diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index e98d5d30e9..dc46e6bc4f 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -11,17 +11,18 @@ kwtypes PythonTask Task - TaskMetadata TaskResolverMixin IgnoreOutputs """ +import asyncio import collections import datetime +import inspect from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast +from typing import Any, Coroutine, Dict, Generic, List, Optional, OrderedDict, Tuple, Type, TypeVar, Union, cast from flytekit.configuration import SerializationSettings from flytekit.core.context_manager import ( @@ -233,7 +234,9 @@ def get_input_types(self) -> Optional[Dict[str, type]]: """ return None - def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: + def local_execute( + self, ctx: FlyteContext, **kwargs + ) -> Union[Tuple[Promise], Promise, VoidPromise, Coroutine, None]: """ This function is used only in the local execution path and is responsible for calling dispatch execute. Use this function when calling a task with native values (or Promises containing Flyte literals derived from @@ -283,6 +286,10 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr # Code is simpler with duplication and less metaprogramming, but introduces regressions # if one is changed and not the other. outputs_literal_map = self.sandbox_execute(ctx, input_literal_map) + + if inspect.iscoroutine(outputs_literal_map): + return outputs_literal_map + outputs_literals = outputs_literal_map.literals # TODO maybe this is the part that should be done for local execution, we pass the outputs to some special @@ -300,7 +307,7 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr vals = [Promise(var, outputs_literals[var]) for var in output_names] return create_task_output(vals, self.python_interface) - def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: + def __call__(self, *args: object, **kwargs: object) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: return flyte_entity_call_handler(self, *args, **kwargs) # type: ignore def compile(self, ctx: FlyteContext, *args, **kwargs): @@ -488,7 +495,7 @@ def construct_node_metadata(self) -> _workflow_model.NodeMetadata: interruptible=self.metadata.interruptible, ) - def compile(self, ctx: FlyteContext, *args, **kwargs): + def compile(self, ctx: FlyteContext, *args, **kwargs) -> Optional[Union[Tuple[Promise], Promise, VoidPromise]]: """ Generates a node that encapsulates this task in a workflow definition. """ @@ -498,9 +505,72 @@ def compile(self, ctx: FlyteContext, *args, **kwargs): def _outputs_interface(self) -> Dict[Any, Variable]: return self.interface.outputs # type: ignore + def _output_to_literal_map(self, native_outputs, exec_ctx): + expected_output_names = list(self._outputs_interface.keys()) + if len(expected_output_names) == 1: + # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of + # length one. That convention is used for naming outputs - and single-length-NamedTuples are + # particularly troublesome but elegant handling of them is not a high priority + # Again, we're using the output_tuple_name as a proxy. + if self.python_interface.output_tuple_name and isinstance(native_outputs, tuple): + native_outputs_as_map = {expected_output_names[0]: native_outputs[0]} + else: + native_outputs_as_map = {expected_output_names[0]: native_outputs} + elif len(expected_output_names) == 0: + native_outputs_as_map = {} + else: + native_outputs_as_map = {expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs)} + + # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption + # built into the IDL that all the values of a literal map are of the same type. + with timeit("Translate the output to literals"): + literals = {} + for i, (k, v) in enumerate(native_outputs_as_map.items()): + literal_type = self._outputs_interface[k].type + py_type = self.get_type_for_output_var(k, v) + + if isinstance(v, tuple): + raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}") + try: + literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type) + except Exception as e: + # only show the name of output key if it's user-defined (by default Flyte names these as "o") + key = k if k != f"o{i}" else i + msg = f"Failed to convert outputs of task '{self.name}' at position {key}:\n {e}" + logger.error(msg) + raise TypeError(msg) from e + + return _literal_models.LiteralMap(literals=literals), native_outputs_as_map + + def _write_decks(self, native_inputs, native_outputs_as_map, ctx, new_user_params): + if self._disable_deck is False: + from flytekit.deck.deck import Deck, _output_deck + + INPUT = "input" + OUTPUT = "output" + + input_deck = Deck(INPUT) + for k, v in native_inputs.items(): + input_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_input_var(k, v))) + + output_deck = Deck(OUTPUT) + for k, v in native_outputs_as_map.items(): + output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v))) + + if ctx.execution_state and ctx.execution_state.is_local_execution(): + # When we run the workflow remotely, flytekit outputs decks at the end of _dispatch_execute + _output_deck(self.name.split(".")[-1], new_user_params) + + async def _async_execute(self, native_inputs, native_outputs, ctx, exec_ctx, new_user_params): + native_outputs = await native_outputs + native_outputs = self.post_execute(new_user_params, native_outputs) + literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx) + self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params) + return literals_map + def dispatch_execute( self, ctx: FlyteContext, input_literal_map: _literal_models.LiteralMap - ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec]: + ) -> Union[_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec, Coroutine]: """ This method translates Flyte's Type system based input values and invokes the actual call to the executor This method is also invoked during runtime. @@ -510,10 +580,8 @@ def dispatch_execute( may be none * ``DynamicJobSpec`` is returned when a dynamic workflow is executed """ - # Invoked before the task is executed new_user_params = self.pre_execute(ctx.user_space_params) - from flytekit.deck.deck import _output_deck # Create another execution context with the new user params, but let's keep the same working dir with FlyteContextManager.with_context( @@ -543,6 +611,23 @@ def dispatch_execute( logger.exception(f"Exception when executing {e}") raise e + if inspect.iscoroutine(native_outputs): + # If native outputs is a coroutine, then this is an eager workflow. + if exec_ctx.execution_state: + if exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION: + # Just return task outputs as a coroutine if the eager workflow is being executed locally, + # outside of a workflow. This preserves the expectation that the eager workflow is an async + # function. + return native_outputs + elif exec_ctx.execution_state.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION: + # If executed inside of a workflow being executed locally, then run the coroutine to get the + # actual results. + return asyncio.run( + self._async_execute(native_inputs, native_outputs, ctx, exec_ctx, new_user_params) + ) + + return self._async_execute(native_inputs, native_outputs, ctx, exec_ctx, new_user_params) + logger.debug("Task executed successfully in user level") # Lets run the post_execute method. This may result in a IgnoreOutputs Exception, which is # bubbled up to be handled at the callee layer. @@ -551,68 +636,13 @@ def dispatch_execute( # Short circuit the translation to literal map because what's returned may be a dj spec (or an # already-constructed LiteralMap if the dynamic task was a no-op), not python native values # dynamic_execute returns a literal map in local execute so this also gets triggered. - if isinstance(native_outputs, _literal_models.LiteralMap) or isinstance( - native_outputs, _dynamic_job.DynamicJobSpec - ): + if isinstance(native_outputs, (_literal_models.LiteralMap, _dynamic_job.DynamicJobSpec)): return native_outputs - expected_output_names = list(self._outputs_interface.keys()) - if len(expected_output_names) == 1: - # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of - # length one. That convention is used for naming outputs - and single-length-NamedTuples are - # particularly troublesome but elegant handling of them is not a high priority - # Again, we're using the output_tuple_name as a proxy. - if self.python_interface.output_tuple_name and isinstance(native_outputs, tuple): - native_outputs_as_map = {expected_output_names[0]: native_outputs[0]} - else: - native_outputs_as_map = {expected_output_names[0]: native_outputs} - elif len(expected_output_names) == 0: - native_outputs_as_map = {} - else: - native_outputs_as_map = { - expected_output_names[i]: native_outputs[i] for i, _ in enumerate(native_outputs) - } - - # We manually construct a LiteralMap here because task inputs and outputs actually violate the assumption - # built into the IDL that all the values of a literal map are of the same type. - with timeit("Translate the output to literals"): - literals = {} - for i, (k, v) in enumerate(native_outputs_as_map.items()): - literal_type = self._outputs_interface[k].type - py_type = self.get_type_for_output_var(k, v) - - if isinstance(v, tuple): - raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}") - try: - literals[k] = TypeEngine.to_literal(exec_ctx, v, py_type, literal_type) - except Exception as e: - # only show the name of output key if it's user-defined (by default Flyte names these as "o") - key = k if k != f"o{i}" else i - msg = f"Failed to convert outputs of task '{self.name}' at position {key}:\n {e}" - logger.error(msg) - raise TypeError(msg) from e - - if self._disable_deck is False: - from flytekit.deck.deck import Deck - - INPUT = "input" - OUTPUT = "output" - - input_deck = Deck(INPUT) - for k, v in native_inputs.items(): - input_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_input_var(k, v))) - - output_deck = Deck(OUTPUT) - for k, v in native_outputs_as_map.items(): - output_deck.append(TypeEngine.to_html(ctx, v, self.get_type_for_output_var(k, v))) - - if ctx.execution_state and ctx.execution_state.is_local_execution(): - # When we run the workflow remotely, flytekit outputs decks at the end of _dispatch_execute - _output_deck(self.name.split(".")[-1], new_user_params) - - outputs_literal_map = _literal_models.LiteralMap(literals=literals) + literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx) + self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params) # After the execute has been successfully completed - return outputs_literal_map + return literals_map def pre_execute(self, user_params: Optional[ExecutionParameters]) -> Optional[ExecutionParameters]: # type: ignore """ diff --git a/flytekit/core/condition.py b/flytekit/core/condition.py index 37c4afc88f..bc7b4df865 100644 --- a/flytekit/core/condition.py +++ b/flytekit/core/condition.py @@ -271,6 +271,13 @@ def __init__( self._output_promise: Optional[Union[Tuple[Promise], Promise]] = None self._err: Optional[str] = None self._stmt = stmt + self._output_node = None + + @property + def output_node(self) -> Optional[Node]: + # This is supposed to hold a pointer to the node that created this case. + # It is set in the then() call. but the value will not be set if it's a VoidPromise or None was returned. + return self._output_node @property def expr(self) -> Optional[Union[ComparisonExpression, ConjunctionExpression]]: @@ -289,6 +296,21 @@ def then( self, p: Union[Promise, Tuple[Promise]] ) -> Optional[Union[Condition, Promise, Tuple[Promise], VoidPromise]]: self._output_promise = p + if isinstance(p, Promise): + if not p.is_ready: + self._output_node = p.ref.node # type: ignore + elif isinstance(p, VoidPromise): + if p.ref is not None: + self._output_node = p.ref.node + elif hasattr(p, "_fields"): + # This condition detects the NamedTuple case and iterates through the fields to find one that has a node + # which should be the first one. + for f in p._fields: # type: ignore + prom = getattr(p, f) + if not prom.is_ready: + self._output_node = prom.ref.node + break + # We can always mark branch as completed return self._cs.end_branch() @@ -391,6 +413,8 @@ def transform_to_conj_expr( def transform_to_operand(v: Union[Promise, Literal]) -> Tuple[_core_cond.Operand, Optional[Promise]]: if isinstance(v, Promise): return _core_cond.Operand(var=create_branch_node_promise_var(v.ref.node_id, v.var)), v + if v.scalar.none_type: + return _core_cond.Operand(scalar=v.scalar), None return _core_cond.Operand(primitive=v.scalar.primitive), None @@ -415,7 +439,8 @@ def transform_to_boolexpr( def to_case_block(c: Case) -> Tuple[Union[_core_wf.IfBlock], typing.List[Promise]]: expr, promises = transform_to_boolexpr(cast(Union[ComparisonExpression, ConjunctionExpression], c.expr)) - n = c.output_promise.ref.node # type: ignore + if c.output_promise is not None: + n = c.output_node return _core_wf.IfBlock(condition=expr, then_node=n), promises @@ -438,7 +463,7 @@ def to_ifelse_block(node_id: str, cs: ConditionalSection) -> Tuple[_core_wf.IfEl node = None err = None if last_case.output_promise is not None: - node = last_case.output_promise.ref.node # type: ignore + node = last_case.output_node else: err = Error(failed_node_id=node_id, message=last_case.err if last_case.err else "Condition failed") return ( diff --git a/flytekit/core/constants.py b/flytekit/core/constants.py index 5f5d892e17..c655b5791f 100644 --- a/flytekit/core/constants.py +++ b/flytekit/core/constants.py @@ -9,6 +9,7 @@ class SdkTaskType(object): PYTHON_TASK = "python-task" DYNAMIC_TASK = "dynamic-task" CONTAINER_ARRAY_TASK = "container_array" + EXPERIMENTAL_ARRAY_NODE_TASK = "array_node" SPARK_TASK = "spark" # Hive is multi-step operation: diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index aa6b0e3e4d..de85c0be97 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -349,9 +349,12 @@ def __getattr__(self, item: str) -> _GroupSecrets: """ return self._GroupSecrets(item, self) - def get(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: + def get( + self, group: str, key: Optional[str] = None, group_version: Optional[str] = None, encode_mode: str = "r" + ) -> str: """ Retrieves a secret using the resolution order -> Env followed by file. If not found raises a ValueError + param encode_mode, defines the mode to open files, it can either be "r" to read file, or "rb" to read binary file """ self.check_group_key(group) env_var = self.get_secrets_env_var(group, key, group_version) @@ -360,10 +363,11 @@ def get(self, group: str, key: Optional[str] = None, group_version: Optional[str if v is not None: return v if os.path.exists(fpath): - with open(fpath, "r") as f: + with open(fpath, encode_mode) as f: return f.read().strip() raise ValueError( - f"Unable to find secret for key {key} in group {group} " f"in Env Var:{env_var} and FilePath: {fpath}" + f"Please make sure to add secret_requests=[Secret(group={group}, key={key})] in @task. Unable to find secret for key {key} in group {group} " + f"in Env Var:{env_var} and FilePath: {fpath}" ) def get_secrets_env_var(self, group: str, key: Optional[str] = None, group_version: Optional[str] = None) -> str: @@ -485,6 +489,9 @@ class Mode(Enum): # or propeller. LOCAL_TASK_EXECUTION = 3 + # This is the mode that is used to indicate a dynamic task + DYNAMIC_TASK_EXECUTION = 4 + mode: Optional[ExecutionState.Mode] working_dir: Union[os.PathLike, str] engine_dir: Optional[Union[os.PathLike, str]] diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index fa29d5e128..f7e04d6403 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -41,7 +41,7 @@ _ANON = "anon" -def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): +def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False) -> Dict[str, Any]: kwargs: Dict[str, Any] = { "cache_regions": True, } @@ -61,6 +61,41 @@ def s3_setup_args(s3_cfg: configuration.S3Config, anonymous: bool = False): return kwargs +def azure_setup_args(azure_cfg: configuration.AzureBlobStorageConfig, anonymous: bool = False) -> Dict[str, Any]: + kwargs: Dict[str, Any] = {} + + if azure_cfg.account_name: + kwargs["account_name"] = azure_cfg.account_name + if azure_cfg.account_key: + kwargs["account_key"] = azure_cfg.account_key + if azure_cfg.client_id: + kwargs["client_id"] = azure_cfg.client_id + if azure_cfg.client_secret: + kwargs["client_secret"] = azure_cfg.client_secret + if azure_cfg.tenant_id: + kwargs["tenant_id"] = azure_cfg.tenant_id + kwargs[_ANON] = anonymous + return kwargs + + +def get_fsspec_storage_options( + protocol: str, data_config: typing.Optional[DataConfig] = None, anonymous: bool = False, **kwargs +) -> Dict[str, Any]: + data_config = data_config or DataConfig.auto() + + if protocol == "file": + return {"auto_mkdir": True, **kwargs} + if protocol == "s3": + return {**s3_setup_args(data_config.s3, anonymous=anonymous), **kwargs} + if protocol == "gs": + if anonymous: + kwargs["token"] = _ANON + return kwargs + if protocol in ("abfs", "abfss"): + return {**azure_setup_args(data_config.azure, anonymous=anonymous), **kwargs} + return {} + + class FileAccessProvider(object): """ This is the class that is available through the FlyteContext and can be used for persisting data to the remote @@ -106,25 +141,15 @@ def data_config(self) -> DataConfig: def get_filesystem( self, protocol: typing.Optional[str] = None, anonymous: bool = False, **kwargs - ) -> typing.Optional[fsspec.AbstractFileSystem]: + ) -> fsspec.AbstractFileSystem: if not protocol: return self._default_remote - if protocol == "file": - kwargs["auto_mkdir"] = True - elif protocol == "s3": - s3kwargs = s3_setup_args(self._data_config.s3, anonymous=anonymous) - s3kwargs.update(kwargs) - return fsspec.filesystem(protocol, **s3kwargs) # type: ignore - elif protocol == "gs": - if anonymous: - kwargs["token"] = _ANON - return fsspec.filesystem(protocol, **kwargs) # type: ignore - - # Preserve old behavior of returning None for file systems that don't have an explicit anonymous option. - if anonymous: - return None - return fsspec.filesystem(protocol, **kwargs) # type: ignore + storage_options = get_fsspec_storage_options( + protocol=protocol, anonymous=anonymous, data_config=self._data_config, **kwargs + ) + + return fsspec.filesystem(protocol, **storage_options) def get_filesystem_for_path(self, path: str = "", anonymous: bool = False, **kwargs) -> fsspec.AbstractFileSystem: protocol = get_protocol(path) @@ -183,7 +208,7 @@ def exists(self, path: str) -> bool: return anon_fs.exists(path) raise oe - def get(self, from_path: str, to_path: str, recursive: bool = False): + def get(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): file_system = self.get_filesystem_for_path(from_path) if recursive: from_path, to_path = self.recursive_paths(from_path, to_path) @@ -194,13 +219,13 @@ def get(self, from_path: str, to_path: str, recursive: bool = False): return shutil.copytree( self.strip_file_header(from_path), self.strip_file_header(to_path), dirs_exist_ok=True ) - return file_system.get(from_path, to_path, recursive=recursive) + return file_system.get(from_path, to_path, recursive=recursive, **kwargs) except OSError as oe: logger.debug(f"Error in getting {from_path} to {to_path} rec {recursive} {oe}") file_system = self.get_filesystem(get_protocol(from_path), anonymous=True) if file_system is not None: logger.debug(f"Attempting anonymous get with {file_system}") - return file_system.get(from_path, to_path, recursive=recursive) + return file_system.get(from_path, to_path, recursive=recursive, **kwargs) raise oe def put(self, from_path: str, to_path: str, recursive: bool = False, **kwargs): @@ -287,7 +312,7 @@ def upload_directory(self, local_path: str, remote_path: str): """ return self.put_data(local_path, remote_path, is_multipart=True) - def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False): + def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False, **kwargs): """ :param remote_path: :param local_path: @@ -296,7 +321,7 @@ def get_data(self, remote_path: str, local_path: str, is_multipart: bool = False try: pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) with timeit(f"Download data to local from {remote_path}"): - self.get(remote_path, to_path=local_path, recursive=is_multipart) + self.get(remote_path, to_path=local_path, recursive=is_multipart, **kwargs) except Exception as ex: raise FlyteAssertion( f"Failed to get data from {remote_path} to {local_path} (recursive={is_multipart}).\n\n" diff --git a/flytekit/core/dynamic_workflow_task.py b/flytekit/core/dynamic_workflow_task.py index 8f429dcf6c..4eb1e9906b 100644 --- a/flytekit/core/dynamic_workflow_task.py +++ b/flytekit/core/dynamic_workflow_task.py @@ -20,7 +20,7 @@ dynamic = functools.partial(task.task, execution_mode=PythonFunctionTask.ExecutionBehavior.DYNAMIC) dynamic.__doc__ = """ Please first see the comments for :py:func:`flytekit.task` and :py:func:`flytekit.workflow`. This ``dynamic`` -concept is an amalgamation of both and enables the user to pursue some :std:ref:`pretty incredible ` +concept is an amalgamation of both and enables the user to pursue some :std:ref:`pretty incredible ` constructs. In short, a task's function is run at execution time only, and a workflow function is run at compilation time only (local @@ -28,7 +28,7 @@ body is run to produce a workflow. It is almost as if the decorator changed from ``@task`` to ``@workflow`` except workflows cannot make use of their inputs like native Python values whereas dynamic workflows can. The resulting workflow is passed back to the Flyte engine and is -run as a :std:ref:`subworkflow `. Simple usage +run as a :std:ref:`subworkflow `. Simple usage .. code-block:: @@ -49,5 +49,5 @@ def my_dynamic_subwf(a: int, b: int) -> int: x = t1(a=a) return t2(b=b, x=x) -See the :std:ref:`cookbook ` for a longer discussion. +See the :std:ref:`cookbook ` for a longer discussion. """ # noqa: W293 diff --git a/flytekit/core/gate.py b/flytekit/core/gate.py index a09a8d82ce..241eaf2f7e 100644 --- a/flytekit/core/gate.py +++ b/flytekit/core/gate.py @@ -6,6 +6,7 @@ import click +from flytekit.core import constants from flytekit.core import interface as flyte_interface from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.promise import Promise, VoidPromise, flyte_entity_call_handler @@ -94,18 +95,25 @@ def construct_node_metadata(self) -> _workflow_model.NodeMetadata: # This is to satisfy the LocallyExecutable protocol def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: if self.sleep_duration: - print(f"Mock sleeping for {self.sleep_duration}") + click.echo( + f'{click.style("[Sleep Gate]", fg="yellow")} ' + f'{click.style(f"Simulating Sleep for {self.sleep_duration}", fg="cyan")}' + ) return VoidPromise(self.name) # Trigger stdin if self.input_type: - msg = f"Execution stopped for gate {self.name}...\n" + msg = click.style("[Input Gate] ", fg="yellow") + click.style( + f"Waiting for input @{self.name} of type {self.input_type}", fg="cyan" + ) literal = parse_stdin_to_literal(ctx, self.input_type, msg) p = Promise(var="o0", val=literal) return p # Assume this is an approval operation since that's the only remaining option. - msg = f"Pausing execution for {self.name}, literal value is:\n{typing.cast(Promise, self._upstream_item).val}\nContinue?" + msg = click.style("[Approval Gate] ", fg="yellow") + click.style( + f"@{self.name} Approve {typing.cast(Promise, self._upstream_item).val.value}?", fg="cyan" + ) proceed = click.confirm(msg, default=True) if proceed: # We need to return a promise here, and a promise is what should've been passed in by the call in approve() @@ -172,6 +180,8 @@ def approve(upstream_item: Union[Tuple[Promise], Promise, VoidPromise], name: st ctx = FlyteContextManager.current_context() upstream_item = typing.cast(Promise, upstream_item) if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: + if upstream_item.ref.node_id == constants.GLOBAL_INPUT_NODE_ID: + raise ValueError("Workflow inputs cannot be passed to approval nodes.") if not upstream_item.ref.node.flyte_entity.python_interface: raise ValueError( f"Upstream node doesn't have a Python interface. Node entity is: " diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 5dbca3a893..a548f6a49b 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -288,7 +288,7 @@ def transform_interface_to_list_interface( """ Takes a single task interface and interpolates it to an array interface - to allow performing distributed python map like functions - :param interface: Interface to be upgraded toa list interface + :param interface: Interface to be upgraded to a list interface :param bound_inputs: fixed inputs that should not upgraded to a list and will be maintained as scalars. """ map_inputs = transform_types_to_list_of_type(interface.inputs, bound_inputs) diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index 5a544bc316..e47b731ac6 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -147,6 +147,7 @@ def set_command_prefix(self, cmd: typing.Optional[typing.List[str]]): @contextmanager def prepare_target(self): """ + TODO: why do we do this? Alters the underlying run_task command to modify it for map task execution and then resets it after. """ self._run_task.set_command_fn(self.get_command) diff --git a/flytekit/core/node.py b/flytekit/core/node.py index bf5c97ba60..1038c00521 100644 --- a/flytekit/core/node.py +++ b/flytekit/core/node.py @@ -12,6 +12,32 @@ from flytekit.models.task import Resources as _resources_model +def assert_not_promise(v: Any, location: str): + """ + This function will raise an exception if the value is a promise. This should be used to ensure that we don't + accidentally use a promise in a place where we don't support it. + """ + from flytekit.core.promise import Promise + + if isinstance(v, Promise): + raise AssertionError(f"Cannot use a promise in the {location} Value: {v}") + + +def assert_no_promises_in_resources(resources: _resources_model): + """ + This function will raise an exception if any of the resources have promises in them. This is because we don't + support promises in resources / runtime overriding of resources through input values. + """ + if resources is None: + return + if resources.requests is not None: + for r in resources.requests: + assert_not_promise(r.value, "resources.requests") + if resources.limits is not None: + for r in resources.limits: + assert_not_promise(r.value, "resources.limits") + + class Node(object): """ This class will hold all the things necessary to make an SdkNode but we won't make one until we know things like @@ -86,7 +112,10 @@ def with_overrides(self, *args, **kwargs): if "node_name" in kwargs: # Convert the node name into a DNS-compliant. # https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#dns-subdomain-names - self._id = _dnsify(kwargs["node_name"]) + v = kwargs["node_name"] + assert_not_promise(v, "node_name") + self._id = _dnsify(v) + if "aliases" in kwargs: alias_dict = kwargs["aliases"] if not isinstance(alias_dict, dict): @@ -94,6 +123,7 @@ def with_overrides(self, *args, **kwargs): self._aliases = [] for k, v in alias_dict.items(): self._aliases.append(_workflow_model.Alias(var=k, alias=v)) + if "requests" in kwargs or "limits" in kwargs: requests = kwargs.get("requests") if requests and not isinstance(requests, Resources): @@ -101,8 +131,10 @@ def with_overrides(self, *args, **kwargs): limits = kwargs.get("limits") if limits and not isinstance(limits, Resources): raise AssertionError("limits should be specified as flytekit.Resources") + resources = convert_resources_to_resource_model(requests=requests, limits=limits) + assert_no_promises_in_resources(resources) + self._resources = resources - self._resources = convert_resources_to_resource_model(requests=requests, limits=limits) if "timeout" in kwargs: timeout = kwargs["timeout"] if timeout is None: @@ -115,21 +147,31 @@ def with_overrides(self, *args, **kwargs): raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds") if "retries" in kwargs: retries = kwargs["retries"] + assert_not_promise(retries, "retries") self._metadata._retries = ( _literal_models.RetryStrategy(0) if retries is None else _literal_models.RetryStrategy(retries) ) + if "interruptible" in kwargs: + v = kwargs["interruptible"] + assert_not_promise(v, "interruptible") self._metadata._interruptible = kwargs["interruptible"] + if "name" in kwargs: self._metadata._name = kwargs["name"] + if "task_config" in kwargs: logger.warning("This override is beta. We may want to revisit this in the future.") new_task_config = kwargs["task_config"] if not isinstance(new_task_config, type(self.flyte_entity._task_config)): raise ValueError("can't change the type of the task config") self.flyte_entity._task_config = new_task_config + if "container_image" in kwargs: - self.flyte_entity._container_image = kwargs["container_image"] + v = kwargs["container_image"] + assert_not_promise(v, "container_image") + self.flyte_entity._container_image = v + return self diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index c2de88599e..705188c348 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -24,7 +24,7 @@ def create_node( """ This is the function you want to call if you need to specify dependencies between tasks that don't consume and/or don't produce outputs. For example, if you have t1() and t2(), both of which do not take in nor produce any - outputs, how do you specify that t2 should run before t1? + outputs, how do you specify that t2 should run before t1? :: t1_node = create_node(t1) t2_node = create_node(t2) @@ -33,36 +33,23 @@ def create_node( # OR t2_node >> t1_node - This works for tasks that take inputs as well, say a ``t3(in1: int)`` + This works for tasks that take inputs as well, say a ``t3(in1: int)`` :: t3_node = create_node(t3, in1=some_int) # basically calling t3(in1=some_int) - You can still use this method to handle setting certain overrides + You can still use this method to handle setting certain overrides :: t3_node = create_node(t3, in1=some_int).with_overrides(...) - Outputs, if there are any, will be accessible. A `t4() -> (int, str)` + Outputs, if there are any, will be accessible. A `t4() -> (int, str)` :: t4_node = create_node(t4) - in compilation node.o0 has the promise. + In compilation node.o0 has the promise. :: t5(in1=t4_node.o0) - in local workflow execution, what is the node? Can it just be the named tuple? - t5(in1=t4_node.o0) - - @workflow - def wf(): - create_node(sub_wf) - create_node(wf2) - - @dynamic - def sub_wf(): - create_node(other_sub) - create_node(task) - If t1 produces only one output, note that in local execution, you still get a wrapper object that - needs to be dereferenced by the output name. + needs to be dereferenced by the output name. :: t1_node = create_node(t1) t2(t1_node.o0) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 35b72b5a56..5d598a017a 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1,8 +1,9 @@ from __future__ import annotations import collections +import inspect from enum import Enum -from typing import Any, Dict, List, Optional, Set, Tuple, Union, cast +from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast from typing_extensions import Protocol, get_args @@ -136,12 +137,26 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr self._lhs = lhs if lhs.is_ready: if lhs.val.scalar is None or lhs.val.scalar.primitive is None: - raise ValueError("Only primitive values can be used in comparison") + union = lhs.val.scalar.union + if union and union.value.scalar: + if union.value.scalar.primitive or union.value.scalar.none_type: + self._lhs = union.value + else: + raise ValueError("Only primitive values can be used in comparison") + else: + raise ValueError("Only primitive values can be used in comparison") if isinstance(rhs, Promise): self._rhs = rhs if rhs.is_ready: if rhs.val.scalar is None or rhs.val.scalar.primitive is None: - raise ValueError("Only primitive values can be used in comparison") + union = rhs.val.scalar.union + if union and union.value.scalar: + if union.value.scalar.primitive or union.value.scalar.none_type: + self._rhs = union.value + else: + raise ValueError("Only primitive values can be used in comparison") + else: + raise ValueError("Only primitive values can be used in comparison") if self._lhs is None: self._lhs = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), lhs, type(lhs), None) if self._rhs is None: @@ -162,11 +177,15 @@ def op(self) -> ComparisonOps: def eval(self) -> bool: if isinstance(self.lhs, Promise): lhs = self.lhs.eval() + elif self.lhs.scalar.none_type: + lhs = None else: lhs = get_primitive_val(self.lhs.scalar.primitive) if isinstance(self.rhs, Promise): rhs = self.rhs.eval() + elif self.rhs.scalar.none_type: + rhs = None else: rhs = get_primitive_val(self.rhs.scalar.primitive) @@ -350,9 +369,12 @@ def is_(self, v: bool) -> ComparisonExpression: def is_false(self) -> ComparisonExpression: return self.is_(False) - def is_true(self): + def is_true(self) -> ComparisonExpression: return self.is_(True) + def is_none(self) -> ComparisonExpression: + return ComparisonExpression(self, ComparisonOps.EQ, None) + def __eq__(self, other) -> ComparisonExpression: # type: ignore return ComparisonExpression(self, ComparisonOps.EQ, other) @@ -499,13 +521,6 @@ def with_overrides(self, *args, **kwargs): val.with_overrides(*args, **kwargs) return self - @property - def ref(self): - for p in promises: - if p.ref: - return p.ref - return None - def runs_before(self, other: Any): """ This function is just here to allow local workflow execution to run. See the corresponding function in @@ -969,7 +984,7 @@ def local_execution_mode(self) -> ExecutionState.Mode: def flyte_entity_call_handler( entity: SupportsNodeCreation, *args, **kwargs -) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: +) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, Coroutine, None]: """ This function is the call handler for tasks, workflows, and launch plans (which redirects to the underlying workflow). The logic is the same for all three, but we did not want to create base class, hence this separate @@ -1026,7 +1041,7 @@ def flyte_entity_call_handler( return create_task_output(vals, cast(SupportsNodeCreation, entity).python_interface) else: return None - return cast(LocallyExecutable, entity).local_execute(child_ctx, **kwargs) + return cast(LocallyExecutable, entity).local_execute(ctx, **kwargs) else: mode = cast(LocallyExecutable, entity).local_execution_mode() with FlyteContextManager.with_context( @@ -1042,6 +1057,9 @@ def flyte_entity_call_handler( else: raise Exception(f"Received an output when workflow local execution expected None. Received: {result}") + if inspect.iscoroutine(result): + return result + if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( result is not None and expected_outputs == 1 ): diff --git a/flytekit/core/python_auto_container.py b/flytekit/core/python_auto_container.py index 47da6a9729..5335410a79 100644 --- a/flytekit/core/python_auto_container.py +++ b/flytekit/core/python_auto_container.py @@ -237,14 +237,8 @@ def load_task(self, loader_args: List[str]) -> PythonAutoContainerTask: return task_def def loader_args(self, settings: SerializationSettings, task: PythonAutoContainerTask) -> List[str]: # type:ignore - from flytekit.core.python_function_task import PythonFunctionTask - - if isinstance(task, PythonFunctionTask): - _, m, t, _ = extract_task_module(task.task_function) - return ["task-module", m, "task-name", t] - if isinstance(task, TrackedInstance): - _, m, t, _ = extract_task_module(task) - return ["task-module", m, "task-name", t] + _, m, t, _ = extract_task_module(task) + return ["task-module", m, "task-name", t] def get_all_tasks(self) -> List[PythonAutoContainerTask]: # type: ignore raise Exception("should not be needed") diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index f1318941fa..e1e80a4227 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -13,7 +13,6 @@ """ - from abc import ABC from collections import OrderedDict from enum import Enum @@ -93,6 +92,7 @@ def my_func(a: int) -> str: class ExecutionBehavior(Enum): DEFAULT = 1 DYNAMIC = 2 + EAGER = 3 def __init__( self, @@ -155,6 +155,15 @@ def execution_mode(self) -> ExecutionBehavior: def task_function(self): return self._task_function + @property + def name(self) -> str: + """ + Returns the name of the task. + """ + if self.instantiated_in and self.instantiated_in not in self._name: + return f"{self.instantiated_in}.{self._name}" + return self._name + def execute(self, **kwargs) -> Any: """ This method will be invoked to execute the task. If you do decide to override this method you must also @@ -162,6 +171,11 @@ def execute(self, **kwargs) -> Any: """ if self.execution_mode == self.ExecutionBehavior.DEFAULT: return exception_scopes.user_entry_point(self._task_function)(**kwargs) + elif self.execution_mode == self.ExecutionBehavior.EAGER: + # if the task is a coroutine function, inject the context object so that the async_entity + # has access to the FlyteContext. + kwargs["async_ctx"] = FlyteContextManager.current_context() + return exception_scopes.user_entry_point(self._task_function)(**kwargs) elif self.execution_mode == self.ExecutionBehavior.DYNAMIC: return self.dynamic_execute(self._task_function, **kwargs) @@ -188,7 +202,12 @@ def compile_into_workflow( else: cs = ctx.compilation_state.with_params(prefix="d") - with FlyteContextManager.with_context(ctx.with_compilation_state(cs)): + updated_ctx = ctx.with_compilation_state(cs) + if self.execution_mode == self.ExecutionBehavior.DYNAMIC: + es = ctx.new_execution_state().with_params(mode=ExecutionState.Mode.DYNAMIC_TASK_EXECUTION) + updated_ctx = updated_ctx.with_execution_state(es) + + with FlyteContextManager.with_context(updated_ctx): # TODO: Resolve circular import from flytekit.tools.translator import get_serializable diff --git a/flytekit/core/resources.py b/flytekit/core/resources.py index 4cf2523f6a..62b880f6ed 100644 --- a/flytekit/core/resources.py +++ b/flytekit/core/resources.py @@ -32,6 +32,19 @@ class Resources(object): storage: Optional[str] = None ephemeral_storage: Optional[str] = None + def __post_init__(self): + def _check_none_or_str(value): + if value is None: + return + if not isinstance(value, str): + raise AssertionError(f"{value} should be a string") + + _check_none_or_str(self.cpu) + _check_none_or_str(self.mem) + _check_none_or_str(self.gpu) + _check_none_or_str(self.storage) + _check_none_or_str(self.ephemeral_storage) + @dataclass class ResourceSpec(object): diff --git a/flytekit/core/tracker.py b/flytekit/core/tracker.py index 1123c57d25..f2e6437bed 100644 --- a/flytekit/core/tracker.py +++ b/flytekit/core/tracker.py @@ -1,17 +1,30 @@ import importlib -import importlib as _importlib import inspect -import inspect as _inspect import os +import sys import typing +from pathlib import Path from types import ModuleType -from typing import Callable, Tuple, Union +from typing import Callable, Optional, Tuple, Union from flytekit.configuration.feature_flags import FeatureFlags from flytekit.exceptions import system as _system_exceptions from flytekit.loggers import logger +def import_module_from_file(module_name, file): + try: + spec = importlib.util.spec_from_file_location(module_name, file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + except AssertionError: + # handle where we can't determine the module of functions within the module + return importlib.import_module(module_name) + except Exception as exc: + raise ModuleNotFoundError(f"Module from file {file} cannot be loaded") from exc + + class InstanceTrackingMeta(type): """ Please see the original class :py:class`flytekit.common.mixins.registerable._InstanceTracker` also and also look @@ -22,18 +35,49 @@ class InstanceTrackingMeta(type): variable that the instance was assigned to. """ + @staticmethod + def _get_module_from_main(globals) -> Optional[str]: + curdir = Path.cwd() + file = globals.get("__file__") + if file is None: + return None + + file = Path(file) + try: + file_relative = file.relative_to(curdir) + except ValueError: + return None + + module_components = [*file_relative.with_suffix("").parts] + module_name = ".".join(module_components) + if len(module_components) == 0: + return None + + # make sure current directory is in the PYTHONPATH. + sys.path.insert(0, str(curdir)) + return import_module_from_file(module_name, file) + @staticmethod def _find_instance_module(): - frame = _inspect.currentframe() + frame = inspect.currentframe() while frame: if frame.f_code.co_name == "" and "__name__" in frame.f_globals: - return frame.f_globals["__name__"] + if frame.f_globals["__name__"] != "__main__": + return frame.f_globals["__name__"], frame.f_globals["__file__"] + # if the remote_deploy command is invoked in the same module as where + # the app is defined, get the module from the file name + mod = InstanceTrackingMeta._get_module_from_main(frame.f_globals) + if mod is None: + return None, None + return mod.__name__, mod.__file__ frame = frame.f_back - return None + return None, None def __call__(cls, *args, **kwargs): o = super(InstanceTrackingMeta, cls).__call__(*args, **kwargs) - o._instantiated_in = InstanceTrackingMeta._find_instance_module() + mod_name, mod_file = InstanceTrackingMeta._find_instance_module() + o._instantiated_in = mod_name + o._module_file = mod_file return o @@ -51,6 +95,7 @@ class TrackedInstance(metaclass=InstanceTrackingMeta): def __init__(self, *args, **kwargs): self._instantiated_in = None + self._module_file = None self._lhs = None super().__init__(*args, **kwargs) @@ -77,7 +122,7 @@ def find_lhs(self) -> str: raise _system_exceptions.FlyteSystemException(f"Object {self} does not have an _instantiated in") logger.debug(f"Looking for LHS for {self} from {self._instantiated_in}") - m = _importlib.import_module(self._instantiated_in) + m = importlib.import_module(self._instantiated_in) for k in dir(m): try: if getattr(m, k) is self: @@ -92,6 +137,28 @@ def find_lhs(self) -> str: # continue looping through m. logger.warning("Caught ValueError {} while attempting to auto-assign name".format(err)) + # try to find object in module when the tracked instance is defined in the __main__ module + module = import_module_from_file(self._instantiated_in, self._module_file) + + def _candidate_name_matches(candidate) -> bool: + if not hasattr(candidate, "name") or not hasattr(self, "name"): + return False + return candidate.name == self.name + + for k in dir(module): + try: + candidate = getattr(module, k) + # consider the variable equivalent to self if it's of the same type, name + if ( + type(candidate) == type(self) + and _candidate_name_matches(candidate) + and candidate.instantiated_in == self.instantiated_in + ): + self._lhs = k + return k + except ValueError as err: + logger.warning(f"Caught ValueError {err} while attempting to auto-assign name") + logger.error(f"Could not find LHS for {self} in {self._instantiated_in}") raise _system_exceptions.FlyteSystemException(f"Error looking for LHS in {self._instantiated_in}") @@ -216,6 +283,13 @@ def get_absolute_module_name(self, path: str, package_root: typing.Optional[str] _mod_sanitizer = _ModuleSanitizer() +def _task_module_from_callable(f: Callable): + mod = inspect.getmodule(f) + mod_name = getattr(mod, "__name__", f.__module__) + name = f.__name__.split(".")[-1] + return mod, mod_name, name + + def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, str, str]: """ Returns the task-name, absolute module and the string name of the callable. @@ -224,21 +298,20 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str, """ if isinstance(f, TrackedInstance): - mod = importlib.import_module(f.instantiated_in) - mod_name = mod.__name__ - name = f.lhs - # We cannot get the sourcefile for an instance, so we replace it with the module - g = mod - inspect_file = inspect.getfile(g) + if hasattr(f, "task_function"): + mod, mod_name, name = _task_module_from_callable(f.task_function) + elif f.instantiated_in: + mod = importlib.import_module(f.instantiated_in) + mod_name = mod.__name__ + name = f.lhs else: - mod = inspect.getmodule(f) # type: ignore - if mod is None: - raise AssertionError(f"Unable to determine module of {f}") - mod_name = mod.__name__ - name = f.__name__.split(".")[-1] - inspect_file = inspect.getfile(f) + mod, mod_name, name = _task_module_from_callable(f) + + if mod is None: + raise AssertionError(f"Unable to determine module of {f}") if mod_name == "__main__": + inspect_file = inspect.getfile(f) # type: ignore return name, "", name, os.path.abspath(inspect_file) mod_name = get_full_module_path(mod, mod_name) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 8966564b28..9c48908f98 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -12,7 +12,7 @@ import typing from abc import ABC, abstractmethod from functools import lru_cache -from typing import Dict, NamedTuple, Optional, Type, cast +from typing import Dict, List, NamedTuple, Optional, Type, cast from dataclasses_json import DataClassJsonMixin, dataclass_json from google.protobuf import json_format as _json_format @@ -22,6 +22,7 @@ from google.protobuf.message import Message from google.protobuf.struct_pb2 import Struct from marshmallow_enum import EnumField, LoadDumpOptions +from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, get_args, get_origin from flytekit.core.annotation import FlyteAnnotation @@ -53,6 +54,38 @@ T = typing.TypeVar("T") DEFINITIONS = "definitions" +TITLE = "title" + + +class BatchSize: + """ + This is used to annotate a FlyteDirectory when we want to download/upload the contents of the directory in batches. For example, + + @task + def t1(directory: Annotated[FlyteDirectory, BatchSize(10)]) -> Annotated[FlyteDirectory, BatchSize(100)]: + ... + return FlyteDirectory(...) + + In the above example flytekit will download all files from the input `directory` in chunks of 10, i.e. first it + downloads 10 files, loads them to memory, then writes those 10 to local disk, then it loads the next 10, so on + and so forth. Similarly, for outputs, in this case flytekit is going to upload the resulting directory in chunks of + 100. + """ + + def __init__(self, val: int): + self._val = val + + @property + def val(self) -> int: + return self._val + + +def get_batch_size(t: Type) -> Optional[int]: + if is_annotated(t): + for annotation in get_args(t)[1:]: + if isinstance(annotation, BatchSize): + return annotation.val + return None class TypeTransformerFailedError(TypeError, AssertionError, ValueError): @@ -220,8 +253,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: class DataclassTransformer(TypeTransformer[object]): """ - The Dataclass Transformer, provides a type transformer for arbitrary Python dataclasses, that have - @dataclass and @dataclass_json decorators. + The Dataclass Transformer provides a type transformer for dataclasses_json dataclasses. The Dataclass is converted to and from json and is transported between tasks using the proto.Structpb representation Also the type declaration will try to extract the JSON Schema for the object if possible and pass it with the @@ -233,9 +265,8 @@ class DataclassTransformer(TypeTransformer[object]): .. code-block:: python - @dataclass_json @dataclass - class Test(): + class Test(DataClassJsonMixin): a: int b: str @@ -270,9 +301,8 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): if type(v) == expected_type: return - # @dataclass_json # @dataclass - # class Foo(object): + # class Foo(DataClassJsonMixin): # a: int = 0 # # @task @@ -316,21 +346,28 @@ def get_literal_type(self, t: Type[T]) -> LiteralType: f"Type {t} cannot be parsed." ) - if not issubclass(t, DataClassJsonMixin): + if not issubclass(t, DataClassJsonMixin) and not issubclass(t, DataClassJSONMixin): raise AssertionError( - f"Dataclass {t} should be decorated with @dataclass_json to be " f"serialized correctly" + f"Dataclass {t} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be " + f"serialized correctly" ) schema = None try: - s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() - for _, v in s.fields.items(): - # marshmallow-jsonschema only supports enums loaded by name. - # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 - if isinstance(v, EnumField): - v.load_by = LoadDumpOptions.name - from marshmallow_jsonschema import JSONSchema - - schema = JSONSchema().dump(s) + if issubclass(t, DataClassJsonMixin): + s = cast(DataClassJsonMixin, self._get_origin_type_in_annotation(t)).schema() + for _, v in s.fields.items(): + # marshmallow-jsonschema only supports enums loaded by name. + # https://github.com/fuhrysteve/marshmallow-jsonschema/blob/81eada1a0c42ff67de216923968af0a6b54e5dcb/marshmallow_jsonschema/base.py#L228 + if isinstance(v, EnumField): + v.load_by = LoadDumpOptions.name + # check if DataClass mixin + from marshmallow_jsonschema import JSONSchema + + schema = JSONSchema().dump(s) + else: # DataClassJSONMixin + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(cast(DataClassJSONMixin, self._get_origin_type_in_annotation(t))).to_dict() except Exception as e: # https://github.com/lovasoa/marshmallow_dataclass/issues/13 logger.warning( @@ -347,14 +384,18 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp f"{type(python_val)} is not of type @dataclass, only Dataclasses are supported for " f"user defined datatypes in Flytekit" ) - if not issubclass(type(python_val), DataClassJsonMixin): + if not issubclass(type(python_val), DataClassJsonMixin) and not issubclass( + type(python_val), DataClassJSONMixin + ): raise TypeTransformerFailedError( - f"Dataclass {python_type} should be decorated with @dataclass_json to be " f"serialized correctly" + f"Dataclass {python_type} should be decorated with @dataclass_json or inherit DataClassJSONMixin to be " + f"serialized correctly" ) self._serialize_flyte_type(python_val, python_type) - return Literal( - scalar=Scalar(generic=_json_format.Parse(cast(DataClassJsonMixin, python_val).to_json(), _struct.Struct())) - ) + + json_str = python_val.to_json() # type: ignore + + return Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) # type: ignore def _get_origin_type_in_annotation(self, python_type: Type[T]) -> Type[T]: # dataclass will try to hash python type when calling dataclass.schema(), but some types in the annotation is @@ -429,10 +470,10 @@ def _serialize_flyte_type(self, python_val: T, python_type: Type[T]) -> typing.A or issubclass(python_type, StructuredDataset) ): lv = TypeEngine.to_literal(FlyteContext.current_context(), python_val, python_type, None) - # dataclass_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a + # dataclasses_json package will extract the "path" from FlyteFile, FlyteDirectory, and write it to a # JSON which will be stored in IDL. The path here should always be a remote path, but sometimes the # path in FlyteFile and FlyteDirectory could be a local path. Therefore, reset the python value here, - # so that dataclass_json can always get a remote path. + # so that dataclasses_json can always get a remote path. # In other words, the file transformer has special code that handles the fact that if remote_source is # set, then the real uri in the literal should be the remote source, not the path (which may be an # auto-generated random local path). To be sure we're writing the right path to the json, use the uri @@ -596,15 +637,18 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: if not dataclasses.is_dataclass(expected_python_type): raise TypeTransformerFailedError( f"{expected_python_type} is not of type @dataclass, only Dataclasses are supported for " - f"user defined datatypes in Flytekit" + "user defined datatypes in Flytekit" ) - if not issubclass(expected_python_type, DataClassJsonMixin): + if not issubclass(expected_python_type, DataClassJsonMixin) and not issubclass( + expected_python_type, DataClassJSONMixin + ): raise TypeTransformerFailedError( - f"Dataclass {expected_python_type} should be decorated with @dataclass_json to be " + f"Dataclass {expected_python_type} should be decorated with @dataclass_json or mixin with DataClassJSONMixin to be " f"serialized correctly" ) json_str = _json_format.MessageToJson(lv.scalar.generic) - dc = cast(DataClassJsonMixin, expected_python_type).from_json(json_str) + dc = expected_python_type.from_json(json_str) # type: ignore + dc = self._fix_structured_dataset_type(expected_python_type, dc) return self._fix_dataclass_int(expected_python_type, self._deserialize_flyte_type(dc, expected_python_type)) @@ -615,10 +659,15 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: @lru_cache(typed=True) def guess_python_type(self, literal_type: LiteralType) -> Type[T]: # type: ignore if literal_type.simple == SimpleType.STRUCT: - if literal_type.metadata is not None and DEFINITIONS in literal_type.metadata: - schema_name = literal_type.metadata["$ref"].split("/")[-1] - return convert_json_schema_to_python_class(literal_type.metadata[DEFINITIONS], schema_name) - + if literal_type.metadata is not None: + if DEFINITIONS in literal_type.metadata: + schema_name = literal_type.metadata["$ref"].split("/")[-1] + return convert_marshmallow_json_schema_to_python_class( + literal_type.metadata[DEFINITIONS], schema_name + ) + elif TITLE in literal_type.metadata: + schema_name = literal_type.metadata[TITLE] + return convert_mashumaro_json_schema_to_python_class(literal_type.metadata, schema_name) raise ValueError(f"Dataclass transformer cannot reverse {literal_type}") @@ -744,7 +793,15 @@ def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: python_type = args[0] - if python_type in cls._REGISTRY: + # this makes sure that if it's a list/dict of annotated types, we hit the unwrapping code in step 2 + # see test_list_of_annotated in test_structured_dataset.py + if ( + (not hasattr(python_type, "__origin__")) + or ( + hasattr(python_type, "__origin__") + and (python_type.__origin__ is not list and python_type.__origin__ is not dict) + ) + ) and python_type in cls._REGISTRY: return cls._REGISTRY[python_type] # Step 2 @@ -862,7 +919,7 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type "actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then" "return v.x, instead of v, even if this has a single element" ) - if python_val is None and expected.union_type is None: + if python_val is None and expected and expected.union_type is None: raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}") transformer = cls.get_transformer(python_type) if transformer.type_assertions_enabled: @@ -1519,39 +1576,97 @@ def to_literal( def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T: return expected_python_type(lv.scalar.primitive.string_value) # type: ignore + def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]: + if literal_type.enum_type: + return enum.Enum("DynamicEnum", {f"{i}": i for i in literal_type.enum_type.values}) # type: ignore + raise ValueError(f"Enum transformer cannot reverse {literal_type}") -def convert_json_schema_to_python_class(schema: dict, schema_name) -> Type[dataclasses.dataclass()]: # type: ignore - """ - Generate a model class based on the provided JSON Schema - :param schema: dict representing valid JSON schema - :param schema_name: dataclass name of return type - """ + +def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any): + attribute_list = [] + for property_key, property_val in schema["properties"].items(): + if property_val.get("anyOf"): + property_type = property_val["anyOf"][0]["type"] + elif property_val.get("enum"): + property_type = "enum" + else: + property_type = property_val["type"] + # Handle list + if property_type == "array": + attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore + # Handle dataclass and dict + elif property_type == "object": + if property_val.get("anyOf"): + sub_schemea = property_val["anyOf"][0] + sub_schemea_name = sub_schemea["title"] + attribute_list.append( + (property_key, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name)) + ) + elif property_val.get("additionalProperties"): + attribute_list.append( + (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore + ) + else: + sub_schemea_name = property_val["title"] + attribute_list.append( + (property_key, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name)) + ) + elif property_type == "enum": + attribute_list.append([property_key, str]) # type: ignore + # Handle int, float, bool or str + else: + attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore + return attribute_list + + +def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typing.Any): attribute_list = [] for property_key, property_val in schema[schema_name]["properties"].items(): property_type = property_val["type"] # Handle list if property_val["type"] == "array": - attribute_list.append((property_key, typing.List[_get_element_type(property_val["items"])])) # type: ignore + attribute_list.append((property_key, List[_get_element_type(property_val["items"])])) # type: ignore[misc,index] # Handle dataclass and dict elif property_type == "object": if property_val.get("$ref"): name = property_val["$ref"].split("/")[-1] - attribute_list.append((property_key, convert_json_schema_to_python_class(schema, name))) + attribute_list.append((property_key, convert_marshmallow_json_schema_to_python_class(schema, name))) elif property_val.get("additionalProperties"): attribute_list.append( - (property_key, typing.Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore + (property_key, Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore[misc,index] ) else: - attribute_list.append((property_key, typing.Dict[str, _get_element_type(property_val)])) # type: ignore + attribute_list.append((property_key, Dict[str, _get_element_type(property_val)])) # type: ignore[misc,index] # Handle int, float, bool or str else: attribute_list.append([property_key, _get_element_type(property_val)]) # type: ignore + return attribute_list + + +def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> Type[dataclasses.dataclass()]: # type: ignore + """ + Generate a model class based on the provided JSON Schema + :param schema: dict representing valid JSON schema + :param schema_name: dataclass name of return type + """ + attribute_list = generate_attribute_list_from_dataclass_json(schema, schema_name) + return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) + + +def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> Type[dataclasses.dataclass()]: # type: ignore + """ + Generate a model class based on the provided JSON Schema + :param schema: dict representing valid JSON schema + :param schema_name: dataclass name of return type + """ + + attribute_list = generate_attribute_list_from_dataclass_json_mixin(schema, schema_name) return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list)) def _get_element_type(element_property: typing.Dict[str, str]) -> Type: - element_type = element_property["type"] + element_type = [e_property["type"] for e_property in element_property["anyOf"]] if element_property.get("anyOf") else element_property["type"] # type: ignore element_format = element_property["format"] if "format" in element_property else None if type(element_type) == list: @@ -1663,6 +1778,18 @@ def _register_default_type_transformers(): ) ) + TypeEngine.register( + SimpleTransformer( + "date", + _datetime.date, + _type_models.LiteralType(simple=_type_models.SimpleType.DATETIME), + lambda x: Literal( + scalar=Scalar(primitive=Primitive(datetime=_datetime.datetime.combine(x, _datetime.time.min))) + ), # convert datetime to date + lambda x: x.scalar.primitive.datetime.date(), # get date from datetime + ) + ) + TypeEngine.register( SimpleTransformer( "none", diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 7ced5940fc..e1416d4a7f 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -1,10 +1,12 @@ from __future__ import annotations +import asyncio +import inspect import typing from dataclasses import dataclass from enum import Enum from functools import update_wrapper -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, overload +from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union, cast, overload from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask @@ -262,7 +264,7 @@ def construct_node_metadata(self) -> _workflow_model.NodeMetadata: interruptible=self.workflow_metadata_defaults.interruptible, ) - def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, None]: + def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, Tuple, Coroutine, None]: """ Workflow needs to fill in default arguments before invoking the call handler. """ @@ -294,6 +296,11 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr kwargs_literals = {k: Promise(var=k, val=v) for k, v in literal_map.items()} self.compile() function_outputs = self.execute(**kwargs_literals) + + if inspect.iscoroutine(function_outputs): + # handle coroutines for eager workflows + function_outputs = asyncio.run(function_outputs) + # First handle the empty return case. # A workflow function may return a task that doesn't return anything # def wf(): @@ -787,7 +794,7 @@ def workflow( your typical Python values. So even though you may have a task ``t1() -> int``, when ``a = t1()`` is called, ``a`` will not be an integer so if you try to ``range(a)`` you'll get an error. - Please see the :std:doc:`user guide ` for more usage examples. + Please see the :ref:`user guide ` for more usage examples. :param _workflow_function: This argument is implicitly passed and represents the decorated function. :param failure_policy: Use the options in flytekit.WorkflowFailurePolicy diff --git a/flytekit/deck/deck.py b/flytekit/deck/deck.py index dd8ef59163..3ce9d058a4 100644 --- a/flytekit/deck/deck.py +++ b/flytekit/deck/deck.py @@ -50,7 +50,6 @@ def t1() -> str: @task() def t2() -> Annotated[pd.DataFrame, TopFrameRenderer(10)]: return iris_df - """ def __init__(self, name: str, html: Optional[str] = ""): diff --git a/flytekit/exceptions/system.py b/flytekit/exceptions/system.py index 63c43e8879..63fe55f0b9 100644 --- a/flytekit/exceptions/system.py +++ b/flytekit/exceptions/system.py @@ -37,3 +37,7 @@ def __init__(self, task_module, task_name=None, additional_msg=None): class FlyteSystemAssertion(FlyteSystemException, AssertionError): _ERROR_CODE = "SYSTEM:AssertionError" + + +class FlyteAgentNotFound(FlyteSystemException, AssertionError): + _ERROR_CODE = "SYSTEM:AgentNotFound" diff --git a/flytekit/experimental/__init__.py b/flytekit/experimental/__init__.py new file mode 100644 index 0000000000..2780211c8f --- /dev/null +++ b/flytekit/experimental/__init__.py @@ -0,0 +1,4 @@ +"""Experimental features of flytekit.""" + +from flytekit.core.array_node_map_task import map_task # noqa: F401 +from flytekit.experimental.eager_function import EagerException, eager diff --git a/flytekit/experimental/eager_function.py b/flytekit/experimental/eager_function.py new file mode 100644 index 0000000000..264d0d641a --- /dev/null +++ b/flytekit/experimental/eager_function.py @@ -0,0 +1,624 @@ +import asyncio +import inspect +import signal +from contextlib import asynccontextmanager +from datetime import datetime, timedelta +from functools import partial, wraps +from typing import List, Optional + +from flytekit import Deck, Secret, current_context +from flytekit.configuration import DataConfig, PlatformConfig, S3Config +from flytekit.core.base_task import PythonTask +from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager +from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.task import task +from flytekit.core.workflow import WorkflowBase +from flytekit.loggers import logger +from flytekit.models.core.execution import WorkflowExecutionPhase +from flytekit.remote import FlyteRemote + +FLYTE_SANDBOX_INTERNAL_ENDPOINT = "flyte-sandbox-grpc.flyte:8089" +FLYTE_SANDBOX_MINIO_ENDPOINT = "http://flyte-sandbox-minio.flyte:9000" + +NODE_HTML_TEMPLATE = """ + + + + +

{entity_type}: {entity_name}

+ +

+ Execution: + {execution_name} +

+ +
+Inputs +
{inputs}
+
+ +
+Outputs +
{outputs}
+
+ +
+""" + + +class EagerException(Exception): + """Raised when a node in an eager workflow encounters an error. + + This exception should be used in an :py:func:`@eager ` workflow function to + catch exceptions that are raised by tasks or subworkflows. + + .. code-block:: python + + from flytekit import task + from flytekit.experimental import eager, EagerException + + @task + def add_one(x: int) -> int: + if x < 0: + raise ValueError("x must be positive") + return x + 1 + + @task + def double(x: int) -> int: + return x * 2 + + @eager + async def eager_workflow(x: int) -> int: + try: + out = await add_one(x=x) + except EagerException: + # The ValueError error is caught + # and raised as an EagerException + raise + return await double(x=out) + """ + + +class AsyncEntity: + """A wrapper around a Flyte entity (task, workflow, launch plan) that allows it to be executed asynchronously.""" + + def __init__( + self, + entity, + remote: Optional[FlyteRemote], + ctx: FlyteContext, + async_stack: "AsyncStack", + timeout: Optional[timedelta] = None, + poll_interval: Optional[timedelta] = None, + local_entrypoint: bool = False, + ): + self.entity = entity + self.ctx = ctx + self.async_stack = async_stack + self.execution_state = self.ctx.execution_state.mode + self.remote = remote + self.local_entrypoint = local_entrypoint + if self.remote is not None: + logger.debug(f"Using remote config: {self.remote.config}") + else: + logger.debug("Not using remote, executing locally") + self._timeout = timeout + self._poll_interval = poll_interval + self._execution = None + + async def __call__(self, **kwargs): + logger.debug(f"Calling {self.entity}: {self.entity.name}") + + # ensure async context is provided + if "async_ctx" in kwargs: + kwargs.pop("async_ctx") + + if getattr(self.entity, "execution_mode", None) == PythonFunctionTask.ExecutionBehavior.DYNAMIC: + raise EagerException( + "Eager workflows currently do not work with dynamic workflows. " + "If you need to use a subworkflow, use a static @workflow or nested @eager workflow." + ) + + if not self.local_entrypoint and self.ctx.execution_state.is_local_execution(): + # If running as a local workflow execution, just execute the python function + try: + if isinstance(self.entity, WorkflowBase): + out = self.entity._workflow_function(**kwargs) + if inspect.iscoroutine(out): + # need to handle invocation of AsyncEntity tasks within the workflow + out = await out + return out + elif isinstance(self.entity, PythonTask): + # invoke the task-decorated entity + out = self.entity(**kwargs) + if inspect.iscoroutine(out): + out = await out + return out + else: + raise ValueError(f"Entity type {type(self.entity)} not supported for local execution") + except Exception as exc: + raise EagerException( + f"Error executing {type(self.entity)} {self.entity.name} with {type(exc)}: {exc}" + ) from exc + + # this is a hack to handle the case when the task.name doesn't contain the fully + # qualified module name + entity_name = ( + f"{self.entity._instantiated_in}.{self.entity.name}" + if self.entity._instantiated_in not in self.entity.name + else self.entity.name + ) + + if isinstance(self.entity, WorkflowBase): + remote_entity = self.remote.fetch_workflow(name=entity_name) + elif isinstance(self.entity, PythonTask): + remote_entity = self.remote.fetch_task(name=entity_name) + else: + raise ValueError(f"Entity type {type(self.entity)} not supported for local execution") + + execution = self.remote.execute(remote_entity, inputs=kwargs, type_hints=self.entity.python_interface.inputs) + self._execution = execution + + url = self.remote.generate_console_url(execution) + msg = f"Running flyte {type(self.entity)} {entity_name} on remote cluster: {url}" + if self.local_entrypoint: + logger.info(msg) + else: + logger.debug(msg) + + node = AsyncNode(self, entity_name, execution, url) + self.async_stack.set_node(node) + + poll_interval = self._poll_interval or timedelta(seconds=30) + time_to_give_up = datetime.max if self._timeout is None else datetime.utcnow() + self._timeout + + while datetime.utcnow() < time_to_give_up: + execution = self.remote.sync(execution) + if execution.closure.phase in {WorkflowExecutionPhase.FAILED}: + raise EagerException(f"Error executing {self.entity.name} with error: {execution.closure.error}") + elif execution.is_done: + break + await asyncio.sleep(poll_interval.total_seconds()) + + outputs = {} + for key, type_ in self.entity.python_interface.outputs.items(): + outputs[key] = execution.outputs.get(key, as_type=type_) + + if len(outputs) == 1: + out, *_ = outputs.values() + return out + return outputs + + async def terminate(self): + execution = self.remote.sync(self._execution) + logger.debug(f"Cleaning up execution: {execution}") + if not execution.is_done: + self.remote.terminate( + execution, + f"Execution terminated by eager workflow execution {self.async_stack.parent_execution_id}.", + ) + + poll_interval = self._poll_interval or timedelta(seconds=6) + time_to_give_up = datetime.max if self._timeout is None else datetime.utcnow() + self._timeout + + while datetime.utcnow() < time_to_give_up: + execution = self.remote.sync(execution) + if execution.is_done: + break + await asyncio.sleep(poll_interval.total_seconds()) + + return True + + +class AsyncNode: + """A node in the async callstack.""" + + def __init__(self, async_entity, entity_name, execution=None, url=None): + self.entity_name = entity_name + self.async_entity = async_entity + self.execution = execution + self._url = url + + @property + def url(self) -> str: + # make sure that internal flyte sandbox endpoint is replaced with localhost endpoint when rendering the urls + # for flyte decks + endpoint_root = FLYTE_SANDBOX_INTERNAL_ENDPOINT.replace("http://", "") + if endpoint_root in self._url: + return self._url.replace(endpoint_root, "localhost:30080") + return self._url + + @property + def entity_type(self) -> str: + if ( + isinstance(self.async_entity.entity, PythonTask) + and getattr(self.async_entity.entity, "execution_mode", None) == PythonFunctionTask.ExecutionBehavior.EAGER + ): + return "Eager Workflow" + elif isinstance(self.async_entity.entity, PythonTask): + return "Task" + elif isinstance(self.async_entity.entity, WorkflowBase): + return "Workflow" + return str(type(self.async_entity.entity)) + + def __repr__(self): + ex_id = self.execution.id + execution_id = None if self.execution is None else f"{ex_id.project}:{ex_id.domain}:{ex_id.name}" + return ( + "" + + @property + def call_stack(self) -> List[AsyncNode]: + return self._call_stack + + def set_node(self, node: AsyncNode): + self._call_stack.append(node) + + +async def render_deck(async_stack): + """Render the callstack as a deck presentation to be shown after eager workflow execution.""" + + def get_io(dict_like): + try: + return {k: dict_like.get(k) for k in dict_like} + except Exception: + return dict_like + + output = "

Nodes


" + for node in async_stack.call_stack: + node_inputs = get_io(node.execution.inputs) + if node.execution.closure.phase in {WorkflowExecutionPhase.FAILED}: + node_outputs = None + else: + node_outputs = get_io(node.execution.outputs) + + output = f"{output}\n" + NODE_HTML_TEMPLATE.format( + entity_type=node.entity_type, + entity_name=node.entity_name, + execution_name=node.execution.id.name, + url=node.url, + inputs=node_inputs, + outputs=node_outputs, + ) + + Deck("eager workflow", output) + + +@asynccontextmanager +async def eager_context( + fn, + remote: Optional[FlyteRemote], + ctx: FlyteContext, + async_stack: AsyncStack, + timeout: Optional[timedelta] = None, + poll_interval: Optional[timedelta] = None, + local_entrypoint: bool = False, +): + """This context manager overrides all tasks in the global namespace with async versions.""" + + _original_cache = {} + + # override tasks with async version + for k, v in fn.__globals__.items(): + if isinstance(v, (PythonTask, WorkflowBase)): + _original_cache[k] = v + fn.__globals__[k] = AsyncEntity(v, remote, ctx, async_stack, timeout, poll_interval, local_entrypoint) + + try: + yield + finally: + # restore old tasks + for k, v in _original_cache.items(): + fn.__globals__[k] = v + + +async def node_cleanup_async(sig, loop, async_stack: AsyncStack): + """Clean up subtasks when eager workflow parent is done. + + This applies either if the eager workflow completes successfully, fails, or is cancelled by the user. + """ + logger.debug(f"Cleaning up async nodes on signal: {sig}") + terminations = [] + for node in async_stack.call_stack: + terminations.append(node.async_entity.terminate()) + results = await asyncio.gather(*terminations) + logger.debug(f"Successfully terminated subtasks {results}") + + +def node_cleanup(sig, frame, loop, async_stack: AsyncStack): + """Clean up subtasks when eager workflow parent is done. + + This applies either if the eager workflow completes successfully, fails, or is cancelled by the user. + """ + logger.debug(f"Cleaning up async nodes on signal: {sig}") + terminations = [] + for node in async_stack.call_stack: + terminations.append(node.async_entity.terminate()) + results = asyncio.gather(*terminations) + results = asyncio.run(results) + logger.debug(f"Successfully terminated subtasks {results}") + loop.close() + + +def eager( + _fn=None, + *, + remote: Optional[FlyteRemote] = None, + client_secret_group: Optional[str] = None, + client_secret_key: Optional[str] = None, + timeout: Optional[timedelta] = None, + poll_interval: Optional[timedelta] = None, + local_entrypoint: bool = False, + **kwargs, +): + """Eager workflow decorator. + + :param remote: A :py:class:`~flytekit.remote.FlyteRemote` object to use for executing Flyte entities. + :param client_secret_group: The client secret group to use for this workflow. + :param client_secret_key: The client secret key to use for this workflow. + :param timeout: The timeout duration specifying how long to wait for a task/workflow execution within the eager + workflow to complete or terminate. By default, the eager workflow will wait indefinitely until complete. + :param poll_interval: The poll interval for checking if a task/workflow execution within the eager workflow has + finished. If not specified, the default poll interval is 6 seconds. + :param local_entrypoint: If True, the eager workflow will can be executed locally but use the provided + :py:func:`~flytekit.remote.FlyteRemote` object to create task/workflow executions. This is useful for local + testing against a remote Flyte cluster. + :param kwargs: keyword-arguments forwarded to :py:func:`~flytekit.task`. + + This type of workflow will execute all flyte entities within it eagerly, meaning that all python constructs can be + used inside of an ``@eager``-decorated function. This is because eager workflows use a + :py:class:`~flytekit.remote.remote.FlyteRemote` object to kick off executions when a flyte entity needs to produce a + value. + + For example: + + .. code-block:: python + + from flytekit import task + from flytekit.experimental import eager + + @task + def add_one(x: int) -> int: + return x + 1 + + @task + def double(x: int) -> int: + return x * 2 + + @eager + async def eager_workflow(x: int) -> int: + out = await add_one(x=x) + return await double(x=out) + + # run locally with asyncio + if __name__ == "__main__": + import asyncio + + result = asyncio.run(eager_workflow(x=1)) + print(f"Result: {result}") # "Result: 4" + + Unlike :py:func:`dynamic workflows `, eager workflows are not compiled into a workflow spec, but + uses python's `async `__ capabilities to execute flyte entities. + + .. note:: + + Eager workflows only support `@task`, `@workflow`, and `@eager` entities. Dynamic workflows and launchplans are + currently not supported. + + Note that for the ``@eager`` function is an ``async`` function. Under the hood, tasks and workflows called inside + an ``@eager`` workflow are executed asynchronously. This means that task and workflow calls will return an awaitable, + which need to be awaited. + + .. important:: + + A ``client_secret_group`` and ``client_secret_key`` is needed for authenticating via + :py:class:`~flytekit.remote.remote.FlyteRemote` using the ``client_credentials`` authentication, which is + configured via :py:class:`~flytekit.configuration.PlatformConfig`. + + .. code-block:: python + + from flytekit.remote import FlyteRemote + from flytekit.configuration import Config + + @eager( + remote=FlyteRemote(config=Config.auto(config_file="config.yaml")), + client_secret_group="my_client_secret_group", + client_secret_key="my_client_secret_key", + ) + async def eager_workflow(x: int) -> int: + out = await add_one(x) + return await double(one) + + Where ``config.yaml`` contains is a flytectl-compatible config file. + For more details, see `here `__. + + When using a sandbox cluster started with ``flytectl demo start``, however, the ``client_secret_group`` + and ``client_secret_key`` are not needed, : + + .. code-block:: python + + @eager(remote=FlyteRemote(config=Config.for_sandbox())) + async def eager_workflow(x: int) -> int: + ... + + .. important:: + + When using ``local_entrypoint=True`` you also need to specify the ``remote`` argument. In this case, the eager + workflow runtime will be local, but all task/subworkflow invocations will occur on the specified Flyte cluster. + This argument is primarily used for testing and debugging eager workflow logic locally. + + """ + + if _fn is None: + return partial( + eager, + remote=remote, + client_secret_group=client_secret_group, + client_secret_key=client_secret_key, + local_entrypoint=local_entrypoint, + **kwargs, + ) + + if local_entrypoint and remote is None: + raise ValueError("Must specify remote argument if local_entrypoint is True") + + @wraps(_fn) + async def wrapper(*args, **kws): + # grab the "async_ctx" argument injected by PythonFunctionTask.execute + logger.debug("Starting") + _remote = remote + + # locally executed nested eager workflows won't have async_ctx injected into the **kws input + ctx = kws.pop("async_ctx", None) + task_id, execution_id = None, None + if ctx: + exec_params = ctx.user_space_params + task_id = exec_params.task_id + execution_id = exec_params.execution_id + + async_stack = AsyncStack(task_id, execution_id) + _remote = _prepare_remote(_remote, ctx, client_secret_group, client_secret_key, local_entrypoint) + + # make sure sub-nodes as cleaned up on termination signal + loop = asyncio.get_event_loop() + node_cleanup_partial = partial(node_cleanup_async, async_stack=async_stack) + cleanup_fn = partial(asyncio.ensure_future, node_cleanup_partial(signal.SIGTERM, loop)) + signal.signal(signal.SIGTERM, partial(node_cleanup, loop=loop, async_stack=async_stack)) + + async with eager_context(_fn, _remote, ctx, async_stack, timeout, poll_interval, local_entrypoint): + try: + if _remote is not None: + with _remote.remote_context(): + out = await _fn(*args, **kws) + else: + out = await _fn(*args, **kws) + # need to await for _fn to complete, then invoke the deck + await render_deck(async_stack) + return out + finally: + # in case the cleanup function hasn't been called yet, call it at the end of the eager workflow + await cleanup_fn() + + secret_requests = kwargs.pop("secret_requests", None) or [] + if client_secret_group is not None and client_secret_key is not None: + secret_requests.append(Secret(group=client_secret_group, key=client_secret_key)) + + return task( + wrapper, + secret_requests=secret_requests, + disable_deck=False, + execution_mode=PythonFunctionTask.ExecutionBehavior.EAGER, + **kwargs, + ) + + +def _prepare_remote( + remote: Optional[FlyteRemote], + ctx: FlyteContext, + client_secret_group: Optional[str] = None, + client_secret_key: Optional[str] = None, + local_entrypoint: bool = False, +) -> Optional[FlyteRemote]: + """Prepare FlyteRemote object for accessing Flyte cluster in a task running on the same cluster.""" + + is_local_execution_mode = ctx.execution_state.mode in { + ExecutionState.Mode.LOCAL_TASK_EXECUTION, + ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION, + } + + if remote is not None and local_entrypoint and is_local_execution_mode: + # when running eager workflows as a local entrypoint, we don't have to modify the remote object + # because we can assume that the user is running this from their local machine and can do browser-based + # authentication. + logger.info("Running eager workflow as local entrypoint") + return remote + + if remote is None or is_local_execution_mode: + # if running the "eager workflow" (which is actually task) locally, run the task as a function, + # which doesn't need a remote object + return None + + # Handle the case where this the task is running in a Flyte cluster and needs to access the cluster itself + # via FlyteRemote. + if remote.config.platform.endpoint.startswith("localhost"): + # replace sandbox endpoints with internal dns, since localhost won't exist within the Flyte cluster + return _internal_demo_remote(remote) + return _internal_remote(remote, client_secret_group, client_secret_key) + + +def _internal_demo_remote(remote: FlyteRemote) -> FlyteRemote: + """Derives a FlyteRemote object from a sandbox yaml configuration, modifying parts to make it work internally.""" + # replace sandbox endpoints with internal dns, since localhost won't exist within the Flyte cluster + return FlyteRemote( + config=remote.config.with_params( + platform=PlatformConfig( + endpoint=FLYTE_SANDBOX_INTERNAL_ENDPOINT, + insecure=True, + auth_mode="Pkce", + client_id=remote.config.platform.client_id, + ), + data_config=DataConfig( + s3=S3Config( + endpoint=FLYTE_SANDBOX_MINIO_ENDPOINT, + access_key_id=remote.config.data_config.s3.access_key_id, + secret_access_key=remote.config.data_config.s3.secret_access_key, + ), + ), + ), + default_domain=remote.default_domain, + default_project=remote.default_project, + ) + + +def _internal_remote( + remote: FlyteRemote, + client_secret_group: str, + client_secret_key: str, +) -> FlyteRemote: + """Derives a FlyteRemote object from a yaml configuration file, modifying parts to make it work internally.""" + assert client_secret_group is not None, "secret_group must be defined when using a remote cluster" + assert client_secret_key is not None, "secret_key must be defined a remote cluster" + secrets_manager = current_context().secrets + client_secret = secrets_manager.get(client_secret_group, client_secret_key) + # get the raw output prefix from the context that's set from the pyflyte-execute entrypoint + # (see flytekit/bin/entrypoint.py) + ctx = FlyteContextManager.current_context() + return FlyteRemote( + config=remote.config.with_params( + platform=PlatformConfig( + endpoint=remote.config.platform.endpoint, + insecure=remote.config.platform.insecure, + auth_mode="client_credentials", + client_id=remote.config.platform.client_id, + client_credentials_secret=remote.config.platform.client_credentials_secret or client_secret, + ), + ), + default_domain=remote.default_domain, + default_project=remote.default_project, + data_upload_location=ctx.file_access.raw_output_prefix, + ) diff --git a/flytekit/extend/__init__.py b/flytekit/extend/__init__.py index 7223d13523..07e92e4c24 100644 --- a/flytekit/extend/__init__.py +++ b/flytekit/extend/__init__.py @@ -1,7 +1,7 @@ """ -===================== +================== Extending Flytekit -===================== +================== .. currentmodule:: flytekit.extend @@ -12,13 +12,10 @@ get_serializable context_manager - SQLTask IgnoreOutputs - PythonTask ExecutionState Image ImageConfig - SerializationSettings Interface Promise TaskPlugins diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index 55f71959fe..73737f3a6c 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -1,54 +1,124 @@ +import asyncio +import typing + import grpc from flyteidl.admin.agent_pb2 import ( - PERMANENT_FAILURE, CreateTaskRequest, CreateTaskResponse, DeleteTaskRequest, DeleteTaskResponse, GetTaskRequest, GetTaskResponse, - Resource, ) from flyteidl.service.agent_pb2_grpc import AsyncAgentServiceServicer +from prometheus_client import Counter, Summary from flytekit import logger +from flytekit.exceptions.system import FlyteAgentNotFound from flytekit.extend.backend.base_agent import AgentRegistry from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate +metric_prefix = "flyte_agent_" +create_operation = "create" +get_operation = "get" +delete_operation = "delete" -class AgentService(AsyncAgentServiceServicer): - def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: - try: - tmp = TaskTemplate.from_flyte_idl(request.template) - inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None - agent = AgentRegistry.get_agent(context, tmp.type) - if agent is None: - return CreateTaskResponse() - return agent.create(context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp) - except Exception as e: - logger.error(f"failed to create task with error {e}") - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(f"failed to create task with error {e}") +# Follow the naming convention. https://prometheus.io/docs/practices/naming/ +request_success_count = Counter( + f"{metric_prefix}requests_success_total", "Total number of successful requests", ["task_type", "operation"] +) +request_failure_count = Counter( + f"{metric_prefix}requests_failure_total", + "Total number of failed requests", + ["task_type", "operation", "error_code"], +) - def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: - try: - agent = AgentRegistry.get_agent(context, request.task_type) - if agent is None: - return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) - return agent.get(context=context, resource_meta=request.resource_meta) - except Exception as e: - logger.error(f"failed to get task with error {e}") - context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(f"failed to get task with error {e}") +request_latency = Summary( + f"{metric_prefix}request_latency_seconds", "Time spent processing agent request", ["task_type", "operation"] +) + +input_literal_size = Summary(f"{metric_prefix}input_literal_bytes", "Size of input literal", ["task_type"]) + + +def agent_exception_handler(func): + async def wrapper( + self, + request: typing.Union[CreateTaskRequest, GetTaskRequest, DeleteTaskRequest], + context: grpc.ServicerContext, + *args, + **kwargs, + ): + if isinstance(request, CreateTaskRequest): + task_type = request.template.type + operation = create_operation + if request.inputs: + input_literal_size.labels(task_type=task_type).observe(request.inputs.ByteSize()) + elif isinstance(request, GetTaskRequest): + task_type = request.task_type + operation = get_operation + elif isinstance(request, DeleteTaskRequest): + task_type = request.task_type + operation = delete_operation + else: + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + return - def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: try: - agent = AgentRegistry.get_agent(context, request.task_type) - if agent is None: - return DeleteTaskResponse() - return agent.delete(context=context, resource_meta=request.resource_meta) + with request_latency.labels(task_type=task_type, operation=operation).time(): + res = await func(self, request, context, *args, **kwargs) + request_success_count.labels(task_type=task_type, operation=operation).inc() + return res + except FlyteAgentNotFound: + error_message = f"Cannot find agent for task type: {task_type}." + logger.error(error_message) + context.set_code(grpc.StatusCode.NOT_FOUND) + context.set_details(error_message) + request_failure_count.labels(task_type=task_type, operation=operation, error_code="404").inc() except Exception as e: - logger.error(f"failed to delete task with error {e}") + error_message = f"failed to {operation} {task_type} task with error {e}." + logger.error(error_message) context.set_code(grpc.StatusCode.INTERNAL) - context.set_details(f"failed to delete task with error {e}") + context.set_details(error_message) + request_failure_count.labels(task_type=task_type, operation=operation, error_code="500").inc() + + return wrapper + + +class AsyncAgentService(AsyncAgentServiceServicer): + @agent_exception_handler + async def CreateTask(self, request: CreateTaskRequest, context: grpc.ServicerContext) -> CreateTaskResponse: + tmp = TaskTemplate.from_flyte_idl(request.template) + inputs = LiteralMap.from_flyte_idl(request.inputs) if request.inputs else None + agent = AgentRegistry.get_agent(tmp.type) + + logger.info(f"{tmp.type} agent start creating the job") + if agent.asynchronous: + return await agent.async_create( + context=context, inputs=inputs, output_prefix=request.output_prefix, task_template=tmp + ) + return await asyncio.get_running_loop().run_in_executor( + None, + agent.create, + context, + request.output_prefix, + tmp, + inputs, + ) + + @agent_exception_handler + async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) -> GetTaskResponse: + agent = AgentRegistry.get_agent(request.task_type) + logger.info(f"{agent.task_type} agent start checking the status of the job") + if agent.asynchronous: + return await agent.async_get(context=context, resource_meta=request.resource_meta) + return await asyncio.get_running_loop().run_in_executor(None, agent.get, context, request.resource_meta) + + @agent_exception_handler + async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: + agent = AgentRegistry.get_agent(request.task_type) + logger.info(f"{agent.task_type} agent start deleting the job") + if agent.asynchronous: + return await agent.async_delete(context=context, resource_meta=request.resource_meta) + return await asyncio.get_running_loop().run_in_executor(None, agent.delete, context, request.resource_meta) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index cf9b901c74..84838559ec 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -1,7 +1,12 @@ +import asyncio +import signal +import sys import time import typing -from abc import ABC, abstractmethod +from abc import ABC from collections import OrderedDict +from functools import partial +from types import FrameType import grpc from flyteidl.admin.agent_pb2 import ( @@ -17,10 +22,12 @@ from flyteidl.core.tasks_pb2 import TaskTemplate from rich.progress import Progress +import flytekit from flytekit import FlyteContext, logger from flytekit.configuration import ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.type_engine import TypeEngine +from flytekit.exceptions.system import FlyteAgentNotFound from flytekit.models.literals import LiteralMap @@ -35,8 +42,16 @@ class AgentBase(ABC): will look up the agent based on the task type. Every task type can only have one agent. """ - def __init__(self, task_type: str): + def __init__(self, task_type: str, asynchronous=True): self._task_type = task_type + self._asynchronous = asynchronous + + @property + def asynchronous(self) -> bool: + """ + asynchronous is a flag to indicate whether the agent is asynchronous or not. + """ + return self._asynchronous @property def task_type(self) -> str: @@ -45,7 +60,6 @@ def task_type(self) -> str: """ return self._task_type - @abstractmethod def create( self, context: grpc.ServicerContext, @@ -56,23 +70,42 @@ def create( """ Return a Unique ID for the task that was created. It should return error code if the task creation failed. """ - pass + raise NotImplementedError - @abstractmethod def get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + raise NotImplementedError + + def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + """ + Delete the task. This call should be idempotent. + """ + raise NotImplementedError + + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + """ + Return a Unique ID for the task that was created. It should return error code if the task creation failed. + """ + raise NotImplementedError + + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: """ Return the status of the task, and return the outputs in some cases. For example, bigquery job can't write the structured dataset to the output location, so it returns the output literals to the propeller, and the propeller will write the structured dataset to the blob store. """ - pass + raise NotImplementedError - @abstractmethod - def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: """ Delete the task. This call should be idempotent. """ - pass + raise NotImplementedError class AgentRegistry(object): @@ -91,12 +124,9 @@ def register(agent: AgentBase): logger.info(f"Registering an agent for task type {agent.task_type}") @staticmethod - def get_agent(context: grpc.ServicerContext, task_type: str) -> typing.Optional[AgentBase]: + def get_agent(task_type: str) -> typing.Optional[AgentBase]: if task_type not in AgentRegistry._REGISTRY: - logger.error(f"Cannot find agent for task type [{task_type}]") - context.set_code(grpc.StatusCode.NOT_FOUND) - context.set_details(f"Cannot find the agent for task type [{task_type}]") - return None + raise FlyteAgentNotFound(f"Cannot find agent for task type: {task_type}.") return AgentRegistry._REGISTRY[task_type] @@ -105,9 +135,9 @@ def convert_to_flyte_state(state: str) -> State: Convert the state from the agent to the state in flyte. """ state = state.lower() - if state in ["failed"]: + if state in ["failed", "timedout", "canceled"]: return RETRYABLE_FAILURE - elif state in ["done", "succeeded"]: + elif state in ["done", "succeeded", "success"]: return SUCCEEDED elif state in ["running"]: return RUNNING @@ -121,46 +151,89 @@ def is_terminal_state(state: State) -> bool: return state in [SUCCEEDED, RETRYABLE_FAILURE, PERMANENT_FAILURE] +def get_agent_secret(secret_key: str) -> str: + return flytekit.current_context().secrets.get(secret_key) + + class AsyncAgentExecutorMixin: """ This mixin class is used to run the agent task locally, and it's only used for local execution. Task should inherit from this class if the task can be run in the agent. """ - def execute(self, **kwargs) -> typing.Any: - from unittest.mock import MagicMock + _is_canceled = None + _agent = None + _entity = None + def execute(self, **kwargs) -> typing.Any: from flytekit.tools.translator import get_serializable - entity = typing.cast(PythonTask, self) - m: OrderedDict = OrderedDict() - dummy_context = MagicMock(spec=grpc.ServicerContext) - cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity) - agent = AgentRegistry.get_agent(dummy_context, cp_entity.template.type) + self._entity = typing.cast(PythonTask, self) + task_template = get_serializable(OrderedDict(), SerializationSettings(ImageConfig()), self._entity).template + self._agent = AgentRegistry.get_agent(task_template.type) - if agent is None: - raise Exception("Cannot run the task locally, please mock.") - literals = {} + res = asyncio.run(self._create(task_template, kwargs)) + res = asyncio.run(self._get(resource_meta=res.resource_meta)) + + if res.resource.state != SUCCEEDED: + raise Exception(f"Failed to run the task {self._entity.name}") + + return LiteralMap.from_flyte_idl(res.resource.outputs) + + async def _create( + self, task_template: TaskTemplate, inputs: typing.Dict[str, typing.Any] = None + ) -> CreateTaskResponse: ctx = FlyteContext.current_context() - for k, v in kwargs.items(): - literals[k] = TypeEngine.to_literal(ctx, v, type(v), entity.interface.inputs[k].type) + grpc_ctx = _get_grpc_context() + + # Convert python inputs to literals + literals = {} + for k, v in inputs.items(): + literals[k] = TypeEngine.to_literal(ctx, v, type(v), self._entity.interface.inputs[k].type) inputs = LiteralMap(literals) if literals else None output_prefix = ctx.file_access.get_random_local_directory() - cp_entity = get_serializable(m, settings=SerializationSettings(ImageConfig()), entity=entity) - res = agent.create(dummy_context, output_prefix, cp_entity.template, inputs) + + if self._agent.asynchronous: + res = await self._agent.async_create(grpc_ctx, output_prefix, task_template, inputs) + else: + res = self._agent.create(grpc_ctx, output_prefix, task_template, inputs) + + signal.signal(signal.SIGINT, partial(self.signal_handler, res.resource_meta)) # type: ignore + return res + + async def _get(self, resource_meta: bytes) -> GetTaskResponse: state = RUNNING - metadata = res.resource_meta + grpc_ctx = _get_grpc_context() + progress = Progress(transient=True) - task = progress.add_task(f"[cyan]Running Task {entity.name}...", total=None) + task = progress.add_task(f"[cyan]Running Task {self._entity.name}...", total=None) with progress: while not is_terminal_state(state): progress.start_task(task) time.sleep(1) - res = agent.get(dummy_context, metadata) + if self._agent.asynchronous: + res = await self._agent.async_get(grpc_ctx, resource_meta) + if self._is_canceled: + await self._is_canceled + sys.exit(1) + else: + res = self._agent.get(grpc_ctx, resource_meta) state = res.resource.state logger.info(f"Task state: {state}") + return res - if state != SUCCEEDED: - raise Exception(f"Failed to run the task {entity.name}") + def signal_handler(self, resource_meta: bytes, signum: int, frame: FrameType) -> typing.Any: + grpc_ctx = _get_grpc_context() + if self._agent.asynchronous: + if self._is_canceled is None: + self._is_canceled = asyncio.create_task(self._agent.async_delete(grpc_ctx, resource_meta)) + else: + self._agent.delete(grpc_ctx, resource_meta) + sys.exit(1) - return LiteralMap.from_flyte_idl(res.resource.outputs) + +def _get_grpc_context(): + from unittest.mock import MagicMock + + grpc_ctx = MagicMock(spec=grpc.ServicerContext) + return grpc_ctx diff --git a/flytekit/extras/pytorch/checkpoint.py b/flytekit/extras/pytorch/checkpoint.py index c7561f13f4..9e5841a1d5 100644 --- a/flytekit/extras/pytorch/checkpoint.py +++ b/flytekit/extras/pytorch/checkpoint.py @@ -4,7 +4,7 @@ from typing import Any, Callable, Dict, NamedTuple, Optional, Type, Union import torch -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from typing_extensions import Protocol from flytekit.core.context_manager import FlyteContext @@ -20,9 +20,8 @@ class IsDataclass(Protocol): __post_init__: Optional[Callable] -@dataclass_json @dataclass -class PyTorchCheckpoint: +class PyTorchCheckpoint(DataClassJsonMixin): """ This class is helpful to save a checkpoint. """ diff --git a/flytekit/extras/sklearn/native.py b/flytekit/extras/sklearn/native.py index 59ecca70c5..dd231c15be 100644 --- a/flytekit/extras/sklearn/native.py +++ b/flytekit/extras/sklearn/native.py @@ -1,5 +1,5 @@ import pathlib -from typing import Generic, Type, TypeVar +from typing import Type, TypeVar import joblib import sklearn @@ -13,7 +13,7 @@ T = TypeVar("T") -class SklearnTypeTransformer(TypeTransformer, Generic[T]): +class SklearnTypeTransformer(TypeTransformer[T]): def get_literal_type(self, t: Type[T]) -> LiteralType: return LiteralType( blob=_core_types.BlobType( @@ -75,5 +75,15 @@ class SklearnEstimatorTransformer(SklearnTypeTransformer[sklearn.base.BaseEstima def __init__(self): super().__init__(name="Sklearn Estimator", t=sklearn.base.BaseEstimator) + def guess_python_type(self, literal_type: LiteralType) -> Type[sklearn.base.BaseEstimator]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.SKLEARN_FORMAT + ): + return sklearn.base.BaseEstimator + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + TypeEngine.register(SklearnEstimatorTransformer()) diff --git a/flytekit/extras/sqlite3/task.py b/flytekit/extras/sqlite3/task.py index ef8013a5da..6bcbae5d4f 100644 --- a/flytekit/extras/sqlite3/task.py +++ b/flytekit/extras/sqlite3/task.py @@ -43,7 +43,7 @@ class SQLite3Config(object): Args: uri: default FlyteFile that will be downloaded on execute compressed: Boolean that indicates if the given file is a compressed archive. Supported file types are - [zip, tar, gztar, bztar, xztar] + [zip, tar, gztar, bztar, xztar] """ uri: str @@ -65,7 +65,7 @@ class SQLite3Task(PythonCustomizedContainerTask[SQLite3Config], SQLTask[SQLite3C :language: python :dedent: 4 - See the :std:ref:`cookbook ` for additional usage examples and + See the :ref:`integrations guide ` for additional usage examples and the base class :py:class:`flytekit.extend.PythonCustomizedContainerTask` as well. """ diff --git a/flytekit/extras/tensorflow/record.py b/flytekit/extras/tensorflow/record.py index 17e7c37ddd..0258d379ac 100644 --- a/flytekit/extras/tensorflow/record.py +++ b/flytekit/extras/tensorflow/record.py @@ -3,7 +3,7 @@ from typing import Optional, Tuple, Type, Union import tensorflow as tf -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from tensorflow.python.data.ops.readers import TFRecordDatasetV2 from typing_extensions import Annotated, get_args, get_origin @@ -16,9 +16,8 @@ from flytekit.types.file import TFRecordFile -@dataclass_json @dataclass -class TFRecordDatasetConfig: +class TFRecordDatasetConfig(DataClassJsonMixin): """ TFRecordDatasetConfig can be used while creating tf.data.TFRecordDataset comprising record of one or more TFRecord files. diff --git a/flytekit/image_spec/image_spec.py b/flytekit/image_spec/image_spec.py index 9ab3c2b8a5..64e701d5e7 100644 --- a/flytekit/image_spec/image_spec.py +++ b/flytekit/image_spec/image_spec.py @@ -97,7 +97,7 @@ def exist(self) -> bool: tag = calculate_hash_from_image_spec(self) # if docker engine is not running locally container_registry = DOCKER_HUB - if "/" in self.registry: + if self.registry and "/" in self.registry: container_registry = self.registry.split("/")[0] if container_registry == DOCKER_HUB: url = f"https://hub.docker.com/v2/repositories/{self.registry}/{self.name}/tags/{tag}" diff --git a/flytekit/interaction/click_types.py b/flytekit/interaction/click_types.py new file mode 100644 index 0000000000..9f03ce8c84 --- /dev/null +++ b/flytekit/interaction/click_types.py @@ -0,0 +1,462 @@ +import datetime +import json +import logging +import os +import pathlib +import typing +from dataclasses import dataclass +from typing import cast + +import cloudpickle +import rich_click as click +import yaml +from dataclasses_json import DataClassJsonMixin +from pytimeparse import parse +from typing_extensions import get_args + +from flytekit import Blob, BlobMetadata, BlobType, FlyteContext, FlyteContextManager, Literal, LiteralType, Scalar +from flytekit.core.data_persistence import FileAccessProvider +from flytekit.core.type_engine import TypeEngine +from flytekit.models import literals +from flytekit.models.literals import LiteralCollection, LiteralMap, Primitive, Union, Void +from flytekit.models.types import SimpleType +from flytekit.remote import FlyteRemote +from flytekit.tools import script_mode +from flytekit.types.pickle.pickle import FlytePickleTransformer + + +def remove_prefix(text, prefix): + if text.startswith(prefix): + return text[len(prefix) :] + return text + + +@dataclass +class Directory(object): + dir_path: str + local_file: typing.Optional[pathlib.Path] = None + local: bool = True + + +class DirParamType(click.ParamType): + name = "directory path" + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + if FileAccessProvider.is_remote(value): + return Directory(dir_path=value, local=False) + p = pathlib.Path(value) + if p.exists() and p.is_dir(): + files = list(p.iterdir()) + if len(files) != 1: + raise ValueError( + f"Currently only directories containing one file are supported, found [{len(files)}] files found in {p.resolve()}" + ) + return Directory(dir_path=str(p), local_file=files[0].resolve()) + raise click.BadParameter(f"parameter should be a valid directory path, {value}") + + +@dataclass +class FileParam(object): + filepath: str + local: bool = True + + +class FileParamType(click.ParamType): + name = "file path" + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + if FileAccessProvider.is_remote(value): + return FileParam(filepath=value, local=False) + p = pathlib.Path(value) + if p.exists() and p.is_file(): + return FileParam(filepath=str(p.resolve())) + raise click.BadParameter(f"parameter should be a valid file path, {value}") + + +class PickleParamType(click.ParamType): + name = "pickle" + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + uri = FlyteContextManager.current_context().file_access.get_random_local_path() + with open(uri, "w+b") as outfile: + cloudpickle.dump(value, outfile) + return FileParam(filepath=str(pathlib.Path(uri).resolve())) + + +class DateTimeType(click.DateTime): + _NOW_FMT = "now" + _ADDITONAL_FORMATS = [_NOW_FMT] + + def __init__(self): + super().__init__() + self.formats.extend(self._ADDITONAL_FORMATS) + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + if value in self._ADDITONAL_FORMATS: + if value == self._NOW_FMT: + return datetime.datetime.now() + return super().convert(value, param, ctx) + + +class DurationParamType(click.ParamType): + name = "[1:24 | :22 | 1 minute | 10 days | ...]" + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + if value is None: + raise click.BadParameter("None value cannot be converted to a Duration type.") + return datetime.timedelta(seconds=parse(value)) + + +class JsonParamType(click.ParamType): + name = "json object OR json/yaml file path" + + def convert( + self, value: typing.Any, param: typing.Optional[click.Parameter], ctx: typing.Optional[click.Context] + ) -> typing.Any: + if value is None: + raise click.BadParameter("None value cannot be converted to a Json type.") + if type(value) == dict or type(value) == list: + return value + try: + return json.loads(value) + except Exception: # noqa + try: + # We failed to load the json, so we'll try to load it as a file + if os.path.exists(value): + # if the value is a yaml file, we'll try to load it as yaml + if value.endswith(".yaml") or value.endswith(".yml"): + with open(value, "r") as f: + return yaml.safe_load(f) + with open(value, "r") as f: + return json.load(f) + raise + except json.JSONDecodeError as e: + raise click.BadParameter(f"parameter {param} should be a valid json object, {value}, error: {e}") + + +@dataclass +class DefaultConverter(object): + click_type: click.ParamType + primitive_type: typing.Optional[str] = None + scalar_type: typing.Optional[str] = None + + def convert(self, value: typing.Any, python_type_hint: typing.Optional[typing.Type] = None) -> Scalar: + if self.primitive_type: + return Scalar(primitive=Primitive(**{self.primitive_type: value})) + if self.scalar_type: + return Scalar(**{self.scalar_type: value}) + + raise NotImplementedError("Not implemented yet!") + + +class FlyteLiteralConverter(object): + name = "literal_type" + + SIMPLE_TYPE_CONVERTER: typing.Dict[SimpleType, DefaultConverter] = { + SimpleType.FLOAT: DefaultConverter(click.FLOAT, primitive_type="float_value"), + SimpleType.INTEGER: DefaultConverter(click.INT, primitive_type="integer"), + SimpleType.STRING: DefaultConverter(click.STRING, primitive_type="string_value"), + SimpleType.BOOLEAN: DefaultConverter(click.BOOL, primitive_type="boolean"), + SimpleType.DURATION: DefaultConverter(DurationParamType(), primitive_type="duration"), + SimpleType.DATETIME: DefaultConverter(click.DateTime(), primitive_type="datetime"), + } + + def __init__( + self, + flyte_ctx: FlyteContext, + literal_type: LiteralType, + python_type: typing.Type, + get_upload_url_fn: typing.Callable, + is_remote: bool, + remote_instance_accessor: typing.Callable[[], FlyteRemote] = None, + ): + self._is_remote = is_remote + self._literal_type = literal_type + self._python_type = python_type + self._create_upload_fn = get_upload_url_fn + self._flyte_ctx = flyte_ctx + self._click_type = click.UNPROCESSED + self._remote_instance_accessor = remote_instance_accessor + + if self._literal_type.simple: + if self._literal_type.simple == SimpleType.STRUCT: + self._click_type = JsonParamType() + self._click_type.name = f"JSON object {self._python_type.__name__}" + elif self._literal_type.simple not in self.SIMPLE_TYPE_CONVERTER: + raise NotImplementedError(f"Type {self._literal_type.simple} is not supported in pyflyte run") + else: + self._converter = self.SIMPLE_TYPE_CONVERTER[self._literal_type.simple] + self._click_type = self._converter.click_type + + if self._literal_type.enum_type: + self._converter = self.SIMPLE_TYPE_CONVERTER[SimpleType.STRING] + self._click_type = click.Choice(self._literal_type.enum_type.values) + + if self._literal_type.structured_dataset_type: + self._click_type = DirParamType() + + if self._literal_type.collection_type or self._literal_type.map_value_type: + self._click_type = JsonParamType() + if self._literal_type.collection_type: + self._click_type.name = "json list" + else: + self._click_type.name = "json dictionary" + + if self._literal_type.blob: + if self._literal_type.blob.dimensionality == BlobType.BlobDimensionality.SINGLE: + if self._literal_type.blob.format == FlytePickleTransformer.PYTHON_PICKLE_FORMAT: + self._click_type = PickleParamType() + else: + self._click_type = FileParamType() + else: + self._click_type = DirParamType() + + @property + def click_type(self) -> click.ParamType: + return self._click_type + + def is_bool(self) -> bool: + if self._literal_type.simple: + return self._literal_type.simple == SimpleType.BOOLEAN + return False + + def get_uri_for_dir( + self, ctx: typing.Optional[click.Context], value: Directory, remote_filename: typing.Optional[str] = None + ): + uri = value.dir_path + + if self._is_remote and value.local: + md5, _ = script_mode.hash_file(value.local_file) + if not remote_filename: + remote_filename = value.local_file.name + remote = self._remote_instance_accessor() + _, native_url = remote.upload_file(value.local_file) + uri = native_url[: -len(remote_filename)] + + return uri + + def convert_to_structured_dataset( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: Directory + ) -> Literal: + + uri = self.get_uri_for_dir(ctx, value, "00000.parquet") + + lit = Literal( + scalar=Scalar( + structured_dataset=literals.StructuredDataset( + uri=uri, + metadata=literals.StructuredDatasetMetadata( + structured_dataset_type=self._literal_type.structured_dataset_type + ), + ), + ), + ) + + return lit + + def convert_to_blob( + self, + ctx: typing.Optional[click.Context], + param: typing.Optional[click.Parameter], + value: typing.Union[Directory, FileParam], + ) -> Literal: + if isinstance(value, Directory): + uri = self.get_uri_for_dir(ctx, value) + else: + uri = value.filepath + if self._is_remote and value.local: + fp = pathlib.Path(value.filepath) + remote = self._remote_instance_accessor() + _, uri = remote.upload_file(fp) + + lit = Literal( + scalar=Scalar( + blob=Blob( + metadata=BlobMetadata(type=self._literal_type.blob), + uri=uri, + ), + ), + ) + + return lit + + def convert_to_union( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any + ) -> Literal: + lt = self._literal_type + + # handle case where Union type has NoneType and the value is None + has_none_type = any(v.simple == 0 for v in self._literal_type.union_type.variants) + if has_none_type and value is None: + return Literal(scalar=Scalar(none_type=Void())) + + for i in range(len(self._literal_type.union_type.variants)): + variant = self._literal_type.union_type.variants[i] + python_type = get_args(self._python_type)[i] + converter = FlyteLiteralConverter( + self._flyte_ctx, + variant, + python_type, + self._create_upload_fn, + self._is_remote, + self._remote_instance_accessor, + ) + try: + # Here we use click converter to convert the input in command line to native python type, + # and then use flyte converter to convert it to literal. + python_val = converter._click_type.convert(value, param, ctx) + literal = converter.convert_to_literal(ctx, param, python_val) + return Literal(scalar=Scalar(union=Union(literal, variant))) + except (Exception or AttributeError) as e: + logging.debug(f"Failed to convert python type {python_type} to literal type {variant}", e) + raise ValueError(f"Failed to convert python type {self._python_type} to literal type {lt}") + + def convert_to_list( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: list + ) -> Literal: + """ + Convert a python list into a Flyte Literal + """ + if not value: + raise click.BadParameter("Expected non-empty list") + if not isinstance(value, list): + raise click.BadParameter(f"Expected json list '[...]', parsed value is {type(value)}") + converter = FlyteLiteralConverter( + self._flyte_ctx, + self._literal_type.collection_type, + type(value[0]), + self._create_upload_fn, + self._is_remote, + self._remote_instance_accessor, + ) + lt = Literal(collection=LiteralCollection([])) + for v in value: + click_val = converter._click_type.convert(v, param, ctx) + lt.collection.literals.append(converter.convert_to_literal(ctx, param, click_val)) + return lt + + def convert_to_map( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: dict + ) -> Literal: + """ + Convert a python dict into a Flyte Literal. + It is assumed that the click parameter type is a JsonParamType. The map is also assumed to be univariate. + """ + if not value: + raise click.BadParameter("Expected non-empty dict") + if not isinstance(value, dict): + raise click.BadParameter(f"Expected json dict '{{...}}', parsed value is {type(value)}") + converter = FlyteLiteralConverter( + self._flyte_ctx, + self._literal_type.map_value_type, + type(value[list(value.keys())[0]]), + self._create_upload_fn, + self._is_remote, + self._remote_instance_accessor, + ) + lt = Literal(map=LiteralMap({})) + for k, v in value.items(): + click_val = converter._click_type.convert(v, param, ctx) + lt.map.literals[k] = converter.convert_to_literal(ctx, param, click_val) + return lt + + def convert_to_struct( + self, + ctx: typing.Optional[click.Context], + param: typing.Optional[click.Parameter], + value: typing.Union[dict, typing.Any], + ) -> Literal: + """ + Convert the loaded json object to a Flyte Literal struct type. + """ + if type(value) != self._python_type: + if is_pydantic_basemodel(self._python_type): + o = self._python_type.parse_raw(json.dumps(value)) # type: ignore + else: + o = cast(DataClassJsonMixin, self._python_type).from_json(json.dumps(value)) + else: + o = value + return TypeEngine.to_literal(self._flyte_ctx, o, self._python_type, self._literal_type) + + def convert_to_literal( + self, ctx: typing.Optional[click.Context], param: typing.Optional[click.Parameter], value: typing.Any + ) -> Literal: + if self._literal_type.structured_dataset_type: + return self.convert_to_structured_dataset(ctx, param, value) + + if self._literal_type.blob: + return self.convert_to_blob(ctx, param, value) + + if self._literal_type.collection_type: + return self.convert_to_list(ctx, param, value) + + if self._literal_type.map_value_type: + return self.convert_to_map(ctx, param, value) + + if self._literal_type.union_type: + return self.convert_to_union(ctx, param, value) + + if self._literal_type.simple or self._literal_type.enum_type: + if self._literal_type.simple and self._literal_type.simple == SimpleType.STRUCT: + return self.convert_to_struct(ctx, param, value) + return Literal(scalar=self._converter.convert(value, self._python_type)) + + if self._literal_type.schema: + raise DeprecationWarning("Schema Types are not supported in pyflyte run. Use StructuredDataset instead.") + + raise NotImplementedError( + f"CLI parsing is not available for Python Type:`{self._python_type}`, LiteralType:`{self._literal_type}`." + ) + + def convert( + self, ctx: click.Context, param: typing.Optional[click.Parameter], value: typing.Any + ) -> typing.Union[Literal, typing.Any]: + """ + Convert the value to a Flyte Literal or a python native type. This is used by click to convert the input. + """ + try: + lit = self.convert_to_literal(ctx, param, value) + if not self._is_remote: + return TypeEngine.to_python_value(self._flyte_ctx, lit, self._python_type) + return lit + except click.BadParameter: + raise + except Exception as e: + raise click.BadParameter(f"Failed to convert param {param}, {value} to {self._python_type}") from e + + +def is_pydantic_basemodel(python_type: typing.Type) -> bool: + """ + Checks if the python type is a pydantic BaseModel + """ + try: + import pydantic + except ImportError: + return False + else: + return issubclass(python_type, pydantic.BaseModel) + + +def key_value_callback(_: typing.Any, param: str, values: typing.List[str]) -> typing.Optional[typing.Dict[str, str]]: + """ + Callback for click to parse key-value pairs. + """ + if not values: + return None + result = {} + for v in values: + if "=" not in v: + raise click.BadParameter(f"Expected key-value pair of the form key=value, got {v}") + k, v = v.split("=", 1) + result[k.strip()] = v.strip() + return result diff --git a/flytekit/interaction/parse_stdin.py b/flytekit/interaction/parse_stdin.py index ec051d73ce..19c9aef377 100644 --- a/flytekit/interaction/parse_stdin.py +++ b/flytekit/interaction/parse_stdin.py @@ -6,31 +6,23 @@ from flytekit.core.context_manager import FlyteContext from flytekit.core.type_engine import TypeEngine -from flytekit.loggers import logger from flytekit.models.literals import Literal -# TODO: Move the improved click parsing here. https://github.com/flyteorg/flyte/issues/3124 -def parse_stdin_to_literal(ctx: FlyteContext, t: typing.Type, message_prefix: typing.Optional[str]) -> Literal: +def parse_stdin_to_literal(ctx: FlyteContext, t: typing.Type, message: typing.Optional[str]) -> Literal: + """ + Parses the user input from stdin and converts it to a literal of the given type. + """ + from flytekit.interaction.click_types import FlyteLiteralConverter - message = message_prefix or "" - message += f"Please enter value for type {t} to continue" - if issubclass(t, bool): - user_input = click.prompt(message, type=bool) - l = TypeEngine.to_literal(ctx, user_input, bool, TypeEngine.to_literal_type(bool)) # noqa - elif issubclass(t, int): - user_input = click.prompt(message, type=int) - l = TypeEngine.to_literal(ctx, user_input, int, TypeEngine.to_literal_type(int)) # noqa - elif issubclass(t, float): - user_input = click.prompt(message, type=float) - l = TypeEngine.to_literal(ctx, user_input, float, TypeEngine.to_literal_type(float)) # noqa - elif issubclass(t, str): - user_input = click.prompt(message, type=str) - l = TypeEngine.to_literal(ctx, user_input, str, TypeEngine.to_literal_type(str)) # noqa - else: - # Todo: We should implement the rest by way of importing the code in pyflyte run - # that parses text from the command line - raise Exception("Only bool, int/float, or strings are accepted for now.") - - logger.debug(f"Parsed literal {l} from user input {user_input}") - return l + literal_type = TypeEngine.to_literal_type(t) + literal_converter = FlyteLiteralConverter( + ctx, + literal_type=literal_type, + python_type=t, + get_upload_url_fn=lambda: None, + is_remote=False, + remote_instance_accessor=None, + ) + user_input = click.prompt(message, type=literal_converter.click_type) + return literal_converter.convert_to_literal(click.Context(click.Command(None)), None, user_input) diff --git a/flytekit/models/core/condition.py b/flytekit/models/core/condition.py index 845b3b4f79..27e0bc505b 100644 --- a/flytekit/models/core/condition.py +++ b/flytekit/models/core/condition.py @@ -134,15 +134,17 @@ def from_flyte_idl(cls, pb2_object): class Operand(_common.FlyteIdlEntity): - def __init__(self, primitive=None, var=None): + def __init__(self, primitive=None, var=None, scalar=None): """ Defines an operand to a comparison expression. - :param flytekit.models.literals.Primitive primitive: - :param Text var: + :param flytekit.models.literals.Primitive primitive: A primitive value + :param Text var: A variable name + :param flytekit.models.literals.Scalar scalar: A scalar value """ self._primitive = primitive self._var = var + self._scalar = scalar @property def primitive(self): @@ -160,6 +162,14 @@ def var(self): return self._var + @property + def scalar(self): + """ + :rtype: flytekit.models.literals.Scalar + """ + + return self._scalar + def to_flyte_idl(self): """ :rtype: flyteidl.core.condition_pb2.Operand @@ -167,6 +177,7 @@ def to_flyte_idl(self): return _condition.Operand( primitive=self.primitive.to_flyte_idl() if self.primitive else None, var=self.var if self.var else None, + scalar=self.scalar.to_flyte_idl() if self.scalar else None, ) @classmethod @@ -176,6 +187,7 @@ def from_flyte_idl(cls, pb2_object): if pb2_object.HasField("primitive") else None, var=pb2_object.var if pb2_object.HasField("var") else None, + scalar=_literals.Scalar.from_flyte_idl(pb2_object.scalar) if pb2_object.HasField("scalar") else None, ) diff --git a/flytekit/models/core/workflow.py b/flytekit/models/core/workflow.py index 1af53b3a53..e60038c0f6 100644 --- a/flytekit/models/core/workflow.py +++ b/flytekit/models/core/workflow.py @@ -17,6 +17,7 @@ class IfBlock(_common.FlyteIdlEntity): def __init__(self, condition, then_node): """ Defines a condition and the execution unit that should be executed if the condition is satisfied. + :param flytekit.models.core.condition.BooleanExpression condition: :param Node then_node: """ @@ -72,6 +73,7 @@ def __init__(self, case, other=None, else_node=None, error=None): def case(self): """ First condition to evaluate. + :rtype: IfBlock """ @@ -81,6 +83,7 @@ def case(self): def other(self): """ Additional branches to evaluate. + :rtype: list[IfBlock] """ @@ -90,6 +93,7 @@ def other(self): def else_node(self): """ The node to execute in case none of the branches were taken. + :rtype: Node """ @@ -99,6 +103,7 @@ def else_node(self): def error(self): """ An error to throw in case none of the branches were taken. + :rtype: flytekit.models.types.Error """ @@ -130,6 +135,7 @@ def __init__(self, if_else: IfElseBlock): """ BranchNode is a special node that alter the flow of the workflow graph. It allows the control flow to branch at runtime based on a series of conditions that get evaluated on various parameters (e.g. inputs, primtives). + :param IfElseBlock if_else: """ @@ -193,7 +199,7 @@ def retries(self): @property def interruptible(self): """ - :rtype: flytekit.models. + :rtype: flytekit.models """ return self._interruptible @@ -223,6 +229,7 @@ class SignalCondition(_common.FlyteIdlEntity): def __init__(self, signal_id: str, type: type_models.LiteralType, output_variable_name: str): """ Represents a dependency on an signal from a user. + :param signal_id: The node id of the signal, also the signal name. :param type: """ @@ -260,6 +267,7 @@ class ApproveCondition(_common.FlyteIdlEntity): def __init__(self, signal_id: str): """ Represents a dependency on an signal from a user. + :param signal_id: The node id of the signal, also the signal name. """ self._signal_id = signal_id @@ -340,6 +348,39 @@ def from_flyte_idl(cls, pb2_object: _core_workflow.GateNode) -> "GateNode": ) +class ArrayNode(_common.FlyteIdlEntity): + def __init__(self, node: "Node", parallelism=None, min_successes=None, min_success_ratio=None) -> None: + """ + TODO: docstring + """ + self._node = node + self._parallelism = parallelism + # TODO either min_successes or min_success_ratio should be set + self._min_successes = min_successes + self._min_success_ratio = min_success_ratio + + @property + def node(self) -> "Node": + return self._node + + def to_flyte_idl(self) -> _core_workflow.ArrayNode: + return _core_workflow.ArrayNode( + node=self._node.to_flyte_idl() if self._node is not None else None, + parallelism=self._parallelism, + min_successes=self._min_successes, + min_success_ratio=self._min_success_ratio, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object) -> "ArrayNode": + return cls( + Node.from_flyte_idl(pb2_object.node), + pb2_object.parallelism, + pb2_object.min_successes, + pb2_object.min_success_ratio, + ) + + class Node(_common.FlyteIdlEntity): def __init__( self, @@ -352,6 +393,7 @@ def __init__( workflow_node=None, branch_node=None, gate_node: typing.Optional[GateNode] = None, + array_node: typing.Optional[ArrayNode] = None, ): """ A Workflow graph Node. One unit of execution in the graph. Each node can be linked to a Task, @@ -383,12 +425,14 @@ def __init__( self._workflow_node = workflow_node self._branch_node = branch_node self._gate_node = gate_node + self._array_node = array_node @property def id(self): """ A workflow-level unique identifier that identifies this node in the workflow. "inputs" and "outputs" are reserved node ids that cannot be used by other nodes. + :rtype: Text """ return self._id @@ -397,6 +441,7 @@ def id(self): def metadata(self): """ Extra metadata about the node. + :rtype: NodeMetadata """ return self._metadata @@ -406,6 +451,7 @@ def inputs(self): """ Specifies how to bind the underlying interface's inputs. All required inputs specified in the underlying interface must be fulfilled. + :rtype: list[flytekit.models.literals.Binding] """ return self._inputs @@ -416,6 +462,7 @@ def upstream_node_ids(self): [Optional] Specifies execution dependency for this node ensuring it will only get scheduled to run after all its upstream nodes have completed. This node will have an implicit dependency on any node that appears in inputs field. + :rtype: list[Text] """ return self._upstream_node_ids @@ -426,6 +473,7 @@ def output_aliases(self): [Optional] A node can define aliases for a subset of its outputs. This is particularly useful if different nodes need to conform to the same interface (e.g. all branches in a branch node). Downstream nodes must refer to this node's outputs using the alias if one is specified. + :rtype: list[Alias] """ return self._output_aliases @@ -434,6 +482,7 @@ def output_aliases(self): def task_node(self): """ [Optional] Information about the Task to execute in this node. + :rtype: TaskNode """ return self._task_node @@ -442,6 +491,7 @@ def task_node(self): def workflow_node(self): """ [Optional] Information about the Workflow to execute in this mode. + :rtype: WorkflowNode """ return self._workflow_node @@ -450,6 +500,7 @@ def workflow_node(self): def branch_node(self): """ [Optional] Information about the branch node to evaluate in this node. + :rtype: BranchNode """ return self._branch_node @@ -458,6 +509,10 @@ def branch_node(self): def gate_node(self) -> typing.Optional[GateNode]: return self._gate_node + @property + def array_node(self) -> typing.Optional[ArrayNode]: + return self._array_node + @property def target(self): """ @@ -479,6 +534,7 @@ def to_flyte_idl(self): workflow_node=self.workflow_node.to_flyte_idl() if self.workflow_node is not None else None, branch_node=self.branch_node.to_flyte_idl() if self.branch_node is not None else None, gate_node=self.gate_node.to_flyte_idl() if self.gate_node else None, + array_node=self.array_node.to_flyte_idl() if self.array_node else None, ) @classmethod @@ -501,6 +557,7 @@ def from_flyte_idl(cls, pb2_object): if pb2_object.HasField("branch_node") else None, gate_node=GateNode.from_flyte_idl(pb2_object.gate_node) if pb2_object.HasField("gate_node") else None, + array_node=ArrayNode.from_flyte_idl(pb2_object.array_node) if pb2_object.HasField("array_node") else None, ) @@ -529,11 +586,11 @@ class TaskNode(_common.FlyteIdlEntity): def __init__(self, reference_id, overrides: typing.Optional[TaskNodeOverrides] = None): """ Refers to the task that the Node is to execute. - NB: This is currently a oneof in protobuf, but there's only one option currently. This code should be updated - when more options are available. + This is currently a oneof in protobuf, but there's only one option currently. + This code should be updated when more options are available. :param flytekit.models.core.identifier.Identifier reference_id: A globally unique identifier for the task. - :param flyteidl.core.workflow_pb2.TaskNodeOverrides + :param flyteidl.core.workflow_pb2.TaskNodeOverrides: """ self._reference_id = reference_id self._overrides = overrides @@ -541,7 +598,8 @@ def __init__(self, reference_id, overrides: typing.Optional[TaskNodeOverrides] = @property def reference_id(self): """ - A globally unique identifier for the task. This should map to the identifier in Flyte Admin. + A globally unique identifier for the task. This should map to the identifier in Flyte Admin. + :rtype: flytekit.models.core.identifier.Identifier """ return self._reference_id @@ -577,10 +635,10 @@ def from_flyte_idl(cls, pb2_object): class WorkflowNode(_common.FlyteIdlEntity): def __init__(self, launchplan_ref=None, sub_workflow_ref=None): """ - Refers to a the workflow the node is to execute. One of the references must be supplied. + Refers to a the workflow the node is to execute. One of the references must be supplied. :param flytekit.models.core.identifier.Identifier launchplan_ref: [Optional] A globally unique identifier for - the launch plan. Should map to Admin. + the launch plan. Should map to Admin. :param flytekit.models.core.identifier.Identifier sub_workflow_ref: [Optional] Reference to a subworkflow, that should be defined with the compiler context. """ @@ -591,6 +649,7 @@ def __init__(self, launchplan_ref=None, sub_workflow_ref=None): def launchplan_ref(self): """ [Optional] A globally unique identifier for the launch plan. Should map to Admin. + :rtype: flytekit.models.core.identifier.Identifier """ return self._launchplan_ref @@ -599,6 +658,7 @@ def launchplan_ref(self): def sub_workflow_ref(self): """ [Optional] Reference to a subworkflow, that should be defined with the compiler context. + :rtype: flytekit.models.core.identifier.Identifier """ return self._sub_workflow_ref @@ -623,6 +683,7 @@ def to_flyte_idl(self): def from_flyte_idl(cls, pb2_object): """ :param flyteidl.core.workflow_pb2.WorkflowNode pb2_object: + :rtype: WorkflowNode """ if pb2_object.HasField("launchplan_ref"): @@ -654,6 +715,7 @@ class OnFailurePolicy(object): def __init__(self, on_failure=None): """ Metadata for the workflow. + :param on_failure flytekit.models.core.workflow.WorkflowMetadata.OnFailurePolicy: [Optional] The execution policy when the workflow detects a failure. """ self._on_failure = on_failure @@ -678,6 +740,7 @@ def to_flyte_idl(self): def from_flyte_idl(cls, pb2_object): """ :param flyteidl.core.workflow_pb2.WorkflowMetadata pb2_object: + :rtype: WorkflowMetadata """ return cls( @@ -708,6 +771,7 @@ def to_flyte_idl(self): def from_flyte_idl(cls, pb2_object): """ :param flyteidl.core.workflow_pb2.WorkflowMetadataDefaults pb2_object: + :rtype: WorkflowMetadata """ return cls(interruptible=pb2_object.interruptible) @@ -758,6 +822,7 @@ def __init__( def id(self): """ This is an autogenerated id by the system. The id is globally unique across Flyte. + :rtype: flytekit.models.core.identifier.Identifier """ return self._id @@ -766,6 +831,7 @@ def id(self): def metadata(self): """ This contains information on how to run the workflow. + :rtype: WorkflowMetadata """ return self._metadata @@ -774,6 +840,7 @@ def metadata(self): def metadata_defaults(self): """ This contains information on how to run the workflow. + :rtype: WorkflowMetadataDefaults """ return self._metadata_defaults @@ -783,6 +850,7 @@ def interface(self): """ Defines a strongly typed interface for the Workflow (inputs, outputs). This can include some optional parameters. + :rtype: flytekit.models.interface.TypedInterface """ return self._interface @@ -791,7 +859,8 @@ def interface(self): def nodes(self): """ A list of nodes. In addition, "globals" is a special reserved node id that can be used to consume - workflow inputs + workflow inputs. + :rtype: list[Node] """ return self._nodes @@ -803,6 +872,7 @@ def outputs(self): pull node outputs or specify literals. All workflow outputs specified in the interface field must be bound in order for the workflow to be validated. A workflow has an implicit dependency on all of its nodes to execute successfully in order to bind final outputs. + :rtype: list[flytekit.models.literals.Binding] """ return self._outputs @@ -813,6 +883,7 @@ def failure_node(self): Node failure_node: A catch-all node. This node is executed whenever the execution engine determines the workflow has failed. The interface of this node must match the Workflow interface with an additional input named "error" of type pb.lyft.flyte.core.Error. + :rtype: Node """ return self._failure_node @@ -835,6 +906,7 @@ def to_flyte_idl(self): def from_flyte_idl(cls, pb2_object): """ :param flyteidl.core.workflow_pb2.WorkflowTemplate pb2_object: + :rtype: WorkflowTemplate """ return cls( @@ -863,6 +935,7 @@ def __init__(self, var, alias): def var(self): """ Must match one of the output variable names on a node. + :rtype: Text """ return self._var @@ -871,6 +944,7 @@ def var(self): def alias(self): """ A workflow-level unique alias that downstream nodes can refer to in their input. + :rtype: Text """ return self._alias @@ -885,6 +959,7 @@ def to_flyte_idl(self): def from_flyte_idl(cls, pb2_object): """ :param flyteidl.core.workflow_pb2.Alias pb2_object: + :return: Alias """ return cls(pb2_object.var, pb2_object.alias) diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 23b4baab01..b76c01c967 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -5,6 +5,7 @@ from typing import Optional import flyteidl +import flyteidl.admin.cluster_assignment_pb2 as _cluster_assignment_pb2 import flyteidl.admin.execution_pb2 as _execution_pb2 import flyteidl.admin.node_execution_pb2 as _node_execution_pb2 import flyteidl.admin.task_execution_pb2 as _task_execution_pb2 @@ -178,6 +179,8 @@ def __init__( security_context: Optional[security.SecurityContext] = None, overwrite_cache: Optional[bool] = None, envs: Optional[_common_models.Envs] = None, + tags: Optional[typing.List[str]] = None, + cluster_assignment: Optional[ClusterAssignment] = None, ): """ :param flytekit.models.core.identifier.Identifier launch_plan: Launch plan unique identifier to execute @@ -194,6 +197,7 @@ def __init__( :param security_context: Optional security context to use for this execution. :param overwrite_cache: Optional flag to overwrite the cache for this execution. :param envs: flytekit.models.common.Envs environment variables to set for this execution. + :param tags: Optional list of tags to apply to the execution. """ self._launch_plan = launch_plan self._metadata = metadata @@ -207,6 +211,8 @@ def __init__( self._security_context = security_context self._overwrite_cache = overwrite_cache self._envs = envs + self._tags = tags + self._cluster_assignment = cluster_assignment @property def launch_plan(self): @@ -281,6 +287,14 @@ def overwrite_cache(self) -> Optional[bool]: def envs(self) -> Optional[_common_models.Envs]: return self._envs + @property + def tags(self) -> Optional[typing.List[str]]: + return self._tags + + @property + def cluster_assignment(self) -> Optional[ClusterAssignment]: + return self._cluster_assignment + def to_flyte_idl(self): """ :rtype: flyteidl.admin.execution_pb2.ExecutionSpec @@ -300,6 +314,8 @@ def to_flyte_idl(self): security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, envs=self.envs.to_flyte_idl() if self.envs else None, + tags=self.tags, + cluster_assignment=self._cluster_assignment.to_flyte_idl() if self._cluster_assignment else None, ) @classmethod @@ -325,8 +341,43 @@ def from_flyte_idl(cls, p): else None, overwrite_cache=p.overwrite_cache, envs=_common_models.Envs.from_flyte_idl(p.envs) if p.HasField("envs") else None, + tags=p.tags, + cluster_assignment=ClusterAssignment.from_flyte_idl(p.cluster_assignment) + if p.HasField("cluster_assignment") + else None, + ) + + +class ClusterAssignment(_common_models.FlyteIdlEntity): + def __init__(self, cluster_pool=None): + """ + :param Text cluster_pool: + """ + self._cluster_pool = cluster_pool + + @property + def cluster_pool(self): + """ + :rtype: Text + """ + return self._cluster_pool + + def to_flyte_idl(self): + """ + :rtype: flyteidl.admin._cluster_assignment_pb2.ClusterAssignment + """ + return _cluster_assignment_pb2.ClusterAssignment( + cluster_pool_name=self._cluster_pool, ) + @classmethod + def from_flyte_idl(cls, p): + """ + :param flyteidl.admin._cluster_assignment_pb2.ClusterAssignment p: + :rtype: flyteidl.admin.ClusterAssignment + """ + return cls(cluster_pool=p.cluster_pool_name) + class LiteralMapBlob(_common_models.FlyteIdlEntity): def __init__(self, values=None, uri=None): diff --git a/flytekit/remote/__init__.py b/flytekit/remote/__init__.py index 4d6f172586..dd92a813f2 100644 --- a/flytekit/remote/__init__.py +++ b/flytekit/remote/__init__.py @@ -20,6 +20,7 @@ FlyteRemote(private_key=your_private_key_bytes, root_certificates=..., certificate_chain=...) # fetch a workflow from the flyte backend + remote = FlyteRemote(...) flyte_workflow = remote.fetch_workflow(name="my_workflow", version="v1") # execute the workflow, wait=True will return the execution object after it's completed diff --git a/flytekit/remote/backfill.py b/flytekit/remote/backfill.py index 2f31889060..b36fc7919d 100644 --- a/flytekit/remote/backfill.py +++ b/flytekit/remote/backfill.py @@ -5,7 +5,7 @@ from croniter import croniter from flytekit import LaunchPlan -from flytekit.core.workflow import ImperativeWorkflow, WorkflowBase +from flytekit.core.workflow import ImperativeWorkflow, WorkflowBase, WorkflowFailurePolicy from flytekit.remote.entities import FlyteLaunchPlan @@ -16,6 +16,7 @@ def create_backfill_workflow( parallel: bool = False, per_node_timeout: timedelta = None, per_node_retries: int = 0, + failure_policy: typing.Optional[WorkflowFailurePolicy] = None, ) -> typing.Tuple[WorkflowBase, datetime, datetime]: """ Generates a new imperative workflow for the launchplan that can be used to backfill the given launchplan. @@ -46,6 +47,7 @@ def create_backfill_workflow( :param parallel: if the backfill should be run in parallel. False (default) will run each bacfill sequentially :param per_node_timeout: timedelta Timeout to use per node :param per_node_retries: int Retries to user per node + :param failure_policy: WorkflowFailurePolicy Failure policy to use for the backfill workflow :return: WorkflowBase, datetime datetime -> New generated workflow, datetime for first instance of backfill, datetime for last instance of backfill """ if not for_lp: @@ -66,8 +68,11 @@ def create_backfill_workflow( else: raise NotImplementedError("Currently backfilling only supports cron schedules.") - logging.info(f"Generating backfill from {start_date} -> {end_date}. Parallel?[{parallel}]") - wf = ImperativeWorkflow(name=f"backfill-{for_lp.name}") + logging.info( + f"Generating backfill from {start_date} -> {end_date}. " + f"Parallel?[{parallel}] FailurePolicy[{str(failure_policy)}]" + ) + wf = ImperativeWorkflow(name=f"backfill-{for_lp.name}", failure_policy=failure_policy) input_name = schedule.kickoff_time_input_arg date_iter = croniter(cron_schedule.schedule, start_time=start_date, ret_type=datetime) diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index fcca6cf151..dae1b4bc35 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -1,5 +1,7 @@ -"""This module contains shadow entities for all Flyte entities as represented in Flyte Admin / Control Plane. -The goal is to enable easy access, manipulation of these entities. """ +""" +This module contains shadow entities for all Flyte entities as represented in Flyte Admin / Control Plane. +The goal is to enable easy access, manipulation of these entities. +""" from __future__ import annotations from typing import Dict, List, Optional, Tuple, Union @@ -67,6 +69,7 @@ def __init__( def id(self): """ This is generated by the system and uniquely identifies the task. + :rtype: flytekit.models.core.identifier.Identifier """ return self.template.id @@ -75,6 +78,7 @@ def id(self): def type(self): """ This is used to identify additional extensions for use by Propeller or SDK. + :rtype: Text """ return self.template.type @@ -84,6 +88,7 @@ def metadata(self): """ This contains information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts, and retries. + :rtype: TaskMetadata """ return self.template.metadata @@ -92,6 +97,7 @@ def metadata(self): def interface(self): """ The interface definition for this task. + :rtype: flytekit.models.interface.TypedInterface """ return self.template.interface @@ -100,6 +106,7 @@ def interface(self): def custom(self): """ Arbitrary dictionary containing metadata for custom plugins. + :rtype: dict[Text, T] """ return self.template.custom @@ -112,6 +119,7 @@ def task_type_version(self): def container(self): """ If not None, the target of execution should be a container. + :rtype: Container """ return self.template.container @@ -120,6 +128,7 @@ def container(self): def config(self): """ Arbitrary dictionary containing metadata for parsing and handling custom plugins. + :rtype: dict[Text, T] """ return self.template.config @@ -171,9 +180,7 @@ def promote_from_model(cls, base_model: _task_model.TaskTemplate) -> FlyteTask: class FlyteTaskNode(_workflow_model.TaskNode): - """ - A class encapsulating a task that a Flyte node needs to execute. - """ + """A class encapsulating a task that a Flyte node needs to execute.""" def __init__(self, flyte_task: FlyteTask): super(FlyteTaskNode, self).__init__(None) @@ -181,9 +188,7 @@ def __init__(self, flyte_task: FlyteTask): @property def reference_id(self) -> id_models.Identifier: - """ - A globally unique identifier for the task. - """ + """A globally unique identifier for the task.""" return self._flyte_task.id @property @@ -193,8 +198,8 @@ def flyte_task(self) -> FlyteTask: @classmethod def promote_from_model(cls, task: FlyteTask) -> FlyteTaskNode: """ - Takes the idl wrapper for a TaskNode and returns the hydrated Flytekit object for it by fetching it with the - FlyteTask control plane. + Takes the idl wrapper for a TaskNode, + and returns the hydrated Flytekit object for it by fetching it with the FlyteTask control plane. """ return cls(flyte_task=task) @@ -310,7 +315,6 @@ def promote_from_model( tasks: Dict[id_models.Identifier, FlyteTask], converted_sub_workflows: Dict[id_models.Identifier, FlyteWorkflow], ) -> Tuple[FlyteBranchNode, Dict[id_models.Identifier, FlyteWorkflow]]: - block = base_model.if_else block.case._then_node, converted_sub_workflows = FlyteNode.promote_from_model( block.case.then_node, @@ -342,6 +346,12 @@ def promote_from_model(cls, model: _workflow_model.GateNode): return cls(model.signal, model.sleep, model.approve) +class FlyteArrayNode(_workflow_model.ArrayNode): + @classmethod + def promote_from_model(cls, model: _workflow_model.ArrayNode): + return cls(model._parallelism, model._node, model._min_success_ratio, model._min_successes) + + class FlyteNode(_hash_mixin.HashOnReferenceMixin, _workflow_model.Node): """A class encapsulating a remote Flyte node.""" @@ -355,10 +365,11 @@ def __init__( workflow_node: Optional[FlyteWorkflowNode] = None, branch_node: Optional[FlyteBranchNode] = None, gate_node: Optional[FlyteGateNode] = None, + array_node: Optional[FlyteArrayNode] = None, ): - if not task_node and not workflow_node and not branch_node and not gate_node: + if not task_node and not workflow_node and not branch_node and not gate_node and not array_node: raise _user_exceptions.FlyteAssertion( - "An Flyte node must have one of task|workflow|branch|gate entity specified at once" + "An Flyte node must have one of task|workflow|branch|gate|array entity specified at once" ) # TODO: Revisit flyte_branch_node and flyte_gate_node, should they be another type like Condition instead # of a node? @@ -368,7 +379,7 @@ def __init__( elif workflow_node: self._flyte_entity = workflow_node.flyte_workflow or workflow_node.flyte_launch_plan else: - self._flyte_entity = branch_node or gate_node + self._flyte_entity = branch_node or gate_node or array_node super(FlyteNode, self).__init__( id=id, @@ -380,6 +391,7 @@ def __init__( workflow_node=workflow_node, branch_node=branch_node, gate_node=gate_node, + array_node=array_node, ) self._upstream = upstream_nodes @@ -427,7 +439,13 @@ def promote_from_model( remote_logger.warning(f"Should not call promote from model on a start node or end node {model}") return None, converted_sub_workflows - flyte_task_node, flyte_workflow_node, flyte_branch_node, flyte_gate_node = None, None, None, None + flyte_task_node, flyte_workflow_node, flyte_branch_node, flyte_gate_node, flyte_array_node = ( + None, + None, + None, + None, + None, + ) if model.task_node is not None: if model.task_node.reference_id not in tasks: raise RuntimeError( @@ -452,6 +470,9 @@ def promote_from_model( ) elif model.gate_node is not None: flyte_gate_node = FlyteGateNode.promote_from_model(model.gate_node) + elif model.array_node is not None: + flyte_array_node = FlyteArrayNode.promote_from_model(model.array_node) + # TODO: validate task in tasks else: raise _system_exceptions.FlyteSystemException( f"Bad Node model, neither task nor workflow detected, node: {model}" @@ -477,6 +498,7 @@ def promote_from_model( workflow_node=flyte_workflow_node, branch_node=flyte_branch_node, gate_node=flyte_gate_node, + array_node=flyte_array_node, ), converted_sub_workflows, ) @@ -654,7 +676,6 @@ def promote_from_model( tasks: Optional[Dict[Identifier, FlyteTask]] = None, node_launch_plans: Optional[Dict[Identifier, launch_plan_models.LaunchPlanSpec]] = None, ) -> FlyteWorkflow: - base_model_non_system_nodes = cls.get_non_system_nodes(base_model.nodes) node_map = {} @@ -712,8 +733,7 @@ def promote_from_closure( :param closure: This is the closure returned by Admin :param node_launch_plans: The reason this exists is because the compiled closure doesn't have launch plans. - It only has subworkflows and tasks. Why this is unclear. If supplied, this map of launch plans will be - :return: + It only has subworkflows and tasks. Why this is unclear. If supplied, this map of launch plans will be """ sub_workflows = {sw.template.id: sw.template for sw in closure.sub_workflows} tasks = {} diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 59cdc8c212..78fa3271e7 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -33,7 +33,7 @@ from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.reference_entity import ReferenceSpec from flytekit.core.type_engine import LiteralsResolver, TypeEngine -from flytekit.core.workflow import WorkflowBase +from flytekit.core.workflow import WorkflowBase, WorkflowFailurePolicy from flytekit.exceptions import user as user_exceptions from flytekit.exceptions.user import ( FlyteEntityAlreadyExistsException, @@ -54,12 +54,14 @@ from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier from flytekit.models.core.workflow import NodeMetadata from flytekit.models.execution import ( + ClusterAssignment, ExecutionMetadata, ExecutionSpec, NodeExecutionGetDataResponse, NotificationList, WorkflowExecutionGetDataResponse, ) +from flytekit.models.launch_plan import LaunchPlanState from flytekit.models.literals import Literal, LiteralMap from flytekit.remote.backfill import create_backfill_workflow from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow @@ -790,20 +792,24 @@ def upload_file( filename=to_upload.name, ) + extra_headers = self.get_extra_headers_for_protocol(upload_location.native_url) encoded_md5 = b64encode(md5_bytes) with open(str(to_upload), "+rb") as local_file: content = local_file.read() content_length = len(content) + headers = {"Content-Length": str(content_length), "Content-MD5": encoded_md5} + headers.update(extra_headers) rsp = requests.put( upload_location.signed_url, data=content, - headers={"Content-Length": str(content_length), "Content-MD5": encoded_md5}, + headers=headers, verify=False if self._config.platform.insecure_skip_verify is True else self._config.platform.ca_cert_file_path, ) - if rsp.status_code != requests.codes["OK"]: + # Check both HTTP 201 and 200, because some storage backends (e.g. Azure) return 201 instead of 200. + if rsp.status_code not in (requests.codes["OK"], requests.codes["created"]): raise FlyteValueException( rsp.status_code, f"Request to send data {upload_location.signed_url} failed.", @@ -949,12 +955,15 @@ def _execute( inputs: typing.Dict[str, typing.Any], project: str = None, domain: str = None, - execution_name: str = None, + execution_name: typing.Optional[str] = None, + execution_name_prefix: typing.Optional[str] = None, options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, + tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Common method for execution across all entities. @@ -970,9 +979,14 @@ def _execute( for a single execution. If enabled, all calculations are performed even if cached results would be available, overwriting the stored data once execution finishes successfully. :param envs: Environment variables to set for the execution. + :param tags: Tags to set for the execution. + :param cluster_pool: Specify cluster pool on which newly created execution should be placed. :returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution` """ - execution_name = execution_name or "f" + uuid.uuid4().hex[:19] + if execution_name is not None and execution_name_prefix is not None: + raise ValueError("Only one of execution_name and execution_name_prefix can be set, but got both set") + execution_name_prefix = execution_name_prefix + "-" if execution_name_prefix is not None else None + execution_name = execution_name or (execution_name_prefix or "f") + uuid.uuid4().hex[:19] if not options: options = Options() if options.disable_notifications is not None: @@ -1035,6 +1049,8 @@ def _execute( max_parallelism=options.max_parallelism, security_context=options.security_context, envs=common_models.Envs(envs) if envs else None, + tags=tags, + cluster_assignment=ClusterAssignment(cluster_pool=cluster_pool) if cluster_pool else None, ), literal_inputs, ) @@ -1085,13 +1101,16 @@ def execute( domain: str = None, name: str = None, version: str = None, - execution_name: str = None, + execution_name: typing.Optional[str] = None, + execution_name_prefix: typing.Optional[str] = None, image_config: typing.Optional[ImageConfig] = None, options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, + tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ Execute a task, workflow, or launchplan, either something that's been declared locally, or a fetched entity. @@ -1129,6 +1148,8 @@ def execute( for a single execution. If enabled, all calculations are performed even if cached results would be available, overwriting the stored data once execution finishes successfully. :param envs: Environment variables to be set for the execution. + :param tags: Tags to be set for the execution. + :param cluster_pool: Specify cluster pool on which newly created execution should be placed. .. note: @@ -1144,11 +1165,14 @@ def execute( project=project, domain=domain, execution_name=execution_name, + execution_name_prefix=execution_name_prefix, options=options, wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, envs=envs, + tags=tags, + cluster_pool=cluster_pool, ) if isinstance(entity, FlyteWorkflow): return self.execute_remote_wf( @@ -1157,11 +1181,14 @@ def execute( project=project, domain=domain, execution_name=execution_name, + execution_name_prefix=execution_name_prefix, options=options, wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, envs=envs, + tags=tags, + cluster_pool=cluster_pool, ) if isinstance(entity, PythonTask): return self.execute_local_task( @@ -1172,10 +1199,13 @@ def execute( name=name, version=version, execution_name=execution_name, + execution_name_prefix=execution_name_prefix, image_config=image_config, wait=wait, overwrite_cache=overwrite_cache, envs=envs, + tags=tags, + cluster_pool=cluster_pool, ) if isinstance(entity, WorkflowBase): return self.execute_local_workflow( @@ -1186,11 +1216,14 @@ def execute( name=name, version=version, execution_name=execution_name, + execution_name_prefix=execution_name_prefix, image_config=image_config, options=options, wait=wait, overwrite_cache=overwrite_cache, envs=envs, + tags=tags, + cluster_pool=cluster_pool, ) if isinstance(entity, LaunchPlan): return self.execute_local_launch_plan( @@ -1200,10 +1233,13 @@ def execute( project=project, domain=domain, execution_name=execution_name, + execution_name_prefix=execution_name_prefix, options=options, wait=wait, overwrite_cache=overwrite_cache, envs=envs, + tags=tags, + cluster_pool=cluster_pool, ) raise NotImplementedError(f"entity type {type(entity)} not recognized for execution") @@ -1216,12 +1252,15 @@ def execute_remote_task_lp( inputs: typing.Dict[str, typing.Any], project: str = None, domain: str = None, - execution_name: str = None, + execution_name: typing.Optional[str] = None, + execution_name_prefix: typing.Optional[str] = None, options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, + tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Execute a FlyteTask, or FlyteLaunchplan. @@ -1233,11 +1272,14 @@ def execute_remote_task_lp( project=project, domain=domain, execution_name=execution_name, + execution_name_prefix=execution_name_prefix, wait=wait, options=options, type_hints=type_hints, overwrite_cache=overwrite_cache, envs=envs, + tags=tags, + cluster_pool=cluster_pool, ) def execute_remote_wf( @@ -1246,12 +1288,15 @@ def execute_remote_wf( inputs: typing.Dict[str, typing.Any], project: str = None, domain: str = None, - execution_name: str = None, + execution_name: typing.Optional[str] = None, + execution_name_prefix: typing.Optional[str] = None, options: typing.Optional[Options] = None, wait: bool = False, type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None, overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, + tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """Execute a FlyteWorkflow. @@ -1264,11 +1309,14 @@ def execute_remote_wf( project=project, domain=domain, execution_name=execution_name, + execution_name_prefix=execution_name_prefix, options=options, wait=wait, type_hints=type_hints, overwrite_cache=overwrite_cache, envs=envs, + tags=tags, + cluster_pool=cluster_pool, ) # Flytekit Entities @@ -1282,11 +1330,14 @@ def execute_local_task( domain: str = None, name: str = None, version: str = None, - execution_name: str = None, + execution_name: typing.Optional[str] = None, + execution_name_prefix: typing.Optional[str] = None, image_config: typing.Optional[ImageConfig] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, + tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ Execute a @task-decorated function or TaskTemplate task. @@ -1302,6 +1353,8 @@ def execute_local_task( :param wait: If True, will wait for the execution to complete before returning. :param overwrite_cache: If True, will overwrite the cache. :param envs: Environment variables to set for the execution. + :param tags: Tags to set for the execution. + :param cluster_pool: Specify cluster pool on which newly created execution should be placed. :return: FlyteWorkflowExecution object. """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1326,10 +1379,13 @@ def execute_local_task( project=resolved_identifiers.project, domain=resolved_identifiers.domain, execution_name=execution_name, + execution_name_prefix=execution_name_prefix, wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, envs=envs, + tags=tags, + cluster_pool=cluster_pool, ) def execute_local_workflow( @@ -1340,12 +1396,15 @@ def execute_local_workflow( domain: str = None, name: str = None, version: str = None, - execution_name: str = None, + execution_name: typing.Optional[str] = None, + execution_name_prefix: typing.Optional[str] = None, image_config: typing.Optional[ImageConfig] = None, options: typing.Optional[Options] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, + tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ Execute an @workflow decorated function. @@ -1361,6 +1420,8 @@ def execute_local_workflow( :param wait: :param overwrite_cache: :param envs: + :param tags: + :param cluster_pool: :return: """ resolved_identifiers = self._resolve_identifier_kwargs(entity, project, domain, name, version) @@ -1402,11 +1463,14 @@ def execute_local_workflow( project=project, domain=domain, execution_name=execution_name, + execution_name_prefix=execution_name_prefix, wait=wait, options=options, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, envs=envs, + tags=tags, + cluster_pool=cluster_pool, ) def execute_local_launch_plan( @@ -1417,10 +1481,13 @@ def execute_local_launch_plan( project: typing.Optional[str] = None, domain: typing.Optional[str] = None, execution_name: typing.Optional[str] = None, + execution_name_prefix: typing.Optional[str] = None, options: typing.Optional[Options] = None, wait: bool = False, overwrite_cache: typing.Optional[bool] = None, envs: typing.Optional[typing.Dict[str, str]] = None, + tags: typing.Optional[typing.List[str]] = None, + cluster_pool: typing.Optional[str] = None, ) -> FlyteWorkflowExecution: """ @@ -1434,6 +1501,8 @@ def execute_local_launch_plan( :param wait: If True, will wait for the execution to complete before returning. :param overwrite_cache: If True, will overwrite the cache. :param envs: Environment variables to be passed into the execution. + :param tags: Tags to be passed into the execution. + :param cluster_pool: Specify cluster pool on which newly created execution should be placed. :return: FlyteWorkflowExecution object """ try: @@ -1456,11 +1525,14 @@ def execute_local_launch_plan( project=project, domain=domain, execution_name=execution_name, + execution_name_prefix=execution_name_prefix, options=options, wait=wait, type_hints=entity.python_interface.inputs, overwrite_cache=overwrite_cache, envs=envs, + tags=tags, + cluster_pool=cluster_pool, ) ################################### @@ -1593,16 +1665,16 @@ def sync_node_execution( Get data backing a node execution. These FlyteNodeExecution objects should've come from Admin with the model fields already populated correctly. For purposes of the remote experience, we'd like to supplement the object with some additional fields: - - inputs/outputs - - task/workflow executions, and/or underlying node executions in the case of parent nodes - - TypedInterface (remote wrapper type) + - inputs/outputs + - task/workflow executions, and/or underlying node executions in the case of parent nodes + - TypedInterface (remote wrapper type) A node can have several different types of executions behind it. That is, the node could've run (perhaps multiple times because of retries): - - A task - - A static subworkflow - - A dynamic subworkflow (which in turn may have run additional tasks, subwfs, and/or launch plans) - - A launch plan + - A task + - A static subworkflow + - A dynamic subworkflow (which in turn may have run additional tasks, subwfs, and/or launch plans) + - A launch plan The data model is complicated, so ascertaining which of these happened is a bit tricky. That logic is encapsulated in this function. @@ -1851,15 +1923,17 @@ def launch_backfill( dry_run: bool = False, execute: bool = True, parallel: bool = False, + failure_policy: typing.Optional[WorkflowFailurePolicy] = None, ) -> typing.Optional[FlyteWorkflowExecution, FlyteWorkflow, WorkflowBase]: """ Creates and launches a backfill workflow for the given launchplan. If launchplan version is not specified, then the latest launchplan is retrieved. - The from_date is exclusive and end_date is inclusive and backfill run for all instances in between. + The from_date is exclusive and end_date is inclusive and backfill run for all instances in between. :: -> (start_date - exclusive, end_date inclusive) - If dry_run is specified, the workflow is created and returned - if execute==False is specified then the workflow is created and registered - in the last case, the workflow is created, registered and executed. + + If dry_run is specified, the workflow is created and returned. + If execute==False is specified then the workflow is created and registered. + In the last case, the workflow is created, registered and executed. The `parallel` flag can be used to generate a workflow where all launchplans can be run in parallel. Default is that execute backfill is run sequentially @@ -1874,12 +1948,16 @@ def launch_backfill( :param version: str (optional) version to be used for the newly created workflow. :param dry_run: bool do not register or execute the workflow :param execute: bool Register and execute the wwkflow. - :param parallel: if the backfill should be run in parallel. False (default) will run each bacfill sequentially + :param parallel: if the backfill should be run in parallel. False (default) will run each bacfill sequentially. + :param failure_policy: WorkflowFailurePolicy (optional) to be used for the newly created workflow. This can + control failure behavior - whether to continue on failure or stop immediately on failure :return: In case of dry-run, return WorkflowBase, else if no_execute return FlyteWorkflow else in the default - case return a FlyteWorkflowExecution + case return a FlyteWorkflowExecution """ lp = self.fetch_launch_plan(project=project, domain=domain, name=launchplan, version=launchplan_version) - wf, start, end = create_backfill_workflow(start_date=from_date, end_date=to_date, for_lp=lp, parallel=parallel) + wf, start, end = create_backfill_workflow( + start_date=from_date, end_date=to_date, for_lp=lp, parallel=parallel, failure_policy=failure_policy + ) if dry_run: remote_logger.warning("Dry Run enabled. Workflow will not be registered and or executed.") return wf @@ -1902,3 +1980,15 @@ def launch_backfill( return remote_wf return self.execute(remote_wf, inputs={}, project=project, domain=domain, execution_name=execution_name) + + @staticmethod + def get_extra_headers_for_protocol(native_url): + if native_url.startswith("abfs://"): + return {"x-ms-blob-type": "BlockBlob"} + return {} + + def activate_launchplan(self, ident: Identifier): + """ + Given a launchplan, activate it, all previous versions are deactivated. + """ + self.client.update_launch_plan(id=ident, state=LaunchPlanState.ACTIVE) diff --git a/flytekit/sensor/__init__.py b/flytekit/sensor/__init__.py new file mode 100644 index 0000000000..796088e74d --- /dev/null +++ b/flytekit/sensor/__init__.py @@ -0,0 +1,3 @@ +from .base_sensor import BaseSensor +from .file_sensor import FileSensor +from .sensor_engine import SensorEngine diff --git a/flytekit/sensor/base_sensor.py b/flytekit/sensor/base_sensor.py new file mode 100644 index 0000000000..60beb6aa2b --- /dev/null +++ b/flytekit/sensor/base_sensor.py @@ -0,0 +1,66 @@ +import collections +import inspect +from abc import abstractmethod +from typing import Any, Dict, Optional, TypeVar + +import jsonpickle +from typing_extensions import get_type_hints + +from flytekit.configuration import SerializationSettings +from flytekit.core.base_task import PythonTask +from flytekit.core.interface import Interface +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin + +T = TypeVar("T") +SENSOR_MODULE = "sensor_module" +SENSOR_NAME = "sensor_name" +SENSOR_CONFIG_PKL = "sensor_config_pkl" +INPUTS = "inputs" + + +class BaseSensor(AsyncAgentExecutorMixin, PythonTask): + """ + Base class for all sensors. Sensors are tasks that are designed to run forever, and periodically check for some + condition to be met. When the condition is met, the sensor will complete. Sensors are designed to be run by the + sensor agent, and not by the Flyte engine. + """ + + def __init__( + self, + name: str, + sensor_config: Optional[T] = None, + task_type: str = "sensor", + **kwargs, + ): + type_hints = get_type_hints(self.poke, include_extras=True) + signature = inspect.signature(self.poke) + inputs = collections.OrderedDict() + for k, v in signature.parameters.items(): # type: ignore + annotation = type_hints.get(k, None) + inputs[k] = annotation + + super().__init__( + task_type=task_type, + name=name, + task_config=None, + interface=Interface(inputs=inputs), + **kwargs, + ) + self._sensor_config = sensor_config + + @abstractmethod + async def poke(self, **kwargs) -> bool: + """ + This method should be overridden by the user to implement the actual sensor logic. This method should return + ``True`` if the sensor condition is met, else ``False``. + """ + raise NotImplementedError + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + cfg = { + SENSOR_MODULE: type(self).__module__, + SENSOR_NAME: type(self).__name__, + } + if self._sensor_config is not None: + cfg[SENSOR_CONFIG_PKL] = jsonpickle.encode(self._sensor_config) + return cfg diff --git a/flytekit/sensor/file_sensor.py b/flytekit/sensor/file_sensor.py new file mode 100644 index 0000000000..2fb3d64ec1 --- /dev/null +++ b/flytekit/sensor/file_sensor.py @@ -0,0 +1,18 @@ +from typing import Optional, TypeVar + +from flytekit import FlyteContextManager +from flytekit.sensor.base_sensor import BaseSensor + +T = TypeVar("T") + + +class FileSensor(BaseSensor): + def __init__(self, name: str, config: Optional[T] = None, **kwargs): + super().__init__(name=name, sensor_config=config, **kwargs) + + async def poke(self, path: str) -> bool: + file_access = FlyteContextManager.current_context().file_access + fs = file_access.get_filesystem_for_path(path, asynchronous=True) + if file_access.is_remote(path): + return await fs._exists(path) + return fs.exists(path) diff --git a/flytekit/sensor/sensor_engine.py b/flytekit/sensor/sensor_engine.py new file mode 100644 index 0000000000..79d2e0f4b4 --- /dev/null +++ b/flytekit/sensor/sensor_engine.py @@ -0,0 +1,62 @@ +import importlib +import typing +from typing import Optional + +import cloudpickle +import grpc +import jsonpickle +from flyteidl.admin.agent_pb2 import ( + RUNNING, + SUCCEEDED, + CreateTaskResponse, + DeleteTaskResponse, + GetTaskResponse, + Resource, +) + +from flytekit import FlyteContextManager +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate +from flytekit.sensor.base_sensor import INPUTS, SENSOR_CONFIG_PKL, SENSOR_MODULE, SENSOR_NAME + +T = typing.TypeVar("T") + + +class SensorEngine(AgentBase): + def __init__(self): + super().__init__(task_type="sensor", asynchronous=True) + + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } + ctx = FlyteContextManager.current_context() + if inputs: + native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) + task_template.custom[INPUTS] = native_inputs + return CreateTaskResponse(resource_meta=cloudpickle.dumps(task_template.custom)) + + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + meta = cloudpickle.loads(resource_meta) + + sensor_module = importlib.import_module(name=meta[SENSOR_MODULE]) + sensor_def = getattr(sensor_module, meta[SENSOR_NAME]) + sensor_config = jsonpickle.decode(meta[SENSOR_CONFIG_PKL]) if meta.get(SENSOR_CONFIG_PKL) else None + + inputs = meta.get(INPUTS, {}) + cur_state = SUCCEEDED if await sensor_def("sensor", config=sensor_config).poke(**inputs) else RUNNING + return GetTaskResponse(resource=Resource(state=cur_state, outputs=None)) + + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + return DeleteTaskResponse() + + +AgentRegistry.register(SensorEngine()) diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index e55350d3ae..5473b4cef8 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -183,7 +183,7 @@ def load_packages_and_modules( return registrable_entities -def secho(i: Identifier, state: str = "success", reason: str = None): +def secho(i: Identifier, state: str = "success", reason: str = None, op: str = "Registration"): state_ind = "[ ]" fg = "white" nl = False @@ -198,7 +198,7 @@ def secho(i: Identifier, state: str = "success", reason: str = None): nl = True reason = "skipped!" click.secho( - click.style(f"{state_ind}", fg=fg) + f" Registration {i.name} type {i.resource_type_name()} {reason}", + click.style(f"{state_ind}", fg=fg) + f" {op} {i.name} type {i.resource_type_name()} {reason}", dim=True, nl=nl, ) @@ -218,6 +218,7 @@ def register( package_or_module: typing.Tuple[str], remote: FlyteRemote, dry_run: bool = False, + activate_launchplans: bool = False, ): detected_root = find_common_root(package_or_module) click.secho(f"Detected Root {detected_root}, using this to create deployable package...", fg="yellow") @@ -262,7 +263,12 @@ def register( return for cp_entity in registrable_entities: - og_id = cp_entity.id if isinstance(cp_entity, launch_plan.LaunchPlan) else cp_entity.template.id + is_lp = False + if isinstance(cp_entity, launch_plan.LaunchPlan): + og_id = cp_entity.id + is_lp = True + else: + og_id = cp_entity.template.id secho(og_id, "") try: if not dry_run: @@ -270,6 +276,10 @@ def register( cp_entity, serialization_settings, version=version, create_default_launchplan=False ) secho(i) + if is_lp and activate_launchplans: + secho(og_id, "", op="Activation") + remote.activate_launchplan(i) + secho(i, reason="activated", op="Activation") else: secho(og_id, reason="Dry run Mode!") except RegistrationSkipped: diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index b2835dca10..87ccd2f534 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -7,6 +7,7 @@ from flytekit import PythonFunctionTask, SourceCode from flytekit.configuration import SerializationSettings from flytekit.core import constants as _common_constants +from flytekit.core.array_node_map_task import ArrayNodeMapTask from flytekit.core.base_task import PythonTask from flytekit.core.condition import BranchNode from flytekit.core.container_task import ContainerTask @@ -30,6 +31,7 @@ from flytekit.models.core import workflow as _core_wf from flytekit.models.core import workflow as workflow_model from flytekit.models.core.workflow import ApproveCondition +from flytekit.models.core.workflow import ArrayNode as ArrayNodeModel from flytekit.models.core.workflow import BranchNode as BranchNodeModel from flytekit.models.core.workflow import GateNode, SignalCondition, SleepCondition, TaskNodeOverrides from flytekit.models.task import TaskSpec, TaskTemplate @@ -51,6 +53,7 @@ admin_workflow_models.WorkflowSpec, workflow_model.Node, BranchNodeModel, + ArrayNodeModel, ] @@ -67,11 +70,11 @@ class Options(object): annotations: Custom annotations to be applied to the execution resource security_context: Indicates security context for permissions triggered with this launch plan raw_output_data_config: Optional location of offloaded data for things like S3, etc. - remote prefix for storage location of the form ``s3:///key...`` or - ``gcs://...`` or ``file://...``. If not specified will use the platform configured default. This is where - the data for offloaded types is stored. + remote prefix for storage location of the form ``s3:///key...`` or + ``gcs://...`` or ``file://...``. If not specified will use the platform configured default. This is where + the data for offloaded types is stored. max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. - notifications: List of notifications for this execution + notifications: List of notifications for this execution. disable_notifications: This should be set to true if all notifications are intended to be disabled for this execution. """ @@ -180,7 +183,7 @@ def get_serializable_task( if settings.should_fast_serialize(): # This handles container tasks. - if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask)): + if container and isinstance(entity, (PythonAutoContainerTask, MapPythonTask, ArrayNodeMapTask)): # For fast registration, we'll need to muck with the command, but on # ly for certain kinds of tasks. Specifically, # tasks that rely on user code defined in the container. This should be encapsulated by the auto container @@ -191,7 +194,7 @@ def get_serializable_task( # The reason we have to call get_k8s_pod again, instead of just modifying the command in this file, is because # the pod spec is a K8s library object, and we shouldn't be messing around with it in this file. elif pod and not isinstance(entity, ContainerTask): - if isinstance(entity, MapPythonTask): + if isinstance(entity, (MapPythonTask, ArrayNodeMapTask)): entity.set_command_prefix(get_command_prefix_for_fast_execute(settings)) pod = entity.get_k8s_pod(settings) else: @@ -416,7 +419,19 @@ def get_serializable_node( from flytekit.remote import FlyteLaunchPlan, FlyteTask, FlyteWorkflow - if isinstance(entity.flyte_entity, PythonTask): + if isinstance(entity.flyte_entity, ArrayNodeMapTask): + node_model = workflow_model.Node( + id=_dnsify(entity.id), + metadata=entity.metadata, + inputs=entity.bindings, + upstream_node_ids=[n.id for n in upstream_nodes], + output_aliases=[], + array_node=get_serializable_array_node(entity_mapping, settings, entity, options=options), + ) + # TODO: do I need this? + # if entity._aliases: + # node_model._output_aliases = entity._aliases + elif isinstance(entity.flyte_entity, PythonTask): task_spec = get_serializable(entity_mapping, settings, entity.flyte_entity, options=options) node_model = workflow_model.Node( id=_dnsify(entity.id), @@ -539,6 +554,35 @@ def get_serializable_node( return node_model +def get_serializable_array_node( + entity_mapping: OrderedDict, + settings: SerializationSettings, + node: Node, + options: Optional[Options] = None, +) -> ArrayNodeModel: + # TODO Add support for other flyte entities + entity = node.flyte_entity + task_spec = get_serializable(entity_mapping, settings, entity, options) + task_node = workflow_model.TaskNode( + reference_id=task_spec.template.id, + overrides=TaskNodeOverrides(resources=node._resources), + ) + node = workflow_model.Node( + id=entity.name, + metadata=entity.construct_node_metadata(), + inputs=node.bindings, + upstream_node_ids=[], + output_aliases=[], + task_node=task_node, + ) + return ArrayNodeModel( + node=node, + parallelism=entity.concurrency, + min_successes=entity.min_successes, + min_success_ratio=entity.min_success_ratio, + ) + + def get_serializable_branch_node( entity_mapping: OrderedDict, settings: SerializationSettings, diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index f4f23eb72f..fe0e7cfa7c 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -10,12 +10,12 @@ from uuid import UUID import fsspec -from dataclasses_json import config, dataclass_json +from dataclasses_json import DataClassJsonMixin, config from fsspec.utils import get_protocol from marshmallow import fields from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer +from flytekit.core.type_engine import TypeEngine, TypeTransformer, get_batch_size from flytekit.models import types as _type_models from flytekit.models.core import types as _core_types from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar @@ -30,9 +30,8 @@ def noop(): ... -@dataclass_json @dataclass -class FlyteDirectory(os.PathLike, typing.Generic[T]): +class FlyteDirectory(DataClassJsonMixin, os.PathLike, typing.Generic[T]): path: PathType = field(default=None, metadata=config(mm_field=fields.String())) # type: ignore """ .. warning:: @@ -322,6 +321,8 @@ def to_literal( remote_directory = None should_upload = True + batch_size = get_batch_size(python_type) + meta = BlobMetadata(type=self._blob_type(format=self.get_format(python_type))) # There are two kinds of literals we handle, either an actual FlyteDirectory, or a string path to a directory. @@ -358,7 +359,7 @@ def to_literal( if should_upload: if remote_directory is None: remote_directory = ctx.file_access.get_random_remote_directory() - ctx.file_access.put_data(source_path, remote_directory, is_multipart=True) + ctx.file_access.put_data(source_path, remote_directory, is_multipart=True, batch_size=batch_size) return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_directory))) # If not uploading, then we can only take the original source path as the uri. @@ -379,8 +380,10 @@ def to_python_value( # For the remote case, return an FlyteDirectory object that can download local_folder = ctx.file_access.get_random_local_directory() + batch_size = get_batch_size(expected_python_type) + def _downloader(): - return ctx.file_access.get_data(uri, local_folder, is_multipart=True) + return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size) expected_format = self.get_format(expected_python_type) diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index d78ec152d7..5928c7a377 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -6,8 +6,9 @@ from contextlib import contextmanager from dataclasses import dataclass, field -from dataclasses_json import config, dataclass_json +from dataclasses_json import config from marshmallow import fields +from mashumaro.mixins.json import DataClassJSONMixin from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type @@ -25,9 +26,8 @@ def noop(): T = typing.TypeVar("T") -@dataclass_json @dataclass -class FlyteFile(os.PathLike, typing.Generic[T]): +class FlyteFile(os.PathLike, typing.Generic[T], DataClassJSONMixin): path: typing.Union[str, os.PathLike] = field( default=None, metadata=config(mm_field=fields.String()) ) # type: ignore @@ -76,8 +76,6 @@ class FlyteFile(os.PathLike, typing.Generic[T]): Flyte blob store. So no remote paths are uploaded. Flytekit considers a path remote if it starts with ``s3://``, ``gs://``, ``http(s)://``, or even ``file://``. - ----------- - **Converting from a Flyte literal value to a Python instance of FlyteFile** +-------------+---------------+---------------------------------------------+--------------------------------------+ @@ -107,8 +105,6 @@ class FlyteFile(os.PathLike, typing.Generic[T]): | | | * remote_source: None | | +-------------+---------------+---------------------------------------------+--------------------------------------+ - ----------- - **Converting from a Python value (FlyteFile, str, or pathlib.Path) to a Flyte literal** +-------------+---------------+---------------------------------------------+--------------------------------------+ @@ -189,7 +185,9 @@ def __init__( remote_path: typing.Optional[os.PathLike] = None, ): """ - :param path: The source path that users are expected to call open() on + FlyteFile's init method. + + :param path: The source path that users are expected to call open() on. :param downloader: Optional function that can be passed that used to delay downloading of the actual fil until a user actually calls open(). :param remote_path: If the user wants to return something and also specify where it should be uploaded to. @@ -258,7 +256,7 @@ def copy_file(ff: FlyteFile) -> FlyteFile: w.write(r.read()) return new_file - Alternatively + Alternatively, .. code-block:: python @@ -273,10 +271,10 @@ def copy_file(ff: FlyteFile) -> FlyteFile: :param mode: str Open mode like 'rb', 'rt', 'wb', ... :param cache_type: optional str Specify if caching is to be used. Cache protocol can be ones supported by - fsspec https://filesystem-spec.readthedocs.io/en/latest/api.html#readbuffering, - especially useful for large file reads + fsspec https://filesystem-spec.readthedocs.io/en/latest/api.html#readbuffering, + especially useful for large file reads :param cache_options: optional Dict[str, Any] Refer to fsspec caching options. This is strongly coupled to the - cache_protocol + cache_protocol """ ctx = FlyteContextManager.current_context() final_path = self.path @@ -406,6 +404,7 @@ def to_python_value( uri = lv.scalar.blob.uri except AttributeError: raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + # In this condition, we still return a FlyteFile instance, but it's a simple one that has no downloading tricks # Using is instead of issubclass because FlyteFile does actually subclass it if expected_python_type is os.PathLike: diff --git a/flytekit/types/schema/types.py b/flytekit/types/schema/types.py index 7ac98d27c6..bba099a57e 100644 --- a/flytekit/types/schema/types.py +++ b/flytekit/types/schema/types.py @@ -11,8 +11,9 @@ import numpy as _np import pandas -from dataclasses_json import config, dataclass_json +from dataclasses_json import config from marshmallow import fields +from mashumaro.mixins.json import DataClassJSONMixin from flytekit.core.context_manager import FlyteContext, FlyteContextManager from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError @@ -179,9 +180,8 @@ def get_handler(cls, t: Type) -> SchemaHandler: return cls._SCHEMA_HANDLERS[t] -@dataclass_json @dataclass -class FlyteSchema(object): +class FlyteSchema(DataClassJSONMixin): remote_path: typing.Optional[str] = field(default=None, metadata=config(mm_field=fields.String())) """ This is the main schema class that users should use. diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 98a12ae44d..2161c5b58a 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -12,7 +12,7 @@ from flytekit import FlyteContext, logger from flytekit.configuration import DataConfig -from flytekit.core.data_persistence import s3_setup_args +from flytekit.core.data_persistence import get_fsspec_storage_options from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType @@ -27,12 +27,13 @@ T = TypeVar("T") -def get_storage_options(cfg: DataConfig, uri: str, anon: bool = False) -> typing.Optional[typing.Dict]: - protocol = get_protocol(uri) - if protocol == "s3": - kwargs = s3_setup_args(cfg.s3, anon) - if kwargs: - return kwargs +def get_pandas_storage_options( + uri: str, data_config: DataConfig, anonymous: bool = False +) -> typing.Optional[typing.Dict]: + if pd.io.common.is_fsspec_url(uri): + return get_fsspec_storage_options(protocol=get_protocol(uri), data_config=data_config, anonymous=anonymous) + + # Pandas does not allow storage_options for non-fsspec paths e.g. local. return None @@ -54,7 +55,7 @@ def encode( df.to_csv( path, index=False, - storage_options=get_storage_options(ctx.file_access.data_config, path), + storage_options=get_pandas_storage_options(uri=path, data_config=ctx.file_access.data_config), ) structured_dataset_type.format = CSV return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -72,7 +73,7 @@ def decode( ) -> pd.DataFrame: uri = flyte_value.uri columns = None - kwargs = get_storage_options(ctx.file_access.data_config, uri) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config) path = os.path.join(uri, ".csv") if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] @@ -80,7 +81,7 @@ def decode( return pd.read_csv(path, usecols=columns, storage_options=kwargs) except NoCredentialsError: logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config, anonymous=True) return pd.read_csv(path, usecols=columns, storage_options=kwargs) @@ -103,7 +104,7 @@ def encode( path, coerce_timestamps="us", allow_truncated_timestamps=False, - storage_options=get_storage_options(ctx.file_access.data_config, path), + storage_options=get_pandas_storage_options(uri=path, data_config=ctx.file_access.data_config), ) structured_dataset_type.format = PARQUET return literals.StructuredDataset(uri=uri, metadata=StructuredDatasetMetadata(structured_dataset_type)) @@ -121,14 +122,14 @@ def decode( ) -> pd.DataFrame: uri = flyte_value.uri columns = None - kwargs = get_storage_options(ctx.file_access.data_config, uri) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] try: return pd.read_parquet(uri, columns=columns, storage_options=kwargs) except NoCredentialsError: logger.debug("S3 source detected, attempting anonymous S3 access") - kwargs = get_storage_options(ctx.file_access.data_config, uri, anon=True) + kwargs = get_pandas_storage_options(uri=uri, data_config=ctx.file_access.data_config, anonymous=True) return pd.read_parquet(uri, columns=columns, storage_options=kwargs) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 5a4ef43d1a..99a0e0832b 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -8,9 +8,10 @@ from typing import Dict, Generator, Optional, Type, Union import _datetime -from dataclasses_json import config, dataclass_json +from dataclasses_json import config from fsspec.utils import get_protocol from marshmallow import fields +from mashumaro.mixins.json import DataClassJSONMixin from typing_extensions import Annotated, TypeAlias, get_args, get_origin from flytekit import lazy_module @@ -43,9 +44,8 @@ GENERIC_PROTOCOL: str = "generic protocol" -@dataclass_json @dataclass -class StructuredDataset(object): +class StructuredDataset(DataClassJSONMixin): """ This is the user facing StructuredDataset class. Please don't confuse it with the literals.StructuredDataset class (that is just a model, a Python class representation of the protobuf). diff --git a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py index 0e67b2e50b..e0326f112b 100644 --- a/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py +++ b/plugins/flytekit-aws-batch/flytekitplugins/awsbatch/task.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from google.protobuf import json_format from google.protobuf.struct_pb2 import Struct @@ -10,9 +10,8 @@ from flytekit.extend import TaskPlugins -@dataclass_json @dataclass -class AWSBatchConfig(object): +class AWSBatchConfig(DataClassJsonMixin): """ Use this to configure SubmitJobInput for a AWS batch job. Task's marked with this will automatically execute natively onto AWS batch service. @@ -27,7 +26,7 @@ class AWSBatchConfig(object): def to_dict(self): s = Struct() - s.update(self.to_dict()) + s.update(super().to_dict()) return json_format.MessageToDict(s) diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py index 345c4ff8ff..6dce099dfb 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/__init__.py @@ -24,7 +24,6 @@ InputMode IntegerParameterRange ParameterRangeOneOf - SagemakerBuiltinAlgorithmsTask SagemakerCustomTrainingTask SagemakerHPOTask SagemakerTrainingJobConfig diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py index 0df8f42dba..738f1820a2 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/models/parameter_ranges.py @@ -233,8 +233,9 @@ class ParameterRangeOneOf(_common.FlyteIdlEntity): def __init__(self, param: Union[IntegerParameterRange, ContinuousParameterRange, CategoricalParameterRange]): """ Initializes a new ParameterRangeOneOf. + :param Union[IntegerParameterRange, ContinuousParameterRange, CategoricalParameterRange] param: One of the - supported parameter ranges. + supported parameter ranges. """ self._integer_parameter_range = param if isinstance(param, IntegerParameterRange) else None self._continuous_parameter_range = param if isinstance(param, ContinuousParameterRange) else None diff --git a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py index 9951758a42..7f456d19a0 100644 --- a/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py +++ b/plugins/flytekit-aws-sagemaker/flytekitplugins/awssagemaker/training.py @@ -27,9 +27,9 @@ class SagemakerTrainingJobConfig(object): training_job_resource_config: Configuration for Resources to use during the training algorithm_specification: Specification of the algorithm to use should_persist_output: This method will be invoked and will decide if the generated model should be persisted - as the output. ``NOTE: Useful only for distributed training`` - ``default: single node training - always persist output`` - ``default: distributed training - always persist output on node with rank-0`` + as the output. ``NOTE: Useful only for distributed training`` + ``default: single node training - always persist output`` + ``default: distributed training - always persist output on node with rank-0`` """ training_job_resource_config: _training_job_models.TrainingJobResourceConfig diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py index 14ff977ee6..4ddb26cdfd 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py @@ -43,7 +43,7 @@ class Metadata: class BigQueryAgent(AgentBase): def __init__(self): - super().__init__(task_type="bigquery_query_job_task") + super().__init__(task_type="bigquery_query_job_task", asynchronous=False) def create( self, diff --git a/plugins/flytekit-bigquery/tests/test_agent.py b/plugins/flytekit-bigquery/tests/test_agent.py index 16b5b7af4d..af53f4031d 100644 --- a/plugins/flytekit-bigquery/tests/test_agent.py +++ b/plugins/flytekit-bigquery/tests/test_agent.py @@ -44,7 +44,7 @@ def __init__(self): mock_instance.cancel_job.return_value = MockJob() ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent(ctx, "bigquery_query_job_task") + agent = AgentRegistry.get_agent("bigquery_query_job_task") task_id = Identifier( resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" diff --git a/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py b/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py index 3634118b38..6163e440b1 100644 --- a/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py +++ b/plugins/flytekit-dbt/flytekitplugins/dbt/schema.py @@ -2,12 +2,11 @@ from dataclasses import dataclass from typing import List, Optional -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin -@dataclass_json @dataclass -class BaseDBTInput: +class BaseDBTInput(DataClassJsonMixin): """ Base class for DBT Task Input. @@ -76,9 +75,8 @@ def to_args(self) -> List[str]: return args -@dataclass_json @dataclass -class BaseDBTOutput: +class BaseDBTOutput(DataClassJsonMixin): """ Base class for output of DBT task. @@ -94,7 +92,6 @@ class BaseDBTOutput: exit_code: int -@dataclass_json @dataclass class DBTRunInput(BaseDBTInput): """ @@ -131,7 +128,6 @@ def to_args(self) -> List[str]: return args -@dataclass_json @dataclass class DBTRunOutput(BaseDBTOutput): """ @@ -149,7 +145,6 @@ class DBTRunOutput(BaseDBTOutput): raw_manifest: str -@dataclass_json @dataclass class DBTTestInput(BaseDBTInput): """ @@ -187,7 +182,6 @@ def to_args(self) -> List[str]: return args -@dataclass_json @dataclass class DBTTestOutput(BaseDBTOutput): """ @@ -205,7 +199,6 @@ class DBTTestOutput(BaseDBTOutput): raw_manifest: str -@dataclass_json @dataclass class DBTFreshnessInput(BaseDBTInput): """ @@ -243,7 +236,6 @@ def to_args(self) -> List[str]: return args -@dataclass_json @dataclass class DBTFreshnessOutput(BaseDBTOutput): """ diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py index 6752b9b84d..279adb08dd 100644 --- a/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/__init__.py @@ -10,6 +10,8 @@ BoxRenderer FrameProfilingRenderer MarkdownRenderer + ImageRenderer + TableRenderer """ -from .renderer import BoxRenderer, FrameProfilingRenderer, MarkdownRenderer +from .renderer import BoxRenderer, FrameProfilingRenderer, ImageRenderer, MarkdownRenderer, TableRenderer diff --git a/plugins/flytekit-dolt/flytekitplugins/dolt/schema.py b/plugins/flytekit-dolt/flytekitplugins/dolt/schema.py index 8f6867b47f..b5832557ba 100644 --- a/plugins/flytekit-dolt/flytekitplugins/dolt/schema.py +++ b/plugins/flytekit-dolt/flytekitplugins/dolt/schema.py @@ -6,7 +6,7 @@ import dolt_integrations.core as dolt_int import doltcli as dolt import pandas -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from google.protobuf.json_format import MessageToDict from google.protobuf.struct_pb2 import Struct @@ -17,9 +17,8 @@ from flytekit.models.types import LiteralType -@dataclass_json @dataclass -class DoltConfig: +class DoltConfig(DataClassJsonMixin): db_path: str tablename: typing.Optional[str] = None sql: typing.Optional[str] = None @@ -29,9 +28,8 @@ class DoltConfig: remote_conf: typing.Optional[dolt_int.Remote] = None -@dataclass_json @dataclass -class DoltTable: +class DoltTable(DataClassJsonMixin): config: DoltConfig data: typing.Optional[pandas.DataFrame] = None diff --git a/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py b/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py index 39d2758417..c57af980a0 100644 --- a/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py +++ b/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/schema.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import great_expectations as ge -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from great_expectations.checkpoint import SimpleCheckpoint from great_expectations.core.run_identifier import RunIdentifier from great_expectations.core.util import convert_to_json_serializable @@ -23,9 +23,8 @@ from .task import BatchRequestConfig -@dataclass_json @dataclass -class GreatExpectationsFlyteConfig(object): +class GreatExpectationsFlyteConfig(DataClassJsonMixin): """ Use this configuration to configure GreatExpectations Plugin. diff --git a/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/task.py b/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/task.py index 185ec20daa..8fe53e1e95 100644 --- a/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/task.py +++ b/plugins/flytekit-greatexpectations/flytekitplugins/great_expectations/task.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional, Type, Union import great_expectations as ge -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from great_expectations.checkpoint import SimpleCheckpoint from great_expectations.core.run_identifier import RunIdentifier from great_expectations.core.util import convert_to_json_serializable @@ -19,9 +19,8 @@ from flytekit.types.schema import FlyteSchema -@dataclass_json @dataclass -class BatchRequestConfig(object): +class BatchRequestConfig(DataClassJsonMixin): """ Use this configuration to configure Batch Request. A BatchRequest can either be a simple BatchRequest or a RuntimeBatchRequest. diff --git a/plugins/flytekit-hive/flytekitplugins/hive/task.py b/plugins/flytekit-hive/flytekitplugins/hive/task.py index 2781275c86..76835a8d77 100644 --- a/plugins/flytekit-hive/flytekitplugins/hive/task.py +++ b/plugins/flytekit-hive/flytekitplugins/hive/task.py @@ -124,7 +124,7 @@ def __init__( Args: select_query: Singular query that returns a Tabular dataset stage_query: optional query that should be executed before the actual ``select_query``. This can usually - be used for setting memory or the an alternate execution engine like :ref:`tez`_/ + be used for setting memory or the an alternate execution engine like `tez `__ """ query_template = HiveSelectTask._HIVE_QUERY_FORMATTER.format( stage_query_str=stage_query or "", select_query_str=select_query.strip().strip(";") diff --git a/plugins/flytekit-identity-aware-proxy/README.md b/plugins/flytekit-identity-aware-proxy/README.md new file mode 100644 index 0000000000..c6c631707c --- /dev/null +++ b/plugins/flytekit-identity-aware-proxy/README.md @@ -0,0 +1,404 @@ +# Flytekit Identity Aware Proxy + +[GCP Identity Aware Proxy (IAP)](https://cloud.google.com/iap) is a managed Google Cloud Platform (GCP) service that makes it easy to protect applications deployed on GCP by verifying user identity and using context to determine whether a user should be granted access. Because requests to applications protected with IAP first have to pass IAP before they can reach the protected backends, IAP provides a convenient way to implement a zero-trust access model. + +This flytekit plugin allows users to generate ID tokens via an external command for use with Flyte deployments protected with IAP. A step by step guide to protect a Flyte deployment with IAP is provided as well. + +**Disclaimer: Do not choose this deployment path with the goal of *a* Flyte deployment configured with authentication on GCP. The deployment is more involved than the standard Flyte GCP deployment. Follow this guide if your organization has a security policy that requires the use of GCP Identity Aware Proxy.** + +## Configuring the token generation CLI provided by this plugin + +1. Install this plugin via `pip install flytekitplugins-identity-aware-proxy`. + + Verify the installation with `flyte-iap --help`. + +2. Create OAuth 2.0 credentials for both the token generation CLI and for IAP. + 1. [Desktop OAauth credentials](https://cloud.google.com/iap/docs/authentication-howto#authenticating_from_a_desktop_app) for this CLI: + + In the GCP cloud console navigate to *"Apis & Services" / "Credentials"* click *"Create Credentials"*, select "*OAuth Client ID*", and finally choose *“Desktop App”*. + + Note the client id and client secret. + + 2. Follow the instructions to [activate IAP](https://cloud.google.com/iap/docs/enabling-kubernetes-howto#enabling_iap) in your project and cluster. In the process you will create web application type OAuth credentials for IAP (similar as done above for the desktop application type credentials). Again, note the client id and client secret. Don't proceed with the instructions to create the Kubernetes secret for these credentials and the backend config yet, this is done in the deployment guide below. Stop when you have the client id and secret. + + Note: In case you have an existing [Flyte deployment with auth configured](https://docs.flyte.org/en/latest/deployment/configuration/auth_setup.html#apply-oidc-configuration), you likely already have web application type OAuth credentials. You can reuse those credentials for Flyte's IAP. + +3. The token generation CLI provided by this plugin requires 1) the desktop application type client id and client secret to issue an ID token for IAP as well as 2) the client id (not the secret) of the web app type credentials that will be used by IAP (as the audience of the token). + + The desktop client secret needs to be kept secret. Therefore, create a GCP secret manager secret with the desktop client secret. + + Note the name of the secret and the id of the GCP project containing the secret. + + (You will have to grant users that will use the token generation CLI access to the secret.) + +4. Test the token generation CLI: + + ```console + flyte-iap generate-user-id-token \ + --desktop_client_id < fill in desktop client id> \ + --desktop_client_secret_gcp_secret_name \ + --webapp_client_id < fill in the web app client id> \ + --project < fill in the gcp project id where the secret was saved > + ``` + + A browser window should open, asking you to login with your GCP account. Then, a succesful log in should be confirmed with *"Successfully logged into accounts.google.com"*. + + Finally, the token beginning with `eyJhbG..."` should be printed to the console. + + You can decode the token with: + + ```console + jq -R 'split(".") | select(length > 0) | .[0],.[1] | @base64d | fromjson' <<< "eyJhbG..." + ``` + + The token should be issued by `"https://accounts.google.com"`, should contain your email, and should have the desktop client id set as `"azp"` and the web app client id set as `"aud"` (audience). + +5. Configure proxy authorization with this CLI in `~/.flyte/config.yaml`: + + ```yaml + admin: + endpoint: dns:///.com + insecure: false + insecureSkipVerify: true + authType: Pkce + proxyCommand: ["flyte-iap", "generate-user-id-token", "--desktop_client_id", ...] # Add this line + ``` + + This configures the Flyte clients to send `"proxy-authorization"` headers with the token generated by the CLI with every request in order to pass the GCP Identity Aware Proxy. + + 6. For registering workflows from CICD, you might have to generate ID tokens for GCP service accounts instead of user accounts. For this purpose, you have the following options: + * `flyte-iap` provides a second sub command called `generate-service-account-id-token`. This subcommand uses either a service account key json file to obtain an ID token or alternatively obtains one from the metadata server when being run on GCP Compute Engine, App Engine, or Cloud Run. It caches tokens and only obtains a new one when the cached token is about to expire. + * If you want to avoid a flytekit/python dependency in your CICD systems, you can use the `gcloud` sdk: + + ``` + gcloud auth print-identity-token --token-format=full --audiences=".apps.googleusercontent.com" + ``` + * Adapt [this bash script](https://cloud.google.com/iap/docs/authentication-howto#obtaining_an_oidc_token_from_a_local_service_account_key_file) from the GCP Identity Aware Proxy documentation which retrieves a token in exchange for service account credentials. (You would need to replace the `curl` command in the last line with `echo $ID_TOKEN`.) + +## Configuring your Flyte deployment to use IAP + +### Introduction + +To protect your Flyte deployment with IAP, we have to deploy it with a GCE ingress (instead of the Nginx ingress used by the default Flyte deployment). + +Flyteadmin has a gRPC endpoint. The gRPC protocol requires the use of http2. When using http2 between a GCP load balancer (created by the GCE ingress) and a backend in GKE, the use of TLS is required ([see documentation](https://cloud.google.com/kubernetes-engine/docs/how-to/ingress-http2)): + +> To ensure the load balancer can make a correct HTTP2 request to your backend, your backend must be configured with SSL. + +The following deployment guide follows [this](https://cloud.google.com/architecture/exposing-service-mesh-apps-through-gke-ingress) reference architecture for the Istio service mesh on Google Kubernetes Engine. + +We will configure an Istio ingress gateway (pod) deployed behind a GCP load balancer to use http2 and TLS (see [here](https://cloud.google.com/architecture/exposing-service-mesh-apps-through-gke-ingress#security)): + +> you can enable HTTP/2 with TLS encryption between the cluster ingress [...] and the mesh ingress (the envoy proxy instance). When you enable HTTP/2 with TLS encryption for this path, you can use a self-signed or public certificate to encrypt traffic [...] + +Flyte is then deployed behind the Istio ingress gateway and does not need to be configured to use TLS itself. + +*Note that we do not do this for security reasons but to enable http2 traffic (required by gRPC) into the cluster through a GCE Ingress (which is required by IAP).* + +### Deployment + +1. If not already done, deploy the flyte-core helm chart, [activating auth](https://docs.flyte.org/en/latest/deployment/configuration/auth_setup.html#apply-oidc-configuration). Re-use the web app client id created for IAP (see section above). Disable the default ingress in the helm values by setting `common.ingress.enabled` to `false` in the helm values file. + + +2. Deployment of Istio and the Istio ingress gateway ([docs](https://istio.io/latest/docs/setup/install/helm/)) + + * `helm repo add istio https://istio-release.storage.googleapis.com/charts` + * `helm repo update` + * `kubectl create namespace istio-system` + * `helm install istio-base istio/base -n istio-system` + * `helm install istiod istio/istiod -n istio-system --wait` + * `helm install istio-ingress istio/gateway -n istio-system -f istio-values.yaml --wait` + + Here, `istio-values.yaml` contains the following: + + ```yaml + service: + annotations: + beta.cloud.google.com/backend-config: '{"default": "ingress-backend-config"}' + cloud.google.com/app-protocols: '{"https": "HTTP2"}' + type: + NodePort + ``` + + It is crucial that the service type is set to `NodePort` and not the default `LoadBalancer`. Otherwise, the Istio ingress gateway won't be deployed behind the GCP load balancer we create below but would be **publicly available on the internet!** + + With the annotations we configured the service to use http2 which is required by gRPC. We also configured the service to use a so-called backend config `ingress-backend-config` which activates IAP and which we will create in the next step. + + +3. Activate IAP for the Istio ingress gateway via a backend config: + + Create a Kubernetes secret containing the web app client id and secret we created above. The creation of the secret is described [here](https://cloud.google.com/iap/docs/enabling-kubernetes-howto#kubernetes-configure). From now on the assumption is that the secret is called `iap-oauth-client-id`. + + Create a backend config for the Istio ingress gateway: + + ```yaml + apiVersion: cloud.google.com/v1 + kind: BackendConfig + metadata: + name: ingress-backend-config + namespace: istio-system + spec: + healthCheck: + port: 15021 + requestPath: /healthz/ready + type: HTTP + iap: + enabled: true + oauthclientCredentials: + secretName: iap-oauth-client-id + ``` + + Note that apart from activating IAP, we also configured a custom health check as the istio ingress gateway doesn't use the default health check path and port assumed by the GCP load balancer. + + +4. [Install Cert Manager](https://cert-manager.io/docs/installation/helm/) to [create and rotate](https://cert-manager.io/docs/configuration/selfsigned/) a self-signed certificate for the istio ingress (pod): + + * `helm repo add jetstack https://charts.jetstack.io` + * `helm repo update` + * `helm install cert-manager jetstack/cert-manager --namespace cert-manager --create-namespace --set installCRDs=true` + + Create the following objects: + + ```yaml + apiVersion: cert-manager.io/v1 + kind: Issuer + metadata: + name: selfsigned-issuer + namespace: istio-system + spec: + selfSigned: {} + ``` + + ```yaml + apiVersion: cert-manager.io/v1 + kind: Certificate + metadata: + name: istio-ingress-cert + namespace: istio-system + spec: + commonName: istio-ingress + dnsNames: + - istio-ingress + - istio-ingress.istio-system.svc + - istio-ingress.istio-system.svc.cluster.local + issuerRef: + kind: Issuer + name: selfsigned-issuer + secretName: istio-ingress-cert + ``` + + This self-signed TLS certificate is only used between the GCP load balancer and the istio ingress gateway. It is not used by the istio ingress gateway to terminate TLS connections from the outside world (as we created it using a `NodePort` type service). Therefore, it is not unsafe to use a self-signed certificate here. Many applications deployed on GKE don't use any additional encryption between the load balancer and the backend. GCP, however, [encrypts these connections by default](https://cloud.google.com/load-balancing/docs/backend-service#encryption_between_the_load_balancer_and_backends): + + > The next hop, which is between the Google Front End (GFE) and the mesh ingress proxy, is encrypted by default. Network-level encryption between the GFEs and their backends is applied automatically. However, if your security requirements dictate that the platform owner retain ownership of the encryption keys, then you can enable HTTP/2 with TLS encryption between the cluster ingress (the GFE) and the mesh ingress (the envoy proxy instance). + + This additional self-managed encryption is also required to use http2 and in extension gRPC. To repeat, we mainly add this self-signed certificate in order to be able to expose a gRPC service (flyteadmin) via a GCP load balancer, less for the additional encryption. + + +5. Configure the istio ingress gateway to use the self-signed certificate: + + + ```yaml + apiVersion: networking.istio.io/v1beta1 + kind: Gateway + metadata: + name: default-gateway + namespace: istio-system + spec: + selector: + app: istio-ingress + istio: ingress + servers: + - hosts: + - '*' + port: + name: https + number: 443 + protocol: HTTPS + tls: + credentialName: istio-ingress-cert + mode: SIMPLE + ``` + + (Note that the `credentialName` matches the `secretName` in the `Certificate` we created.) + + This `Gateway` object configures the Istio ingress gateway (pod) to use the self-signed certificate we created above for every incoming TLS connection. + + +6. Deploy the GCE ingress that will route traffic to the istio ingress gateway: + + + * Create a global (not regional) static IP address in GCP as is described [here](https://cloud.google.com/kubernetes-engine/docs/how-to/managed-certs#prerequisites). + * Create a DNS record for your Flyte domain to route traffic to this static IP address. + * Create a GCP managed certificate (please fill in your domain): + + ```yaml + apiVersion: networking.gke.io/v1 + kind: ManagedCertificate + metadata: + name: flyte-managed-certificate + namespace: istio-system + spec: + domains: + - < fill in your domain > + ``` + * Create the ingress (please fill in the name of the static IP): + + ```yaml + apiVersion: networking.k8s.io/v1 + kind: Ingress + metadata: + annotations: + kubernetes.io/ingress.allow-http: "true" + kubernetes.io/ingress.global-static-ip-name: "< fill in >" + networking.gke.io/managed-certificates: flyte-managed-certificate + networking.gke.io/v1beta1.FrontendConfig: ingress-frontend-config + name: flyte-ingress + namespace: istio-system + spec: + rules: + - http: + paths: + - backend: + service: + name: istio-ingress + port: + number: 443 + path: / + pathType: Prefix + --- + apiVersion: networking.gke.io/v1beta1 + kind: FrontendConfig + metadata: + name: ingress-frontend-config + namespace: istio-system + spec: + redirectToHttps: + enabled: true + responseCodeName: MOVED_PERMANENTLY_DEFAULT + ``` + + This ingress routes all traffic to the istio ingress gateway via http2 and TLS. + + For clarity: The GCP load balancer TLS terminates connections coming from the outside world using a GCP managed certificate. + The self-signed certificate created above is only used between the GCP load balancer and the istio ingress gateway running in the cluster. + To repeat, because of this it is important for security that the istio ingress gateway uses a `NodePort` type service and not a `LoadBalancer`. + + * In the GCP cloud console under *Kubernetes Engine/Services & Ingress/Ingress* (selecting the respective cluster and the `istio-system` namespace), you can observe the status of the ingress, its managed certificate, and its backends. Only proceed if all statuses are green. The creation of the GCP load balancer configured by the ingress and of the managed certificate can take up to 30 minutes during the first deployment. + + +7. Connect flyteadmin and flyteconsole to the istio ingress gateway: + + So far, we created a GCE ingress (which creates a GCP load balancer). The load balancer is configured to forward all requests to the istio ingress gatway at the edge of the service mesh via http2 and TLS. + + Next, we configure the Istio service mesh to route requests from the Istio ingress gateway to flyteadmin and flyteconsole. + + In istio, this is configured using a so-called `VirtualService` object. + + Please fill in your flyte domain in the following manifest and apply it to the cluster: + + ```yaml + apiVersion: networking.istio.io/v1beta1 + kind: VirtualService + metadata: + name: flyte-virtualservice + namespace: flyte + spec: + gateways: + - istio-system/default-gateway + hosts: + - + http: + - match: + - uri: + prefix: /console + name: console-routes + route: + - destination: + host: flyteconsole + port: + number: 80 + - match: + - uri: + prefix: /api + - uri: + prefix: /healthcheck + - uri: + prefix: /v1/* + - uri: + prefix: /.well-known + - uri: + prefix: /login + - uri: + prefix: /logout + - uri: + prefix: /callback + - uri: + prefix: /me + - uri: + prefix: /config + - uri: + prefix: /oauth2 + name: admin-routes + route: + - destination: + host: flyteadmin + port: + number: 80 + - match: + - uri: + prefix: /flyteidl.service.SignalService + - uri: + prefix: /flyteidl.service.AdminService + - uri: + prefix: /flyteidl.service.DataProxyService + - uri: + prefix: /flyteidl.service.AuthMetadataService + - uri: + prefix: /flyteidl.service.IdentityService + - uri: + prefix: /grpc.health.v1.Health + name: admin-grpc-routes + route: + - destination: + host: flyteadmin + port: + number: 81 + ``` + + In this `VirtualService`, the routing rules for flyteadmin and flyteconsole are configured which in Flyte's default deployment are configured in the Nginx ingress. + + Note that the virtual service references the `Gateway` object we created above which configures the istio ingress gateway to use TLS for these connections. + +8. Test your Flyte deployment with IAP by e.g. executing this python script: + + ```python + from flytekit.remote import FlyteRemote + + from flytekit.configuration import Config + + + remote = FlyteRemote( + config=Config.auto(), + default_project="flytesnacks", + default_domain="development", + ) + + + print(remote.recent_executions()) + ``` + + A browser window should open and ask you to login with your Google account. You should then see confirmation that you *"Successfully logged into accounts.google.com"* (this was for the IAP), finally followd by confirmation that you *"Successfully logged into 'your flyte domain'"* (this was for Flyte itself). + + + +9. At this point your Flyte deployment should be successfully protected by a GCP identity aware proxy using a zero trust model. + + You should check in the GCP cloud console's *IAP* page that IAP is actually activated and configured correctly for the Istio ingress gateway (follow up on any yellow or red status symbols next to the respective backend). + + You could also open the flyte console in an incognito browser window and verify that you are asked to login with your Google account. + + Finally, you could also comment out the `proxyCommand` line in your `~/.flyte/config.yaml` and verify that you are no longer able to access your Flyte deployment behind IAP. + +10. The double login observed above is due to the fact that the Flyte clients send `"proxy-authorization"` headers generated by the CLI provided by this plugin with every request in order to make it past IAP. They still also send the regular `"authorization"` header issued by flyteadmin itself. + + Since the refresh token for Flyte and the one for IAP by default don't have the same lifespan, you likely won't notice this double login again. However, since your deployment is already protected by IAP, the ID token (issued by flyteadmin) in the `"authorization"` header mostly serves to identify users. Therefore, you can consider to increase the lifespan of the refresh token issued by flyteadmin to e.g. 7 days by setting `configmap.adminServer.auth.appAuth.selfAuthServer.refreshTokenLifespan` to e.g. `168h0m0s` in your Flyte helm values file. This way, your users should barely notice the double login. diff --git a/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/__init__.py b/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py b/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py new file mode 100644 index 0000000000..3c70429848 --- /dev/null +++ b/plugins/flytekit-identity-aware-proxy/flytekitplugins/identity_aware_proxy/cli.py @@ -0,0 +1,248 @@ +import logging +import os +import typing + +import click +import jwt +from google.api_core.exceptions import NotFound +from google.auth import default +from google.auth.transport.requests import Request +from google.cloud import secretmanager +from google.oauth2 import id_token + +from flytekit.clients.auth.auth_client import AuthorizationClient +from flytekit.clients.auth.authenticator import Authenticator +from flytekit.clients.auth.exceptions import AccessTokenNotFoundError +from flytekit.clients.auth.keyring import Credentials, KeyringStore + +WEBAPP_CLIENT_ID_HELP = ( + "Webapp type OAuth 2.0 client ID used by the IAP. " + "Typically in the form of `.apps.googleusercontent.com`. " + "Created when activating IAP for the Flyte deployment. " + "https://cloud.google.com/iap/docs/enabling-kubernetes-howto#oauth-credentials" +) + + +class GCPIdentityAwareProxyAuthenticator(Authenticator): + """ + This Authenticator encapsulates the entire OAauth 2.0 flow with GCP Identity Aware Proxy. + + The auth flow is described in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application + + Automatically opens a browser window for login. + """ + + def __init__( + self, + audience: str, + client_id: str, + client_secret: str, + verify: typing.Optional[typing.Union[bool, str]] = None, + ): + """ + Initialize with default creds from KeyStore using the audience name. + """ + super().__init__(audience, "proxy-authorization", KeyringStore.retrieve(audience), verify=verify) + self._auth_client = None + + self.audience = audience + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = "http://localhost:4444" + + def _initialize_auth_client(self): + if not self._auth_client: + self._auth_client = AuthorizationClient( + endpoint=self.audience, + # See step 3 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application + auth_endpoint="https://accounts.google.com/o/oauth2/v2/auth", + token_endpoint="https://oauth2.googleapis.com/token", + # See step 3 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application + scopes=["openid", "email"], + client_id=self.client_id, + redirect_uri=self.redirect_uri, + verify=self._verify, + # See step 3 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application + request_auth_code_params={ + "cred_ref": "true", + "access_type": "offline", + }, + # See step 4 in https://cloud.google.com/iap/docs/authentication-howto#signing_in_to_the_application + request_access_token_params={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "audience": self.audience, + "redirect_uri": self.redirect_uri, + }, + # See https://cloud.google.com/iap/docs/authentication-howto#refresh_token + refresh_access_token_params={ + "client_secret": self.client_secret, + "audience": self.audience, + }, + ) + + def refresh_credentials(self): + """Refresh the IAP credentials. If no credentials are found, it will kick off a full OAuth 2.0 authorization flow.""" + self._initialize_auth_client() + if self._creds: + """We have an id token so lets try to refresh it""" + try: + self._creds = self._auth_client.refresh_access_token(self._creds) + if self._creds: + KeyringStore.store(self._creds) + return + except AccessTokenNotFoundError: + logging.warning("Failed to refresh token. Kicking off a full authorization flow.") + KeyringStore.delete(self._endpoint) + + self._creds = self._auth_client.get_creds_from_remote() + KeyringStore.store(self._creds) + + +def get_gcp_secret_manager_secret(project_id: str, secret_id: str, version: typing.Optional[str] = "latest"): + """Retrieve secret from GCP secret manager.""" + client = secretmanager.SecretManagerServiceClient() + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version}" + try: + response = client.access_secret_version(name=name) + except NotFound as e: + raise click.BadParameter(e.message) + payload = response.payload.data.decode("UTF-8") + return payload + + +@click.group() +def cli(): + """Generate ID tokens for GCP Identity Aware Proxy (IAP).""" + pass + + +@cli.command() +@click.option( + "--desktop_client_id", + type=str, + default=None, + required=True, + help=( + "Desktop type OAuth 2.0 client ID. Typically in the form of `.apps.googleusercontent.com`. " + "Create by following https://cloud.google.com/iap/docs/authentication-howto#setting_up_the_client_id" + ), +) +@click.option( + "--desktop_client_secret_gcp_secret_name", + type=str, + default=None, + required=True, + help=( + "Name of a GCP secret manager secret containing the desktop type OAuth 2.0 client secret " + "obtained together with desktop type OAuth 2.0 client ID." + ), +) +@click.option( + "--webapp_client_id", + type=str, + default=None, + required=True, + help=WEBAPP_CLIENT_ID_HELP, +) +@click.option( + "--project", + type=str, + default=None, + required=True, + help="GCP project ID (in which `desktop_client_secret_gcp_secret_name` is saved).", +) +def generate_user_id_token( + desktop_client_id: str, desktop_client_secret_gcp_secret_name: str, webapp_client_id: str, project: str +): + """Generate a user account ID token for proxy-authorization with GCP Identity Aware Proxy.""" + desktop_client_secret = get_gcp_secret_manager_secret(project, desktop_client_secret_gcp_secret_name) + + iap_authenticator = GCPIdentityAwareProxyAuthenticator( + audience=webapp_client_id, + client_id=desktop_client_id, + client_secret=desktop_client_secret, + ) + try: + iap_authenticator.refresh_credentials() + except Exception as e: + raise click.ClickException(f"Failed to obtain credentials for GCP Identity Aware Proxy (IAP): {e}") + + click.echo(iap_authenticator.get_credentials().id_token) + + +def get_service_account_id_token(audience: str, service_account_email: str) -> str: + """Fetch an ID Token for the service account used by the current environment. + + Uses flytekit's KeyringStore to cache the ID token. + + This function acquires ID token from the environment in the following order. + See https://google.aip.dev/auth/4110. + + 1. If the environment variable ``GOOGLE_APPLICATION_CREDENTIALS`` is set + to the path of a valid service account JSON file, then ID token is + acquired using this service account credentials. + 2. If the application is running in Compute Engine, App Engine or Cloud Run, + then the ID token are obtained from the metadata server. + + Args: + audience (str): The audience that this ID token is intended for. + service_account_email (str): The email address of the service account. + """ + # Flytekit's KeyringStore, by default, uses the endpoint as the key to store the credentials + # We use the audience and the service account email as the key + audience_and_account_key = audience + "-" + service_account_email + creds = KeyringStore.retrieve(audience_and_account_key) + if creds: + is_expired = False + try: + exp_margin = -300 # Generate a new token if it expires in less than 5 minutes + jwt.decode( + creds.id_token.encode("utf-8"), + options={"verify_signature": False, "verify_exp": True}, + leeway=exp_margin, + ) + except jwt.ExpiredSignatureError: + is_expired = True + + if not is_expired: + return creds.id_token + + token = id_token.fetch_id_token(Request(), audience) + + KeyringStore.store(Credentials(for_endpoint=audience_and_account_key, access_token="", id_token=token)) + return token + + +@cli.command() +@click.option( + "--webapp_client_id", + type=str, + default=None, + required=True, + help=WEBAPP_CLIENT_ID_HELP, +) +@click.option( + "--service_account_key", + type=click.Path(exists=True, dir_okay=False), + default=None, + required=False, + help=( + "Path to a service account key file. Alternatively set the environment variable " + "`GOOGLE_APPLICATION_CREDENTIALS` to the path of the service account key file. " + "If not provided and in Compute Engine, App Engine, or Cloud Run, will retrieve " + "the ID token from the metadata server." + ), +) +def generate_service_account_id_token(webapp_client_id: str, service_account_key: str): + """Generate a service account ID token for proxy-authorization with GCP Identity Aware Proxy.""" + if service_account_key: + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = service_account_key + + application_default_credentials, _ = default() + token = get_service_account_id_token(webapp_client_id, application_default_credentials.service_account_email) + click.echo(token) + + +if __name__ == "__main__": + cli() diff --git a/plugins/flytekit-identity-aware-proxy/setup.py b/plugins/flytekit-identity-aware-proxy/setup.py new file mode 100644 index 0000000000..33f8af248d --- /dev/null +++ b/plugins/flytekit-identity-aware-proxy/setup.py @@ -0,0 +1,43 @@ +from setuptools import setup + +PLUGIN_NAME = "identity_aware_proxy" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["click", "google-cloud-secret-manager", "google-auth", "flytekit"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="External command plugin to generate ID tokens for GCP Identity Aware Proxy", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-identity-aware-proxy", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + entry_points={ + "console_scripts": [ + "flyte-iap=flytekitplugins.identity_aware_proxy.cli:cli", + ], + }, + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/plugins/flytekit-identity-aware-proxy/tests/__init__.py b/plugins/flytekit-identity-aware-proxy/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-identity-aware-proxy/tests/test_flytekitplugins_iap.py b/plugins/flytekit-identity-aware-proxy/tests/test_flytekitplugins_iap.py new file mode 100644 index 0000000000..766ff646ab --- /dev/null +++ b/plugins/flytekit-identity-aware-proxy/tests/test_flytekitplugins_iap.py @@ -0,0 +1,146 @@ +import uuid +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import click +import jwt +import pytest +from click.testing import CliRunner +from flytekitplugins.identity_aware_proxy.cli import cli, get_gcp_secret_manager_secret, get_service_account_id_token +from google.api_core.exceptions import NotFound + + +def test_help() -> None: + """Smoke test external command IAP ID token generator cli by printing help message.""" + runner = CliRunner() + result = runner.invoke(cli, "--help") + assert "Generate ID tokens" in result.output + assert result.exit_code == 0 + + result = runner.invoke(cli, ["generate-user-id-token", "--help"]) + assert "Generate a user account ID token" in result.output + assert result.exit_code == 0 + + result = runner.invoke(cli, ["generate-service-account-id-token", "--help"]) + assert "Generate a service account ID token" in result.output + assert result.exit_code == 0 + + +def test_get_gcp_secret_manager_secret(): + """Test retrieval of GCP secret manager secret.""" + project_id = "test_project" + secret_id = "test_secret" + version = "latest" + expected_payload = "test_payload" + + mock_client = MagicMock() + mock_client.access_secret_version.return_value.payload.data.decode.return_value = expected_payload + with patch("google.cloud.secretmanager.SecretManagerServiceClient", return_value=mock_client): + payload = get_gcp_secret_manager_secret(project_id, secret_id, version) + assert payload == expected_payload + + name = f"projects/{project_id}/secrets/{secret_id}/versions/{version}" + mock_client.access_secret_version.assert_called_once_with(name=name) + + +def test_get_gcp_secret_manager_secret_not_found(): + """Test retrieving non-existing secret from GCP secret manager.""" + project_id = "test_project" + secret_id = "test_secret" + version = "latest" + + mock_client = MagicMock() + mock_client.access_secret_version.side_effect = NotFound("Secret not found") + with patch("google.cloud.secretmanager.SecretManagerServiceClient", return_value=mock_client): + with pytest.raises(click.BadParameter): + get_gcp_secret_manager_secret(project_id, secret_id, version) + + +def create_mock_token(aud: str, expires_in: timedelta = None): + """Create a mock JWT token with a certain audience, expiration time, and random JTI.""" + exp = datetime.utcnow() + expires_in + jti = "test_token" + str(uuid.uuid4()) + payload = {"exp": exp, "aud": aud, "jti": jti} + + secret = "your-secret-key" + algorithm = "HS256" + + return jwt.encode(payload, secret, algorithm=algorithm) + + +@patch("flytekitplugins.identity_aware_proxy.cli.id_token.fetch_id_token") +@patch("keyring.get_password") +@patch("keyring.set_password") +def test_sa_id_token_no_token_in_keyring(kr_set_password, kr_get_password, mock_fetch_id_token): + """Test retrieval and caching of service account ID token when no token is stored in keyring yet.""" + test_audience = "test_audience" + service_account_email = "default" + + # Start with a clean KeyringStore + tmp_test_keyring_store = {} + kr_get_password.side_effect = lambda service, user: tmp_test_keyring_store.get(service, {}).get(user, None) + kr_set_password.side_effect = lambda service, user, pwd: tmp_test_keyring_store.update({service: {user: pwd}}) + + mock_fetch_id_token.side_effect = lambda _, aud: create_mock_token(aud, expires_in=timedelta(hours=1)) + + token = get_service_account_id_token(test_audience, service_account_email) + + assert jwt.decode(token.encode("utf-8"), options={"verify_signature": False})["aud"] == test_audience + assert jwt.decode(token.encode("utf-8"), options={"verify_signature": False})["jti"].startswith("test_token") + + # Check that the token is cached in the KeyringStore + second_token = get_service_account_id_token(test_audience, service_account_email) + + assert token == second_token + + +@patch("flytekitplugins.identity_aware_proxy.cli.id_token.fetch_id_token") +@patch("keyring.get_password") +@patch("keyring.set_password") +def test_sa_id_token_expired_token_in_keyring(kr_set_password, kr_get_password, mock_fetch_id_token): + """Test that expired service account ID token in keyring is replaced with a new one.""" + test_audience = "test_audience" + service_account_email = "default" + + # Start with an expired token in the KeyringStore + expired_id_token = create_mock_token(test_audience, expires_in=timedelta(hours=-1)) + tmp_test_keyring_store = {test_audience + "-" + service_account_email: {"id_token": expired_id_token}} + kr_get_password.side_effect = lambda service, user: tmp_test_keyring_store.get(service, {}).get(user, None) + kr_set_password.side_effect = lambda service, user, pwd: tmp_test_keyring_store.update({service: {user: pwd}}) + + mock_fetch_id_token.side_effect = lambda _, aud: create_mock_token(aud, expires_in=timedelta(hours=1)) + + token = get_service_account_id_token(test_audience, service_account_email) + + assert token != expired_id_token + assert jwt.decode(token.encode("utf-8"), options={"verify_signature": False})["aud"] == test_audience + assert jwt.decode(token.encode("utf-8"), options={"verify_signature": False})["jti"].startswith("test_token") + + +@patch("flytekitplugins.identity_aware_proxy.cli.id_token.fetch_id_token") +@patch("keyring.get_password") +@patch("keyring.set_password") +def test_sa_id_token_switch_accounts(kr_set_password, kr_get_password, mock_fetch_id_token): + """Test that caching works when switching service accounts.""" + test_audience = "test_audience" + service_account_email = "default" + service_account_other_email = "other" + + # Start with a clean KeyringStore + tmp_test_keyring_store = {} + kr_get_password.side_effect = lambda service, user: tmp_test_keyring_store.get(service, {}).get(user, None) + kr_set_password.side_effect = lambda service, user, pwd: tmp_test_keyring_store.update({service: {user: pwd}}) + + mock_fetch_id_token.side_effect = lambda _, aud: create_mock_token(aud, expires_in=timedelta(hours=1)) + + default_token = get_service_account_id_token(test_audience, service_account_email) + other_token = get_service_account_id_token(test_audience, service_account_other_email) + + assert default_token != other_token + + # Check that the tokens are cached in the KeyringStore + new_default_token = get_service_account_id_token(test_audience, service_account_email) + new_other_token = get_service_account_id_token(test_audience, service_account_other_email) + + assert default_token == new_default_token + assert other_token == new_other_token diff --git a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py index 9e8e5ef937..363998af45 100644 --- a/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py +++ b/plugins/flytekit-k8s-pod/flytekitplugins/pod/task.py @@ -26,10 +26,11 @@ class Pod(object): Pod is a platform-wide configuration that uses pod templates. By default, every task is launched as a container in a pod. This plugin helps expose a fully modifiable Kubernetes pod spec to customize the task execution runtime. To use pod tasks: (1) Define a pod spec, and (2) Specify the primary container name. + :param V1PodSpec pod_spec: Kubernetes pod spec. https://kubernetes.io/docs/concepts/workloads/pods :param str primary_container_name: the primary container name. If provided the pod-spec can contain a container whose name matches the primary_container_name. This will force Flyte to give up control of the primary - container and will expect users to control setting up the container. If you expect your python function to run as is, simply create containers that do not match the default primary-container-name and Flyte will auto-inject a - container for the python function based on the default image provided during serialization. + container and will expect users to control setting up the container. If you expect your python function to run as is, simply create containers that do not match the default primary-container-name and Flyte will auto-inject a + container for the python function based on the default image provided during serialization. :param Optional[Dict[str, str]] labels: Labels are key/value pairs that are attached to pod spec :param Optional[Dict[str, str]] annotations: Annotations are key/value pairs that are attached to arbitrary non-identifying metadata to pod spec. """ diff --git a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py index 55867a22ec..386f493854 100644 --- a/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py +++ b/plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py @@ -5,7 +5,7 @@ import os from dataclasses import dataclass, field from enum import Enum -from typing import Any, Callable, Dict, Optional, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union import cloudpickle from flyteidl.plugins.kubeflow import common_pb2 as kubeflow_common @@ -203,7 +203,22 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: TaskPlugins.register_pythontask_plugin(PyTorch, PyTorchFunctionTask) -def spawn_helper(fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkpoint_src: str, kwargs) -> Any: +class ElasticWorkerResult(NamedTuple): + """ + A named tuple representing the result of a torch elastic worker process. + + Attributes: + return_value (Any): The value returned by the task function in the worker process. + decks (list[flytekit.Deck]): A list of flytekit Deck objects created in the worker process. + """ + + return_value: Any + decks: List[flytekit.Deck] + + +def spawn_helper( + fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkpoint_src: str, kwargs +) -> ElasticWorkerResult: """Help to spawn worker processes. The purpose of this function is to 1) be pickleable so that it can be used with @@ -220,7 +235,8 @@ def spawn_helper(fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkp checkpoint_src (str): Location where the new checkpoint should be copied to. Returns: - The return value of the received target function. + ElasticWorkerResult: A named tuple containing the return value of the task function and a list of + flytekit Deck objects created in the worker process. """ from flytekit.bin.entrypoint import setup_execution @@ -231,7 +247,8 @@ def spawn_helper(fn: bytes, raw_output_prefix: str, checkpoint_dest: str, checkp ): fn = cloudpickle.loads(fn) return_val = fn(**kwargs) - return return_val + + return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks) class PytorchElasticFunctionTask(PythonFunctionTask[Elastic]): @@ -336,7 +353,8 @@ def _execute(self, **kwargs) -> Any: def fn_partial(): """Closure of the task function with kwargs already bound.""" - return self._task_function(**kwargs) + return_val = self._task_function(**kwargs) + return ElasticWorkerResult(return_value=return_val, decks=flytekit.current_context().decks) launcher_target_func = fn_partial launcher_args = () @@ -365,7 +383,13 @@ def fn_partial(): # `out` is a dictionary of rank (not local rank) -> result # Rank 0 returns the result of the task function if 0 in out: - return out[0] + # For rank 0, we transfer the decks created in the worker process to the parent process + ctx = flytekit.current_context() + for deck in out[0].decks: + if not isinstance(deck, flytekit.deck.deck.TimeLineDeck): + ctx.decks.append(deck) + + return out[0].return_value else: raise IgnoreOutputs() diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index 0d134a5e18..bced35a6df 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -6,16 +6,15 @@ import pytest import torch import torch.distributed as dist -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from flytekitplugins.kfpytorch.task import Elastic import flytekit from flytekit import task, workflow -@dataclass_json @dataclass -class Config: +class Config(DataClassJsonMixin): lr: float = 1e-5 bs: int = 64 name: str = "foo" @@ -112,3 +111,41 @@ def test_task(): with mock.patch("torch.distributed.launcher.api.LaunchConfig", side_effect=LaunchConfig) as mock_launch_config: test_task() assert mock_launch_config.call_args[1]["rdzv_configs"] == rdzv_configs + + +@pytest.mark.parametrize("start_method", ["spawn", "fork"]) +def test_deck(start_method: str) -> None: + """Test that decks created in the main worker process are transferred to the parent process.""" + world_size = 2 + + @task( + task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method), + disable_deck=False, + ) + def train(): + import os + + ctx = flytekit.current_context() + deck = flytekit.Deck("test-deck", f"Hello Flyte Deck viewer from worker process {os.environ.get('RANK')}") + ctx.decks.append(deck) + default_deck = ctx.default_deck + default_deck.append("Hello from default deck") + + @workflow + def wf(): + train() + + wf() + + ctx = flytekit.current_context() + + expected_deck_names = {"timeline", "default", "test-deck"} + found_deck_names = set(d.name for d in ctx.decks) + + assert expected_deck_names.issubset(found_deck_names) + + default_deck = [d for d in ctx.decks if d.name == "default"][0] + assert "Hello from default deck" == default_deck.html.strip() + + test_deck = [d for d in ctx.decks if d.name == "test-deck"][0] + assert "Hello Flyte Deck viewer from worker process 0" in test_deck.html diff --git a/plugins/flytekit-mlflow/setup.py b/plugins/flytekit-mlflow/setup.py index 06addeb060..32bf295aec 100644 --- a/plugins/flytekit-mlflow/setup.py +++ b/plugins/flytekit-mlflow/setup.py @@ -18,7 +18,7 @@ packages=[f"flytekitplugins.{PLUGIN_NAME}"], install_requires=plugin_requires, license="apache2", - python_requires=">=3.8,<3.11", + python_requires=">=3.8", classifiers=[ "Intended Audience :: Science/Research", "Intended Audience :: Developers", diff --git a/plugins/flytekit-mmcloud/README.md b/plugins/flytekit-mmcloud/README.md new file mode 100644 index 0000000000..664a648f02 --- /dev/null +++ b/plugins/flytekit-mmcloud/README.md @@ -0,0 +1,104 @@ +# Flytekit Memory Machine Cloud Plugin + +Flyte Agent plugin to allow executing Flyte tasks using MemVerge Memory Machine Cloud. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-mmcloud +``` + +To get started with MMCloud, refer to the [MMCloud User Guide](https://docs.memverge.com/mmce/current/userguide/olh/index.html). + +## Getting Started + +This plugin allows executing `PythonFunctionTask` using MMCloud without changing any function code. + +[Resource](https://docs.flyte.org/projects/cookbook/en/latest/auto_examples/productionizing/customizing_resources.html) (cpu and mem) requests and limits, [container](https://docs.flyte.org/projects/cookbook/en/latest/auto_examples/customizing_dependencies/multi_images.html) images, and [environment](https://docs.flyte.org/projects/flytekit/en/latest/generated/flytekit.task.html) variable specifications are supported. + +[ImageSpec](https://docs.flyte.org/projects/cookbook/en/latest/auto_examples/customizing_dependencies/image_spec.html) may be used to define images to run tasks. + +### Credentials + +The following [secrets](https://docs.flyte.org/projects/cookbook/en/latest/auto_examples/productionizing/use_secrets.html) are required to be defined for the agent server: +* `mmc_address`: MMCloud OpCenter address +* `mmc_username`: MMCloud OpCenter username +* `mmc_password`: MMCloud OpCenter password + +### Defaults + +Compute resources: +* If only requests are specified, there are no limits. +* If only limits are specified, the requests are equal to the limits. +* If neither resource requests nor limits are specified, the default requests used for job submission are `cpu="1"` and `mem="1Gi"`, and there are no limits. + +### Example + +`example.py` workflow example: +```python +import pandas as pd +from flytekit import ImageSpec, Resources, task, workflow +from sklearn.datasets import load_wine +from sklearn.linear_model import LogisticRegression + +from flytekitplugins.mmcloud import MMCloudConfig + +image_spec = ImageSpec(packages=["scikit-learn"], registry="docker.io/memverge") + + +@task +def get_data() -> pd.DataFrame: + """Get the wine dataset.""" + return load_wine(as_frame=True).frame + + +@task(task_config=MMCloudConfig(), container_image=image_spec) # Task will be submitted as MMCloud job +def process_data(data: pd.DataFrame) -> pd.DataFrame: + """Simplify the task from a 3-class to a binary classification problem.""" + return data.assign(target=lambda x: x["target"].where(x["target"] == 0, 1)) + + +@task( + task_config=MMCloudConfig(submit_extra="--migratePolicy [enable=true]"), + requests=Resources(cpu="1", mem="1Gi"), + limits=Resources(cpu="2", mem="4Gi"), + container_image=image_spec, + environment={"KEY": "value"}, +) +def train_model(data: pd.DataFrame, hyperparameters: dict) -> LogisticRegression: + """Train a model on the wine dataset.""" + features = data.drop("target", axis="columns") + target = data["target"] + return LogisticRegression(max_iter=3000, **hyperparameters).fit(features, target) + + +@workflow +def training_workflow(hyperparameters: dict) -> LogisticRegression: + """Put all of the steps together into a single workflow.""" + data = get_data() + processed_data = process_data(data=data) + return train_model( + data=processed_data, + hyperparameters=hyperparameters, + ) +``` + +### Agent Image + +Install `flytekitplugins-mmcloud` in the agent image. + +A `float` binary (obtainable via the OpCenter) is required. Copy it to the agent image `PATH`. + +Sample `Dockerfile` for building an agent image: +```dockerfile +FROM python:3.11-slim-bookworm + +WORKDIR /root +ENV PYTHONPATH /root + +# flytekit will autoload the agent if package is installed. +RUN pip install flytekitplugins-mmcloud +COPY float /usr/local/bin/float + +CMD pyflyte serve --port 8000 +``` diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/__init__.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/__init__.py new file mode 100644 index 0000000000..e3de1897fc --- /dev/null +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/__init__.py @@ -0,0 +1,16 @@ +""" +.. currentmodule:: flytekitplugins.mmcloud + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + MMCloudConfig + MMCloudTask + MMCloudAgent +""" + +from .agent import MMCloudAgent +from .task import MMCloudConfig, MMCloudTask diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py new file mode 100644 index 0000000000..b44906e144 --- /dev/null +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/agent.py @@ -0,0 +1,212 @@ +import json +import shlex +import subprocess +from dataclasses import asdict, dataclass +from tempfile import NamedTemporaryFile +from typing import Optional + +import grpc +from flyteidl.admin.agent_pb2 import CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource +from flytekitplugins.mmcloud.utils import async_check_output, mmcloud_status_to_flyte_state + +from flytekit import current_context +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry +from flytekit.loggers import logger +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +@dataclass +class Metadata: + job_id: str + + +class MMCloudAgent(AgentBase): + def __init__(self): + super().__init__(task_type="mmcloud_task") + self._response_format = ["--format", "json"] + + async def async_login(self): + """ + Log in to Memory Machine Cloud OpCenter. + """ + try: + # If already logged in, this will reset the session timer + login_info_command = ["float", "login", "--info"] + await async_check_output(*login_info_command) + except subprocess.CalledProcessError: + logger.info("Attempting to log in to OpCenter") + try: + secrets = current_context().secrets + login_command = [ + "float", + "login", + "--address", + secrets.get("mmc_address"), + "--username", + secrets.get("mmc_username"), + "--password", + secrets.get("mmc_password"), + ] + await async_check_output(*login_command) + except subprocess.CalledProcessError: + logger.exception("Failed to log in to OpCenter") + raise + + logger.info("Logged in to OpCenter") + + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + """ + Submit Flyte task as MMCloud job to the OpCenter, and return the job UID for the task. + """ + submit_command = [ + "float", + "submit", + "--force", + *self._response_format, + ] + + # We do not use container.resources because FlytePropeller will impose limits that should not apply to MMCloud + min_cpu, min_mem, max_cpu, max_mem = task_template.custom["resources"] + submit_command.extend(["--cpu", f"{min_cpu}:{max_cpu}"] if max_cpu else ["--cpu", f"{min_cpu}"]) + submit_command.extend(["--mem", f"{min_mem}:{max_mem}"] if max_mem else ["--mem", f"{min_mem}"]) + + container = task_template.container + + image = container.image + submit_command.extend(["--image", image]) + + env = container.env + for key, value in env.items(): + submit_command.extend(["--env", f"{key}={value}"]) + + submit_extra = task_template.custom["submit_extra"] + submit_command.extend(shlex.split(submit_extra)) + + args = task_template.container.args + script_lines = ["#!/bin/bash\n", f"{shlex.join(args)}\n"] + + task_id = task_template.id + try: + # float binary takes a job file as input, so one must be created + # Use a uniquely named temporary file to avoid race conditions and clutter + with NamedTemporaryFile(mode="w") as job_file: + job_file.writelines(script_lines) + # Flush immediately so that the job file is usable + job_file.flush() + logger.debug("Wrote job script") + + submit_command.extend(["--job", job_file.name]) + + logger.info(f"Attempting to submit Flyte task {task_id} as MMCloud job") + logger.debug(f"With command: {submit_command}") + try: + await self.async_login() + submit_response = await async_check_output(*submit_command) + submit_response = json.loads(submit_response.decode()) + job_id = submit_response["id"] + except subprocess.CalledProcessError as e: + logger.exception( + f"Failed to submit Flyte task {task_id} as MMCloud job\n" + f"[stdout] {e.stdout.decode()}\n" + f"[stderr] {e.stderr.decode()}\n" + ) + raise + except (UnicodeError, json.JSONDecodeError): + logger.exception(f"Failed to decode submit response for Flyte task: {task_id}") + raise + except KeyError: + logger.exception(f"Failed to obtain MMCloud job id for Flyte task: {task_id}") + raise + + logger.info(f"Submitted Flyte task {task_id} as MMCloud job {job_id}") + logger.debug(f"OpCenter response: {submit_response}") + except OSError: + logger.exception("Cannot open job script for writing") + raise + + metadata = Metadata(job_id=job_id) + + return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + """ + Return the status of the task, and return the outputs on success. + """ + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + job_id = metadata.job_id + + show_command = [ + "float", + "show", + *self._response_format, + "--job", + job_id, + ] + + logger.info(f"Attempting to obtain status for MMCloud job {job_id}") + logger.debug(f"With command: {show_command}") + try: + await self.async_login() + show_response = await async_check_output(*show_command) + show_response = json.loads(show_response.decode()) + job_status = show_response["status"] + except subprocess.CalledProcessError as e: + logger.exception( + f"Failed to get show response for MMCloud job: {job_id}\n" + f"[stdout] {e.stdout.decode()}\n" + f"[stderr] {e.stderr.decode()}\n" + ) + raise + except (UnicodeError, json.JSONDecodeError): + logger.exception(f"Failed to decode show response for MMCloud job: {job_id}") + raise + except KeyError: + logger.exception(f"Failed to obtain status for MMCloud job: {job_id}") + raise + + task_state = mmcloud_status_to_flyte_state(job_status) + + logger.info(f"Obtained status for MMCloud job {job_id}: {job_status}") + logger.debug(f"OpCenter response: {show_response}") + + return GetTaskResponse(resource=Resource(state=task_state)) + + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + """ + Delete the task. This call should be idempotent. + """ + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + job_id = metadata.job_id + + cancel_command = [ + "float", + "cancel", + "--force", + "--job", + job_id, + ] + + logger.info(f"Attempting to cancel MMCloud job {job_id}") + logger.debug(f"With command: {cancel_command}") + try: + await self.async_login() + await async_check_output(*cancel_command) + except subprocess.CalledProcessError as e: + logger.exception( + f"Failed to cancel MMCloud job: {job_id}\n[stdout] {e.stdout.decode()}\n[stderr] {e.stderr.decode()}\n" + ) + raise + + logger.info(f"Submitted cancel request for MMCloud job: {job_id}") + + return DeleteTaskResponse() + + +AgentRegistry.register(MMCloudAgent()) diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py new file mode 100644 index 0000000000..3a61d590d7 --- /dev/null +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/task.py @@ -0,0 +1,64 @@ +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +from flytekitplugins.mmcloud.utils import flyte_to_mmcloud_resources +from google.protobuf import json_format +from google.protobuf.struct_pb2 import Struct + +from flytekit.configuration import SerializationSettings +from flytekit.core.python_function_task import PythonFunctionTask +from flytekit.core.resources import Resources +from flytekit.extend import TaskPlugins +from flytekit.image_spec.image_spec import ImageSpec + + +@dataclass +class MMCloudConfig(object): + """ + Configures MMCloudTask. Tasks specified with MMCloudConfig will be executed using Memory Machine Cloud. + """ + + # This allows the user to specify additional arguments for the float submit command + submit_extra: str = "" + + +class MMCloudTask(PythonFunctionTask): + _TASK_TYPE = "mmcloud_task" + + def __init__( + self, + task_config: Optional[MMCloudConfig], + task_function: Callable, + container_image: Optional[Union[str, ImageSpec]] = None, + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, + **kwargs, + ): + super().__init__( + task_config=task_config or MMCloudConfig(), + task_type=self._TASK_TYPE, + task_function=task_function, + container_image=container_image, + **kwargs, + ) + + self._mmcloud_resources = flyte_to_mmcloud_resources(requests=requests, limits=limits) + + def execute(self, **kwargs) -> Any: + return PythonFunctionTask.execute(self, **kwargs) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + """ + Return plugin-specific data as a serializable dictionary. + """ + config = { + "submit_extra": self.task_config.submit_extra, + "resources": [str(resource) if resource else None for resource in self._mmcloud_resources], + } + s = Struct() + s.update(config) + return json_format.MessageToDict(s) + + +TaskPlugins.register_pythontask_plugin(MMCloudConfig, MMCloudTask) diff --git a/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/utils.py b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/utils.py new file mode 100644 index 0000000000..03696d6c45 --- /dev/null +++ b/plugins/flytekit-mmcloud/flytekitplugins/mmcloud/utils.py @@ -0,0 +1,89 @@ +import asyncio +import shlex +import subprocess +from asyncio.subprocess import PIPE +from decimal import ROUND_CEILING, Decimal +from typing import Optional, Tuple + +from flyteidl.admin.agent_pb2 import PERMANENT_FAILURE, RETRYABLE_FAILURE, RUNNING, SUCCEEDED, State +from kubernetes.utils.quantity import parse_quantity + +from flytekit.core.resources import Resources + +MMCLOUD_STATUS_TO_FLYTE_STATE = { + "Submitted": RUNNING, + "Initializing": RUNNING, + "Starting": RUNNING, + "Executing": RUNNING, + "Capturing": RUNNING, + "Floating": RUNNING, + "Suspended": RUNNING, + "Suspending": RUNNING, + "Resuming": RUNNING, + "Completed": SUCCEEDED, + "Cancelled": PERMANENT_FAILURE, + "Cancelling": PERMANENT_FAILURE, + "FailToComplete": RETRYABLE_FAILURE, + "FailToExecute": RETRYABLE_FAILURE, + "CheckpointFailed": RETRYABLE_FAILURE, + "Timedout": RETRYABLE_FAILURE, + "NoAvailableHost": RETRYABLE_FAILURE, + "Unknown": RETRYABLE_FAILURE, + "WaitingForLicense": PERMANENT_FAILURE, +} + + +def mmcloud_status_to_flyte_state(status: str) -> State: + """ + Map MMCloud status to Flyte state. + """ + return MMCLOUD_STATUS_TO_FLYTE_STATE[status] + + +def flyte_to_mmcloud_resources( + requests: Optional[Resources] = None, + limits: Optional[Resources] = None, +) -> Tuple[int, int, int, int]: + """ + Map Flyte (K8s) resources to MMCloud resources. + """ + B_IN_GIB = 1073741824 + + # MMCloud does not support cpu under 1 or mem under 1Gi + req_cpu = max(Decimal(1), parse_quantity(requests.cpu)) if requests and requests.cpu else None + req_mem = max(Decimal(B_IN_GIB), parse_quantity(requests.mem)) if requests and requests.mem else None + lim_cpu = max(Decimal(1), parse_quantity(limits.cpu)) if limits and limits.cpu else None + lim_mem = max(Decimal(B_IN_GIB), parse_quantity(limits.mem)) if limits and limits.mem else None + + # Convert Decimal to int + # Round up so that resource demands are met + max_cpu = int(lim_cpu.to_integral_value(rounding=ROUND_CEILING)) if lim_cpu else None + max_mem = int(lim_mem.to_integral_value(rounding=ROUND_CEILING)) if lim_mem else None + + # Use the maximum as the minimum if no minimum is specified + # Use min_cpu 1 and min_mem 1Gi if neither minimum nor maximum are specified + min_cpu = int(req_cpu.to_integral_value(rounding=ROUND_CEILING)) if req_cpu else max_cpu or 1 + min_mem = int(req_mem.to_integral_value(rounding=ROUND_CEILING)) if req_mem else max_mem or B_IN_GIB + + if min_cpu and max_cpu and min_cpu > max_cpu: + raise ValueError("cpu request cannot be greater than cpu limit") + if min_mem and max_mem and min_mem > max_mem: + raise ValueError("mem request cannot be greater than mem limit") + + # Convert B to GiB + min_mem = (min_mem + B_IN_GIB - 1) // B_IN_GIB if min_mem else None + max_mem = (max_mem + B_IN_GIB - 1) // B_IN_GIB if max_mem else None + + return min_cpu, min_mem, max_cpu, max_mem + + +async def async_check_output(*args, **kwargs): + """ + This behaves similarly to subprocess.check_output(). + """ + process = await asyncio.create_subprocess_exec(*args, stdout=PIPE, stderr=PIPE, **kwargs) + stdout, stderr = await process.communicate() + returncode = process.returncode + if returncode != 0: + raise subprocess.CalledProcessError(returncode, shlex.join(args), output=stdout, stderr=stderr) + return stdout diff --git a/plugins/flytekit-mmcloud/setup.py b/plugins/flytekit-mmcloud/setup.py new file mode 100644 index 0000000000..6aa45a93ee --- /dev/null +++ b/plugins/flytekit-mmcloud/setup.py @@ -0,0 +1,39 @@ +from setuptools import setup + +PLUGIN_NAME = "mmcloud" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.9.1,<2.0.0"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="Helen Zhang", + author_email="helen.zhang@memverge.com", + description="MemVerge Flyte plugin", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-mmcloud/tests/__init__.py b/plugins/flytekit-mmcloud/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-mmcloud/tests/test_mmcloud.py b/plugins/flytekit-mmcloud/tests/test_mmcloud.py new file mode 100644 index 0000000000..e5a1eb56d5 --- /dev/null +++ b/plugins/flytekit-mmcloud/tests/test_mmcloud.py @@ -0,0 +1,233 @@ +import asyncio +import subprocess +from collections import OrderedDict +from shutil import which +from unittest.mock import MagicMock + +import grpc +import pytest +from flyteidl.admin.agent_pb2 import PERMANENT_FAILURE, RUNNING, SUCCEEDED +from flytekitplugins.mmcloud import MMCloudAgent, MMCloudConfig, MMCloudTask +from flytekitplugins.mmcloud.utils import async_check_output, flyte_to_mmcloud_resources + +from flytekit import Resources, task +from flytekit.configuration import DefaultImages, ImageConfig, SerializationSettings +from flytekit.extend import get_serializable +from flytekit.extend.backend.base_agent import AgentRegistry + +float_missing = which("float") is None + + +def test_mmcloud_task(): + task_config = MMCloudConfig(submit_extra="--migratePolicy [enable=true]") + requests = Resources(cpu="2", mem="4Gi") + limits = Resources(cpu="4") + container_image = DefaultImages.default_image() + environment = {"KEY": "value"} + + @task( + task_config=task_config, + requests=requests, + limits=limits, + container_image=container_image, + environment=environment, + ) + def say_hello(name: str) -> str: + return f"Hello, {name}." + + assert say_hello.task_config == task_config + assert say_hello.task_type == "mmcloud_task" + assert isinstance(say_hello, MMCloudTask) + + serialization_settings = SerializationSettings(image_config=ImageConfig()) + task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello) + template = task_spec.template + container = template.container + + assert template.custom == {"submit_extra": "--migratePolicy [enable=true]", "resources": ["2", "4", "4", None]} + assert container.image == container_image + assert container.env == environment + + +def test_async_check_output(): + message = "Hello, World!" + stdout = asyncio.run(async_check_output("echo", message)) + assert stdout.decode() == f"{message}\n" + + with pytest.raises(FileNotFoundError): + asyncio.run(async_check_output("nonexistent_command")) + + with pytest.raises(subprocess.CalledProcessError): + asyncio.run(async_check_output("false")) + + +def test_flyte_to_mmcloud_resources(): + B_IN_GIB = 1073741824 + success_cases = { + ("0", "0", "0", "0"): (1, 1, 1, 1), + ("0", "0", None, None): (1, 1, None, None), + (None, None, "0", "0"): (1, 1, 1, 1), + ("1", "2Gi", "3", "4Gi"): (1, 2, 3, 4), + ("1", "2Gi", None, None): (1, 2, None, None), + (None, None, "3", "4Gi"): (3, 4, 3, 4), + (None, None, None, None): (1, 1, None, None), + ("1.1", str(B_IN_GIB + 1), "2.1", str(2 * B_IN_GIB + 1)): (2, 2, 3, 3), + } + + for (req_cpu, req_mem, lim_cpu, lim_mem), (min_cpu, min_mem, max_cpu, max_mem) in success_cases.items(): + resources = flyte_to_mmcloud_resources( + requests=Resources(cpu=req_cpu, mem=req_mem), + limits=Resources(cpu=lim_cpu, mem=lim_mem), + ) + assert resources == (min_cpu, min_mem, max_cpu, max_mem) + + error_cases = { + ("1", "2Gi", "2", "1Gi"), + ("2", "2Gi", "1", "1Gi"), + ("2", "1Gi", "1", "2Gi"), + } + for (req_cpu, req_mem, lim_cpu, lim_mem) in error_cases: + with pytest.raises(ValueError): + flyte_to_mmcloud_resources( + requests=Resources(cpu=req_cpu, mem=req_mem), + limits=Resources(cpu=lim_cpu, mem=lim_mem), + ) + + +@pytest.mark.skipif(float_missing, reason="float binary is required") +def test_async_agent(): + serialization_settings = SerializationSettings(image_config=ImageConfig()) + context = MagicMock(spec=grpc.ServicerContext) + container_image = DefaultImages.default_image() + + @task( + task_config=MMCloudConfig(submit_extra="--migratePolicy [enable=true]"), + requests=Resources(cpu="2", mem="1Gi"), + limits=Resources(cpu="4", mem="16Gi"), + container_image=container_image, + ) + def say_hello0(name: str) -> str: + return f"Hello, {name}." + + task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello0) + agent = AgentRegistry.get_agent(task_spec.template.type) + + assert isinstance(agent, MMCloudAgent) + + create_task_response = asyncio.run( + agent.async_create( + context=context, + output_prefix="", + task_template=task_spec.template, + inputs=None, + ) + ) + resource_meta = create_task_response.resource_meta + + get_task_response = asyncio.run(agent.async_get(context=context, resource_meta=resource_meta)) + state = get_task_response.resource.state + assert state in (RUNNING, SUCCEEDED) + + asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + + get_task_response = asyncio.run(agent.async_get(context=context, resource_meta=resource_meta)) + state = get_task_response.resource.state + assert state == PERMANENT_FAILURE + + @task( + task_config=MMCloudConfig(submit_extra="--nonexistent"), + requests=Resources(cpu="2", mem="4Gi"), + limits=Resources(cpu="4", mem="16Gi"), + container_image=container_image, + ) + def say_hello1(name: str) -> str: + return f"Hello, {name}." + + task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello1) + with pytest.raises(subprocess.CalledProcessError): + create_task_response = asyncio.run( + agent.async_create( + context=context, + output_prefix="", + task_template=task_spec.template, + inputs=None, + ) + ) + + @task( + task_config=MMCloudConfig(), + limits=Resources(cpu="3", mem="1Gi"), + container_image=container_image, + ) + def say_hello2(name: str) -> str: + return f"Hello, {name}." + + task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello2) + with pytest.raises(subprocess.CalledProcessError): + create_task_response = asyncio.run( + agent.async_create( + context=context, + output_prefix="", + task_template=task_spec.template, + inputs=None, + ) + ) + + @task( + task_config=MMCloudConfig(), + limits=Resources(cpu="2", mem="1Gi"), + container_image=container_image, + ) + def say_hello3(name: str) -> str: + return f"Hello, {name}." + + task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello3) + create_task_response = asyncio.run( + agent.async_create( + context=context, + output_prefix="", + task_template=task_spec.template, + inputs=None, + ) + ) + resource_meta = create_task_response.resource_meta + asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + + @task( + task_config=MMCloudConfig(), + requests=Resources(cpu="2", mem="1Gi"), + container_image=container_image, + ) + def say_hello4(name: str) -> str: + return f"Hello, {name}." + + task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello4) + create_task_response = asyncio.run( + agent.async_create( + context=context, + output_prefix="", + task_template=task_spec.template, + inputs=None, + ) + ) + resource_meta = create_task_response.resource_meta + asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) + + @task( + task_config=MMCloudConfig(), + container_image=container_image, + ) + def say_hello5(name: str) -> str: + return f"Hello, {name}." + + task_spec = get_serializable(OrderedDict(), serialization_settings, say_hello5) + create_task_response = asyncio.run( + agent.async_create( + context=context, + output_prefix="", + task_template=task_spec.template, + inputs=None, + ) + ) + resource_meta = create_task_response.resource_meta + asyncio.run(agent.async_delete(context=context, resource_meta=resource_meta)) diff --git a/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py index a0676fdd90..1dcbc066d2 100644 --- a/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py +++ b/plugins/flytekit-onnx-pytorch/flytekitplugins/onnxpytorch/schema.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple, Type, Union import torch -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from torch.onnx import OperatorExportTypes, TrainingMode from typing_extensions import Annotated, get_args, get_origin @@ -16,9 +16,8 @@ from flytekit.types.file import ONNXFile -@dataclass_json @dataclass -class PyTorch2ONNXConfig: +class PyTorch2ONNXConfig(DataClassJsonMixin): """ PyTorch2ONNXConfig is the config used during the pytorch to ONNX conversion. @@ -53,9 +52,8 @@ class PyTorch2ONNXConfig: export_modules_as_functions: Union[bool, set[Type]] = False -@dataclass_json @dataclass -class PyTorch2ONNX: +class PyTorch2ONNX(DataClassJsonMixin): model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction] = field(default=None) diff --git a/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py index 3305396e20..979e9bdcab 100644 --- a/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py +++ b/plugins/flytekit-onnx-scikitlearn/flytekitplugins/onnxscikitlearn/schema.py @@ -5,7 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union import skl2onnx.common.data_types -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from skl2onnx import convert_sklearn from sklearn.base import BaseEstimator from typing_extensions import Annotated, get_args, get_origin @@ -18,9 +18,8 @@ from flytekit.types.file import ONNXFile -@dataclass_json @dataclass -class ScikitLearn2ONNXConfig: +class ScikitLearn2ONNXConfig(DataClassJsonMixin): """ ScikitLearn2ONNXConfig is the config used during the scikitlearn to ONNX conversion. @@ -71,9 +70,8 @@ def __post_init__(self): raise ValueError("All types in final_types must be in skl2onnx.common.data_types") -@dataclass_json @dataclass -class ScikitLearn2ONNX: +class ScikitLearn2ONNX(DataClassJsonMixin): model: BaseEstimator = field(default=None) diff --git a/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py index af8095e0fb..28ed2c5c62 100644 --- a/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py +++ b/plugins/flytekit-onnx-tensorflow/flytekitplugins/onnxtensorflow/schema.py @@ -4,7 +4,7 @@ import numpy as np import tensorflow as tf import tf2onnx -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from typing_extensions import Annotated, get_args, get_origin from flytekit import FlyteContext @@ -15,9 +15,8 @@ from flytekit.types.file import ONNXFile -@dataclass_json @dataclass -class TensorFlow2ONNXConfig: +class TensorFlow2ONNXConfig(DataClassJsonMixin): """ TensorFlow2ONNXConfig is the config used during the tensorflow to ONNX conversion. @@ -46,9 +45,8 @@ class TensorFlow2ONNXConfig: large_model: bool = False -@dataclass_json @dataclass -class TensorFlow2ONNX: +class TensorFlow2ONNX(DataClassJsonMixin): model: tf.keras.Model = field(default=None) diff --git a/plugins/flytekit-openai-chatgpt/tests/test_agent.py b/plugins/flytekit-openai-chatgpt/tests/test_agent.py new file mode 100644 index 0000000000..00546e6b9c --- /dev/null +++ b/plugins/flytekit-openai-chatgpt/tests/test_agent.py @@ -0,0 +1,136 @@ +import pickle +from datetime import timedelta +from unittest import mock +from unittest.mock import MagicMock + +import grpc +import pytest +from aioresponses import aioresponses +from flyteidl.admin.agent_pb2 import SUCCEEDED +from flytekitplugins.spark.agent import Metadata, get_header + +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import Container, Resources, TaskTemplate + +# make like openai organization example +@pytest.mark.asyncio +async def test_chatgpt_agent(): + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent("dispatcher") + + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + task_config = { + "sparkConf": { + "spark.driver.memory": "1000M", + "spark.executor.memory": "1000M", + "spark.executor.cores": "1", + "spark.executor.instances": "2", + "spark.driver.cores": "1", + }, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "flytekit databricks plugin example", + "new_cluster": { + "spark_version": "12.2.x-scala2.12", + "node_type_id": "n2-highmem-4", + "num_workers": 1, + }, + "timeout_seconds": 3600, + "max_retries": 1, + }, + "databricksInstance": "test-account.cloud.databricks.com", + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=[ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://my-s3-bucket/flytesnacks/development/24UYJEF2HDZQN3SG4VAZSM4PLI======/script_mode.tar.gz", + "--dest-dir", + "/root", + "--", + "pyflyte-execute", + "--inputs", + "s3://my-s3-bucket", + "--output-prefix", + "s3://my-s3-bucket", + "--raw-output-data-prefix", + "s3://my-s3-bucket", + "--checkpoint-path", + "s3://my-s3-bucket", + "--prev-checkpoint", + "s3://my-s3-bucket", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "spark_local_example", + "task-name", + "hello_spark", + ], + resources=Resources( + requests=[], + limits=[], + ), + env={}, + config={}, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + mocked_token = "mocked_databricks_token" + mocked_context = mock.patch("flytekit.current_context", autospec=True).start() + mocked_context.return_value.secrets.get.return_value = mocked_token + + metadata_bytes = pickle.dumps( + Metadata( + databricks_instance="test-account.cloud.databricks.com", + run_id="123", + ) + ) + + mock_create_response = {"run_id": "123"} + mock_get_response = {"run_id": "123", "state": {"result_state": "SUCCESS"}} + mock_delete_response = {} + create_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/submit" + get_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/get?run_id=123" + delete_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/cancel" + with aioresponses() as mocked: + mocked.post(create_url, status=200, payload=mock_create_response) + res = await agent.async_create(ctx, "/tmp", dummy_template, None) + assert res.resource_meta == metadata_bytes + + mocked.get(get_url, status=200, payload=mock_get_response) + res = await agent.async_get(ctx, metadata_bytes) + assert res.resource.state == SUCCEEDED + assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl() + + mocked.post(delete_url, status=200, payload=mock_delete_response) + await agent.async_delete(ctx, metadata_bytes) + + assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} + + mock.patch.stopall() \ No newline at end of file diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 6f4ed6886c..d53f2cebe4 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -41,6 +41,7 @@ class NotebookTask(PythonInstanceTask[T]): """ Simple Papermill based input output handling for a Python Jupyter notebook. This task should be used to wrap a Notebook that has 2 properties + Property 1: One of the cells (usually the first) should be marked as the parameters cell. This task will inject inputs after this cell. The task will inject the outputs observed from Flyte @@ -82,8 +83,6 @@ class NotebookTask(PythonInstanceTask[T]): Step 3: Task can be executed as usual - Outputs - ------- The Task produces 2 implicit outputs. #. It captures the executed notebook in its entirety and is available from Flyte with the name ``out_nb``. @@ -116,7 +115,6 @@ class NotebookTask(PythonInstanceTask[T]): supported - Only supported types are str, int, float, bool Most output types are supported as long as FlyteFile etc is used. - """ _IMPLICIT_OP_NOTEBOOK = "out_nb" @@ -261,8 +259,8 @@ def execute(self, **kwargs) -> Any: """ TODO: Figure out how to share FlyteContext ExecutionParameters with the notebook kernel (as notebook kernel is executed in a separate python process) - For Spark, the notebooks today need to use the new_session or just getOrCreate session and get a handle to the - singleton + + For Spark, the notebooks today need to use the new_session or just getOrCreate session and get a handle to the singleton """ logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.") for k, v in kwargs.items(): diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 47db35793d..986b6fb234 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -1,16 +1,23 @@ import datetime import os +import shutil import tempfile import typing +from unittest import mock import pandas as pd +from click.testing import CliRunner from flytekitplugins.papermill import NotebookTask from flytekitplugins.pod import Pod from kubernetes.client import V1Container, V1PodSpec import flytekit from flytekit import StructuredDataset, kwtypes, map_task, task, workflow +from flytekit.clients.friendly import SynchronousFlyteClient +from flytekit.clis.sdk_in_container import pyflyte from flytekit.configuration import Image, ImageConfig +from flytekit.core import context_manager +from flytekit.remote import FlyteRemote from flytekit.types.directory import FlyteDirectory from flytekit.types.file import FlyteFile, PythonNotebook @@ -189,3 +196,32 @@ def wf(a: float) -> typing.List[float]: return map_task(nb_sub_task)(a=[a, a]) assert wf(a=3.14) == [9.8596, 9.8596] + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_notebook_task(mock_client, mock_remote): + mock_remote._client = mock_client + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + notebook_task = """ +from flytekitplugins.papermill import NotebookTask + +nb_simple = NotebookTask( + name="test", + notebook_path="./core/notebook.ipython", +) +""" + with runner.isolated_filesystem(): + os.makedirs("core", exist_ok=True) + with open(os.path.join("core", "notebook.ipython"), "w") as f: + f.write("notebook.ipython") + f.close() + with open(os.path.join("core", "notebook_task.py"), "w") as f: + f.write(notebook_task) + f.close() + result = runner.invoke(pyflyte.main, ["register", "core"]) + assert "Successfully registered 2 entities" in result.output + shutil.rmtree("core") diff --git a/plugins/flytekit-papermill/tests/testdata/datatype.py b/plugins/flytekit-papermill/tests/testdata/datatype.py index 8e07ef052b..86e4aa0aa0 100644 --- a/plugins/flytekit-papermill/tests/testdata/datatype.py +++ b/plugins/flytekit-papermill/tests/testdata/datatype.py @@ -1,9 +1,8 @@ from dataclasses import dataclass -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin -@dataclass_json @dataclass -class X: +class X(DataClassJsonMixin): x: int diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 4290c88ae4..ea644dc078 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -2,12 +2,13 @@ import pandas as pd import polars as pl +from fsspec.utils import get_protocol from flytekit import FlyteContext +from flytekit.core.data_persistence import get_fsspec_storage_options from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import StructuredDatasetType -from flytekit.types.structured.basic_dfs import get_storage_options from flytekit.types.structured.structured_dataset import ( PARQUET, StructuredDataset, @@ -64,7 +65,7 @@ def decode( current_task_metadata: StructuredDatasetMetadata, ) -> pl.DataFrame: uri = flyte_value.uri - kwargs = get_storage_options(ctx.file_access.data_config, uri) + kwargs = get_fsspec_storage_options(protocol=get_protocol(uri), data_config=ctx.file_access.data_config) if current_task_metadata.structured_dataset_type and current_task_metadata.structured_dataset_type.columns: columns = [c.name for c in current_task_metadata.structured_dataset_type.columns] return pl.read_parquet(uri, columns=columns, use_pyarrow=True, storage_options=kwargs) diff --git a/plugins/flytekit-pydantic/README.md b/plugins/flytekit-pydantic/README.md new file mode 100644 index 0000000000..8eb7267100 --- /dev/null +++ b/plugins/flytekit-pydantic/README.md @@ -0,0 +1,28 @@ +# Flytekit Pydantic Plugin + +Pydantic is a data validation and settings management library that uses Python type annotations to enforce type hints at runtime and provide user-friendly errors when data is invalid. Pydantic models are classes that inherit from `pydantic.BaseModel` and are used to define the structure and validation of data using Python type annotations. + +The plugin adds type support for pydantic models. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-pydantic +``` + + +## Type Example +```python +from pydantic import BaseModel + + +class TrainConfig(BaseModel): + lr: float = 1e-3 + batch_size: int = 32 + files: List[FlyteFile] + directories: List[FlyteDirectory] + +@task +def train(cfg: TrainConfig): + ... +``` diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py new file mode 100644 index 0000000000..23e7e341bd --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/__init__.py @@ -0,0 +1,4 @@ +from .basemodel_transformer import BaseModelTransformer +from .deserialization import set_validators_on_supported_flyte_types as _set_validators_on_supported_flyte_types + +_set_validators_on_supported_flyte_types() # enables you to use flytekit.types in pydantic model diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py new file mode 100644 index 0000000000..325da8e500 --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/basemodel_transformer.py @@ -0,0 +1,67 @@ +"""Serializes & deserializes the pydantic basemodels """ + +from typing import Dict, Type + +import pydantic +from google.protobuf import json_format +from typing_extensions import Annotated + +from flytekit import FlyteContext +from flytekit.core import type_engine +from flytekit.models import literals, types + +from . import deserialization, serialization + +BaseModelLiterals = Annotated[ + Dict[str, literals.Literal], + """ + BaseModel serialized to a LiteralMap consisting of: + 1) the basemodel json with placeholders for flyte types + 2) mapping from placeholders to serialized flyte type values in the object store + """, +] + + +class BaseModelTransformer(type_engine.TypeTransformer[pydantic.BaseModel]): + _TYPE_INFO = types.LiteralType(simple=types.SimpleType.STRUCT) + + def __init__(self): + """Construct pydantic.BaseModelTransformer.""" + super().__init__(name="basemodel-transform", t=pydantic.BaseModel) + + def get_literal_type(self, t: Type[pydantic.BaseModel]) -> types.LiteralType: + return types.LiteralType(simple=types.SimpleType.STRUCT) + + def to_literal( + self, + ctx: FlyteContext, + python_val: pydantic.BaseModel, + python_type: Type[pydantic.BaseModel], + expected: types.LiteralType, + ) -> literals.Literal: + """Convert a given ``pydantic.BaseModel`` to the Literal representation.""" + return serialization.serialize_basemodel(python_val) + + def to_python_value( + self, + ctx: FlyteContext, + lv: literals.Literal, + expected_python_type: Type[pydantic.BaseModel], + ) -> pydantic.BaseModel: + """Re-hydrate the pydantic BaseModel object from Flyte Literal value.""" + basemodel_literals: BaseModelLiterals = lv.map.literals + basemodel_json_w_placeholders = read_basemodel_json_from_literalmap(basemodel_literals) + with deserialization.PydanticDeserializationLiteralStore.attach( + basemodel_literals[serialization.OBJECTS_KEY].map + ): + return expected_python_type.parse_raw(basemodel_json_w_placeholders) + + +def read_basemodel_json_from_literalmap(lv: BaseModelLiterals) -> serialization.SerializedBaseModel: + basemodel_literal: literals.Literal = lv[serialization.BASEMODEL_JSON_KEY] + basemodel_json_w_placeholders = json_format.MessageToJson(basemodel_literal.scalar.generic) + assert isinstance(basemodel_json_w_placeholders, str) + return basemodel_json_w_placeholders + + +type_engine.TypeEngine.register(BaseModelTransformer()) diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py new file mode 100644 index 0000000000..238e78c84d --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/commons.py @@ -0,0 +1,31 @@ +import builtins +import datetime +import typing +from typing import Set + +import numpy +import pyarrow +from typing_extensions import Annotated + +from flytekit.core import type_engine + +MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES: Set[str] = {m.__name__ for m in [builtins, typing, datetime, pyarrow, numpy]} + + +def include_in_flyte_types(t: type) -> bool: + if t is None: + return False + object_module = t.__module__ + if any(object_module.startswith(module) for module in MODULES_TO_EXCLUDE_FROM_FLYTE_TYPES): + return False + return True + + +type_engine.TypeEngine.lazy_import_transformers() # loads all transformers +PYDANTIC_SUPPORTED_FLYTE_TYPES = tuple( + filter(include_in_flyte_types, type_engine.TypeEngine.get_available_transformers()) +) + +# this is the UUID placeholder that is set in the serialized basemodel JSON, connecting that field to +# the literal map that holds the actual object that needs to be deserialized (w/ protobuf) +LiteralObjID = Annotated[str, "Key for unique object in literal map."] diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py new file mode 100644 index 0000000000..24fe5afc1e --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/deserialization.py @@ -0,0 +1,145 @@ +import contextlib +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Type, TypeVar, Union, cast + +import pydantic +from flytekitplugins.pydantic import commons, serialization + +from flytekit.core import context_manager, type_engine +from flytekit.models import literals +from flytekit.types import directory, file + +# this field is used by pydantic to get the validator method +PYDANTIC_VALIDATOR_METHOD_NAME = pydantic.BaseModel.__get_validators__.__name__ +PythonType = TypeVar("PythonType") # target type of the deserialization + + +class PydanticDeserializationLiteralStore: + """ + The purpose of this class is to provide a context manager that can be used to deserialize a basemodel from a + literal map. + + Because pydantic validators are fixed when subclassing a BaseModel, this object is a singleton that + serves as a namespace that can be set with the attach_to_literalmap context manager for the time that + a basemodel is being deserialized. The validators are then accessing this namespace for the flyteobj + placeholders that it is trying to deserialize. + """ + + literal_store: Optional[serialization.LiteralStore] = None # attachement point for the literal map + + def __init__(self) -> None: + raise Exception("This class should not be instantiated") + + def __init_subclass__(cls) -> None: + raise Exception("This class should not be subclassed") + + @classmethod + @contextlib.contextmanager + def attach(cls, literal_map: literals.LiteralMap) -> Generator[None, None, None]: + """ + Read a literal map and populate the object store from it. + + This can be used as a context manager to attach to a literal map for the duration of a deserialization + Note that this is not threadsafe, and designed to manage a single deserialization at a time. + """ + assert not cls.is_attached(), "can only be attached to one literal map at a time." + try: + cls.literal_store = literal_map.literals + yield + finally: + cls.literal_store = None + + @classmethod + def contains(cls, item: commons.LiteralObjID) -> bool: + assert cls.is_attached(), "can only check for existence of a literal when attached to a literal map" + assert cls.literal_store is not None + return item in cls.literal_store + + @classmethod + def is_attached(cls) -> bool: + return cls.literal_store is not None + + @classmethod + def get_python_object( + cls, identifier: commons.LiteralObjID, expected_type: Type[PythonType] + ) -> Optional[PythonType]: + """Deserialize a flyte literal and return the python object.""" + if not cls.is_attached(): + raise Exception("Must attach to a literal map before deserializing") + literal = cls.literal_store[identifier] # type: ignore + python_object = deserialize_flyte_literal(literal, expected_type) + return python_object + + +def set_validators_on_supported_flyte_types() -> None: + """ + Set pydantic validator for the flyte types supported by this plugin. + """ + for flyte_type in commons.PYDANTIC_SUPPORTED_FLYTE_TYPES: + setattr(flyte_type, PYDANTIC_VALIDATOR_METHOD_NAME, add_flyte_validators_for_type(flyte_type)) + + +def add_flyte_validators_for_type( + flyte_obj_type: Type[type_engine.T], +) -> Callable[[Any], Iterator[Callable[[Any], type_engine.T]]]: + """ + Add flyte deserialisation validators to a type. + """ + + previous_validators = cast( + Iterator[Callable[[Any], type_engine.T]], + getattr(flyte_obj_type, PYDANTIC_VALIDATOR_METHOD_NAME, lambda *_: [])(), + ) + + def validator(object_uid_maybe: Union[commons.LiteralObjID, Any]) -> Union[type_engine.T, Any]: + """Partial of deserialize_flyte_literal with the object_type fixed""" + if not PydanticDeserializationLiteralStore.is_attached(): + return object_uid_maybe # this validator should only trigger when we are deserializeing + if not isinstance(object_uid_maybe, str): + return object_uid_maybe # object uids are strings and we dont want to trigger on other types + if not PydanticDeserializationLiteralStore.contains(object_uid_maybe): + return object_uid_maybe # final safety check to make sure that the object uid is in the literal map + return PydanticDeserializationLiteralStore.get_python_object(object_uid_maybe, flyte_obj_type) + + def validator_generator(*args, **kwags) -> Iterator[Callable[[Any], type_engine.T]]: + """Generator that returns validators.""" + yield validator + yield from previous_validators + yield from ADDITIONAL_FLYTETYPE_VALIDATORS.get(flyte_obj_type, []) + + return validator_generator + + +def validate_flytefile(flytefile: Union[str, file.FlyteFile]) -> file.FlyteFile: + """Validate a flytefile (i.e. deserialize).""" + if isinstance(flytefile, file.FlyteFile): + return flytefile + if isinstance(flytefile, str): # when e.g. initializing from config + return file.FlyteFile(flytefile) + else: + raise ValueError(f"Invalid type for flytefile: {type(flytefile)}") + + +def validate_flytedir(flytedir: Union[str, directory.FlyteDirectory]) -> directory.FlyteDirectory: + """Validate a flytedir (i.e. deserialize).""" + if isinstance(flytedir, directory.FlyteDirectory): + return flytedir + if isinstance(flytedir, str): # when e.g. initializing from config + return directory.FlyteDirectory(flytedir) + else: + raise ValueError(f"Invalid type for flytedir: {type(flytedir)}") + + +ADDITIONAL_FLYTETYPE_VALIDATORS: Dict[Type, List[Callable[[Any], Any]]] = { + file.FlyteFile: [validate_flytefile], + directory.FlyteDirectory: [validate_flytedir], +} + + +def deserialize_flyte_literal( + flyteobj_literal: literals.Literal, python_type: Type[PythonType] +) -> Optional[PythonType]: + """Deserialize a Flyte Literal into the python object instance.""" + ctx = context_manager.FlyteContext.current_context() + transformer = type_engine.TypeEngine.get_transformer(python_type) + python_obj = transformer.to_python_value(ctx, flyteobj_literal, python_type) + return python_obj diff --git a/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py new file mode 100644 index 0000000000..cd5b149fd9 --- /dev/null +++ b/plugins/flytekit-pydantic/flytekitplugins/pydantic/serialization.py @@ -0,0 +1,115 @@ +""" +Logic for serializing a basemodel to a literalmap that can be passed between flyte tasks. + +The serialization process is as follows: + +1. Serialize the basemodel to json, replacing all flyte types with unique placeholder strings +2. Serialize the flyte types to separate literals and store them in the flyte object store (a singleton object) +3. Return a literal map with the json and the flyte object store represented as a literalmap {placeholder: flyte type} + +""" +import uuid +from typing import Any, Dict, Union, cast + +import pydantic +from google.protobuf import json_format, struct_pb2 +from typing_extensions import Annotated + +from flytekit.core import context_manager, type_engine +from flytekit.models import literals + +from . import commons + +BASEMODEL_JSON_KEY = "BaseModel JSON" +OBJECTS_KEY = "Serialized Flyte Objects" + +SerializedBaseModel = Annotated[str, "A pydantic BaseModel that has been serialized with placeholders for Flyte types."] + +ObjectStoreID = Annotated[str, "Key for unique literalmap of a serialized basemodel."] +LiteralObjID = Annotated[str, "Key for unique object in literal map."] +LiteralStore = Annotated[Dict[LiteralObjID, literals.Literal], "uid to literals for a serialized BaseModel"] + + +class BaseModelFlyteObjectStore: + """ + This class is an intermediate store for python objects that are being serialized/deserialized. + + On serialization of a basemodel, flyte objects are serialized and stored in this object store. + """ + + def __init__(self) -> None: + self.literal_store: LiteralStore = dict() + + def register_python_object(self, python_object: object) -> LiteralObjID: + """Serialize to literal and return a unique identifier.""" + serialized_item = serialize_to_flyte_literal(python_object) + identifier = make_identifier_for_serializeable(python_object) + assert identifier not in self.literal_store + self.literal_store[identifier] = serialized_item + return identifier + + def to_literal(self) -> literals.Literal: + """Convert the object store to a literal map.""" + return literals.Literal(map=literals.LiteralMap(literals=self.literal_store)) + + +def serialize_basemodel(basemodel: pydantic.BaseModel) -> literals.Literal: + """ + Serializes a given pydantic BaseModel instance into a LiteralMap. + The BaseModel is first serialized into a JSON format, where all Flyte types are replaced with unique placeholder strings. + The Flyte Types are serialized into separate Flyte literals + """ + store = BaseModelFlyteObjectStore() + basemodel_literal = serialize_basemodel_to_literal(basemodel, store) + basemodel_literalmap = literals.LiteralMap( + { + BASEMODEL_JSON_KEY: basemodel_literal, # json with flyte types replaced with placeholders + OBJECTS_KEY: store.to_literal(), # flyte type-engine serialized types + } + ) + literal = literals.Literal(map=basemodel_literalmap) # type: ignore + return literal + + +def serialize_basemodel_to_literal( + basemodel: pydantic.BaseModel, + flyteobject_store: BaseModelFlyteObjectStore, +) -> literals.Literal: + """ + Serialize a pydantic BaseModel to json and protobuf, separating out the Flyte types into a separate store. + On deserialization, the store is used to reconstruct the Flyte types. + """ + + def encoder(obj: Any) -> Union[str, commons.LiteralObjID]: + if isinstance(obj, commons.PYDANTIC_SUPPORTED_FLYTE_TYPES): + return flyteobject_store.register_python_object(obj) + return basemodel.__json_encoder__(obj) + + basemodel_json = basemodel.json(encoder=encoder) + return make_literal_from_json(basemodel_json) + + +def serialize_to_flyte_literal(python_obj: object) -> literals.Literal: + """ + Use the Flyte TypeEngine to serialize a python object to a Flyte Literal. + """ + python_type = type(python_obj) + ctx = context_manager.FlyteContextManager().current_context() + literal_type = type_engine.TypeEngine.to_literal_type(python_type) + literal_obj = type_engine.TypeEngine.to_literal(ctx, python_obj, python_type, literal_type) + return literal_obj + + +def make_literal_from_json(json: str) -> literals.Literal: + """ + Converts the json representation of a pydantic BaseModel to a Flyte Literal. + """ + return literals.Literal(scalar=literals.Scalar(generic=json_format.Parse(json, struct_pb2.Struct()))) # type: ignore + + +def make_identifier_for_serializeable(python_type: object) -> LiteralObjID: + """ + Create a unique identifier for a python object. + """ + unique_id = f"{type(python_type).__name__}_{uuid.uuid4().hex}" + return cast(LiteralObjID, unique_id) diff --git a/plugins/flytekit-pydantic/requirements.in b/plugins/flytekit-pydantic/requirements.in new file mode 100644 index 0000000000..44f25884d7 --- /dev/null +++ b/plugins/flytekit-pydantic/requirements.in @@ -0,0 +1,2 @@ +. +-e file:.#egg=flytekitplugins-pydantic diff --git a/plugins/flytekit-pydantic/requirements.txt b/plugins/flytekit-pydantic/requirements.txt new file mode 100644 index 0000000000..68acf7008a --- /dev/null +++ b/plugins/flytekit-pydantic/requirements.txt @@ -0,0 +1,347 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile requirements.in +# +-e file:.#egg=flytekitplugins-pydantic + # via -r requirements.in +adal==1.2.7 + # via azure-datalake-store +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.0 + # via s3fs +aiohttp==3.8.4 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp +arrow==1.2.3 + # via jinja2-time +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.26.4 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.52 + # via adlfs +azure-identity==1.12.0 + # via adlfs +azure-storage-blob==12.16.0 + # via adlfs +binaryornot==0.4.4 + # via cookiecutter +botocore==1.29.76 + # via aiobotocore +cachetools==5.3.0 + # via google-auth +certifi==2022.12.7 + # via + # kubernetes + # requests +cffi==1.15.1 + # via + # azure-datalake-store + # cryptography +chardet==5.1.0 + # via binaryornot +charset-normalizer==3.1.0 + # via + # aiohttp + # requests +click==8.1.3 + # via + # cookiecutter + # flytekit +cloudpickle==2.2.1 + # via flytekit +cookiecutter==2.1.1 + # via flytekit +croniter==1.3.14 + # via flytekit +cryptography==40.0.2 + # via + # adal + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # pyopenssl +dataclasses-json==0.5.7 + # via flytekit +decorator==5.1.1 + # via gcsfs +deprecated==1.2.13 + # via flytekit +diskcache==5.6.1 + # via flytekit +docker==6.0.1 + # via flytekit +docker-image-py==0.1.12 + # via flytekit +docstring-parser==0.15 + # via flytekit +flyteidl==1.3.20 + # via flytekit +flytekit==1.5.0 + # via flytekitplugins-pydantic +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.4.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.4.0 + # via flytekit +gitdb==4.0.10 + # via gitpython +gitpython==3.1.31 + # via flytekit +google-api-core==2.11.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.17.3 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.2 + # via google-cloud-storage +google-cloud-storage==2.9.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage +googleapis-common-protos==1.59.0 + # via + # flyteidl + # flytekit + # google-api-core + # grpcio-status +grpcio==1.54.0 + # via + # flytekit + # grpcio-status +grpcio-status==1.54.0 + # via flytekit +idna==3.4 + # via + # requests + # yarl +importlib-metadata==6.6.0 + # via + # flytekit + # keyring +isodate==0.6.1 + # via azure-storage-blob +jaraco-classes==3.2.3 + # via keyring +jinja2==3.1.2 + # via + # cookiecutter + # jinja2-time +jinja2-time==0.2.0 + # via cookiecutter +jmespath==1.0.1 + # via botocore +joblib==1.2.0 + # via flytekit +keyring==23.13.1 + # via flytekit +kubernetes==26.1.0 + # via flytekit +markupsafe==2.1.2 + # via jinja2 +marshmallow==3.19.0 + # via + # dataclasses-json + # marshmallow-enum + # marshmallow-jsonschema +marshmallow-enum==1.5.1 + # via dataclasses-json +marshmallow-jsonschema==0.13.0 + # via flytekit +more-itertools==9.1.0 + # via jaraco-classes +msal==1.22.0 + # via + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl +mypy-extensions==1.0.0 + # via typing-inspect +natsort==8.3.1 + # via flytekit +numpy==1.24.3 + # via + # flytekit + # pandas + # pyarrow +oauthlib==3.2.2 + # via requests-oauthlib +packaging==23.1 + # via + # docker + # marshmallow +pandas==1.5.3 + # via flytekit +portalocker==2.7.0 + # via msal-extensions +protobuf==4.22.3 + # via + # flyteidl + # google-api-core + # googleapis-common-protos + # grpcio-status + # protoc-gen-swagger +protoc-gen-swagger==0.1.0 + # via flyteidl +pyarrow==10.0.1 + # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth +pycparser==2.21 + # via cffi +pydantic==1.10.7 + # via flytekitplugins-pydantic +pyjwt[crypto]==2.6.0 + # via + # adal + # msal +pyopenssl==23.1.1 + # via flytekit +python-dateutil==2.8.2 + # via + # adal + # arrow + # botocore + # croniter + # flytekit + # kubernetes + # pandas +python-json-logger==2.0.7 + # via flytekit +python-slugify==8.0.1 + # via cookiecutter +pytimeparse==1.1.8 + # via flytekit +pytz==2023.3 + # via + # flytekit + # pandas +pyyaml==6.0 + # via + # cookiecutter + # flytekit + # kubernetes + # responses +regex==2023.5.5 + # via docker-image-py +requests==2.30.0 + # via + # adal + # azure-core + # azure-datalake-store + # cookiecutter + # docker + # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib + # responses +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes +responses==0.23.1 + # via flytekit +rsa==4.9 + # via google-auth +s3fs==2023.4.0 + # via flytekit +six==1.16.0 + # via + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes + # python-dateutil +smmap==5.0.0 + # via gitdb +sortedcontainers==2.4.0 + # via flytekit +statsd==3.3.0 + # via flytekit +text-unidecode==1.3 + # via python-slugify +types-pyyaml==6.0.12.9 + # via responses +typing-extensions==4.5.0 + # via + # azure-core + # azure-storage-blob + # flytekit + # pydantic + # typing-inspect +typing-inspect==0.8.0 + # via dataclasses-json +urllib3==1.26.15 + # via + # botocore + # docker + # flytekit + # kubernetes + # requests + # responses +websocket-client==1.5.1 + # via + # docker + # kubernetes +wheel==0.40.0 + # via flytekit +wrapt==1.15.0 + # via + # aiobotocore + # deprecated + # flytekit +yarl==1.9.2 + # via aiohttp +zipp==3.15.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/plugins/flytekit-pydantic/setup.py b/plugins/flytekit-pydantic/setup.py new file mode 100644 index 0000000000..313c574dd1 --- /dev/null +++ b/plugins/flytekit-pydantic/setup.py @@ -0,0 +1,40 @@ +from setuptools import setup + +PLUGIN_NAME = "pydantic" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.7.0b0,<2.0.0", "pydantic"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Plugin adding type support for Pydantic models", + url="https://github.com/flyteorg/flytekit/tree/master/plugins/flytekit-pydantic", + long_description=open("README.md").read(), + long_description_content_type="text/markdown", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.8", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-pydantic/tests/folder/test_file1.txt b/plugins/flytekit-pydantic/tests/folder/test_file1.txt new file mode 100644 index 0000000000..257cc5642c --- /dev/null +++ b/plugins/flytekit-pydantic/tests/folder/test_file1.txt @@ -0,0 +1 @@ +foo diff --git a/plugins/flytekit-pydantic/tests/folder/test_file2.txt b/plugins/flytekit-pydantic/tests/folder/test_file2.txt new file mode 100644 index 0000000000..5716ca5987 --- /dev/null +++ b/plugins/flytekit-pydantic/tests/folder/test_file2.txt @@ -0,0 +1 @@ +bar diff --git a/plugins/flytekit-pydantic/tests/test_type_transformer.py b/plugins/flytekit-pydantic/tests/test_type_transformer.py new file mode 100644 index 0000000000..3c02dcb3f1 --- /dev/null +++ b/plugins/flytekit-pydantic/tests/test_type_transformer.py @@ -0,0 +1,296 @@ +import datetime as dt +import os +import pathlib +from typing import Any, Dict, List, Optional, Type, Union + +import pandas as pd +import pytest +from flyteidl.core.types_pb2 import SimpleType +from flytekitplugins.pydantic import BaseModelTransformer +from flytekitplugins.pydantic.commons import PYDANTIC_SUPPORTED_FLYTE_TYPES +from pydantic import BaseModel, Extra + +import flytekit +from flytekit.core import context_manager +from flytekit.core.type_engine import TypeEngine +from flytekit.types import directory +from flytekit.types.file import file + + +class TrainConfig(BaseModel): + """Config BaseModel for testing purposes.""" + + batch_size: int = 32 + lr: float = 1e-3 + loss: str = "cross_entropy" + + class Config: + extra = Extra.forbid + + +class Config(BaseModel): + """Config BaseModel for testing purposes with an optional type hint.""" + + model_config: Optional[Union[Dict[str, TrainConfig], TrainConfig]] = TrainConfig() + + +class ConfigWithDatetime(BaseModel): + """Config BaseModel for testing purposes with datetime type hint.""" + + datetime: dt.datetime = dt.datetime.now() + + +class NestedConfig(BaseModel): + """Nested config BaseModel for testing purposes.""" + + files: "ConfigWithFlyteFiles" + dirs: "ConfigWithFlyteDirs" + df: "ConfigWithPandasDataFrame" + datetime: "ConfigWithDatetime" = ConfigWithDatetime() + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, NestedConfig) and all( + getattr(self, attr) == getattr(__value, attr) for attr in ["files", "dirs", "df", "datetime"] + ) + + +class ConfigRequired(BaseModel): + """Config BaseModel for testing purposes with required attribute.""" + + model_config: Union[Dict[str, TrainConfig], TrainConfig] + + +class ConfigWithFlyteFiles(BaseModel): + """Config BaseModel for testing purposes with flytekit.files.FlyteFile type hint.""" + + flytefiles: List[file.FlyteFile] + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ConfigWithFlyteFiles) and all( + pathlib.Path(self_file).read_text() == pathlib.Path(other_file).read_text() + for self_file, other_file in zip(self.flytefiles, __value.flytefiles) + ) + + +class ConfigWithFlyteDirs(BaseModel): + """Config BaseModel for testing purposes with flytekit.directory.FlyteDirectory type hint.""" + + flytedirs: List[directory.FlyteDirectory] + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ConfigWithFlyteDirs) and all( + os.listdir(self_dir) == os.listdir(other_dir) + for self_dir, other_dir in zip(self.flytedirs, __value.flytedirs) + ) + + +class ConfigWithPandasDataFrame(BaseModel): + """Config BaseModel for testing purposes with pandas.DataFrame type hint.""" + + df: pd.DataFrame + + def __eq__(self, __value: object) -> bool: + return isinstance(__value, ConfigWithPandasDataFrame) and self.df.equals(__value.df) + + +class ChildConfig(Config): + """Child class config BaseModel for testing purposes.""" + + d: List[int] = [1, 2, 3] + + +NestedConfig.update_forward_refs() + + +@pytest.mark.parametrize( + "python_type,kwargs", + [ + (Config, {}), + (ConfigRequired, {"model_config": TrainConfig()}), + (TrainConfig, {}), + (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}), + (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), + (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), + ( + NestedConfig, + { + "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + "dirs": {"flytedirs": ["tests/folder/"]}, + "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, + }, + ), + ], +) +def test_transform_round_trip(python_type: Type, kwargs: Dict[str, Any]): + """Test that a (de-)serialization roundtrip results in the identical BaseModel.""" + + ctx = context_manager.FlyteContextManager().current_context() + + type_transformer = BaseModelTransformer() + + python_value = python_type(**kwargs) + + literal_value = type_transformer.to_literal( + ctx, + python_value, + python_type, + type_transformer.get_literal_type(python_value), + ) + + reconstructed_value = type_transformer.to_python_value(ctx, literal_value, type(python_value)) + + assert reconstructed_value == python_value + + +@pytest.mark.parametrize( + "config_type,kwargs", + [ + (Config, {"model_config": {"foo": TrainConfig(loss="mse")}}), + (ConfigRequired, {"model_config": {"foo": TrainConfig(loss="mse")}}), + (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt"]}), + (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), + (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), + ( + NestedConfig, + { + "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + "dirs": {"flytedirs": ["tests/folder/"]}, + "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, + }, + ), + ], +) +def test_pass_to_workflow(config_type: Type, kwargs: Dict[str, Any]): + """Test passing a BaseModel instance to a workflow works.""" + cfg = config_type(**kwargs) + + @flytekit.task + def train(cfg: config_type) -> config_type: + return cfg + + @flytekit.workflow + def wf(cfg: config_type) -> config_type: + return train(cfg=cfg) + + returned_cfg = wf(cfg=cfg) # type: ignore + + assert returned_cfg == cfg + # TODO these assertions are not valid for all types + + +@pytest.mark.parametrize( + "kwargs", + [ + {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + ], +) +def test_flytefiles_in_wf(kwargs: Dict[str, Any]): + """Test passing a BaseModel instance to a workflow works.""" + cfg = ConfigWithFlyteFiles(**kwargs) + + @flytekit.task + def read(cfg: ConfigWithFlyteFiles) -> str: + with open(cfg.flytefiles[0], "r") as f: + return f.read() + + @flytekit.workflow + def wf(cfg: ConfigWithFlyteFiles) -> str: + return read(cfg=cfg) # type: ignore + + string = wf(cfg=cfg) + assert string in {"foo\n", "bar\n"} # type: ignore + + +@pytest.mark.parametrize( + "kwargs", + [ + {"flytedirs": ["tests/folder/"]}, + ], +) +def test_flytedirs_in_wf(kwargs: Dict[str, Any]): + """Test passing a BaseModel instance to a workflow works.""" + cfg = ConfigWithFlyteDirs(**kwargs) + + @flytekit.task + def listdir(cfg: ConfigWithFlyteDirs) -> List[str]: + return os.listdir(cfg.flytedirs[0]) + + @flytekit.workflow + def wf(cfg: ConfigWithFlyteDirs) -> List[str]: + return listdir(cfg=cfg) # type: ignore + + dirs = wf(cfg=cfg) + assert len(dirs) == 2 # type: ignore + + +def test_double_config_in_wf(): + """Test passing a BaseModel instance to a workflow works.""" + cfg1 = TrainConfig(batch_size=13) + cfg2 = TrainConfig(batch_size=31) + + @flytekit.task + def are_different(cfg1: TrainConfig, cfg2: TrainConfig) -> bool: + return cfg1 != cfg2 + + @flytekit.workflow + def wf(cfg1: TrainConfig, cfg2: TrainConfig) -> bool: + return are_different(cfg1=cfg1, cfg2=cfg2) # type: ignore + + assert wf(cfg1=cfg1, cfg2=cfg2), wf(cfg1=cfg1, cfg2=cfg2) # type: ignore + + +@pytest.mark.parametrize( + "python_type,config_kwargs", + [ + (Config, {}), + (ConfigRequired, {"model_config": TrainConfig()}), + (TrainConfig, {}), + (ConfigWithFlyteFiles, {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}), + (ConfigWithFlyteDirs, {"flytedirs": ["tests/folder/"]}), + (ConfigWithPandasDataFrame, {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}), + ( + NestedConfig, + { + "files": {"flytefiles": ["tests/folder/test_file1.txt", "tests/folder/test_file2.txt"]}, + "dirs": {"flytedirs": ["tests/folder/"]}, + "df": {"df": pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})}, + }, + ), + ], +) +def test_dynamic(python_type: Type[BaseModel], config_kwargs: Dict[str, Any]): + config_instance = python_type(**config_kwargs) + + @flytekit.task + def train(cfg: BaseModel): + print(cfg) + + @flytekit.dynamic(cache=True, cache_version="0.3") + def sub_wf(cfg: BaseModel): + train(cfg=cfg) + + @flytekit.workflow + def wf(): + sub_wf(cfg=config_instance) + + wf() + + +def test_supported(): + assert len(PYDANTIC_SUPPORTED_FLYTE_TYPES) == 9 + + +def test_single_df(): + ctx = context_manager.FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(ConfigWithPandasDataFrame) + assert lt.simple == SimpleType.STRUCT + + pyd = ConfigWithPandasDataFrame(df=pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})) + lit = TypeEngine.to_literal(ctx, pyd, ConfigWithPandasDataFrame, lt) + assert lit.map is not None + offloaded_keys = list(lit.map.literals["Serialized Flyte Objects"].map.literals.keys()) + assert len(offloaded_keys) == 1 + assert ( + lit.map.literals["Serialized Flyte Objects"].map.literals[offloaded_keys[0]].scalar.structured_dataset + is not None + ) diff --git a/plugins/flytekit-snowflake/dev-requirements.in b/plugins/flytekit-snowflake/dev-requirements.in new file mode 100644 index 0000000000..2d73dba5b4 --- /dev/null +++ b/plugins/flytekit-snowflake/dev-requirements.in @@ -0,0 +1 @@ +pytest-asyncio diff --git a/plugins/flytekit-snowflake/dev-requirements.txt b/plugins/flytekit-snowflake/dev-requirements.txt new file mode 100644 index 0000000000..99d3f5e4e9 --- /dev/null +++ b/plugins/flytekit-snowflake/dev-requirements.txt @@ -0,0 +1,20 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile dev-requirements.in +# +exceptiongroup==1.1.3 + # via pytest +iniconfig==2.0.0 + # via pytest +packaging==23.1 + # via pytest +pluggy==1.3.0 + # via pytest +pytest==7.4.0 + # via pytest-asyncio +pytest-asyncio==0.21.1 + # via -r dev-requirements.in +tomli==2.0.1 + # via pytest diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py index 2875e56bdf..b6e341bd5a 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/__init__.py @@ -9,6 +9,8 @@ SnowflakeConfig SnowflakeTask + SnowflakeAgent """ +from .agent import SnowflakeAgent from .task import SnowflakeConfig, SnowflakeTask diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py new file mode 100644 index 0000000000..aaf38fd4c5 --- /dev/null +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/agent.py @@ -0,0 +1,155 @@ +import json +from dataclasses import asdict, dataclass +from typing import Optional + +import grpc +import snowflake.connector +from flyteidl.admin.agent_pb2 import ( + PERMANENT_FAILURE, + SUCCEEDED, + CreateTaskResponse, + DeleteTaskResponse, + GetTaskResponse, + Resource, +) +from snowflake.connector import ProgrammingError + +from flytekit import FlyteContextManager, StructuredDataset, logger +from flytekit.core.type_engine import TypeEngine +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state +from flytekit.models import literals +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate +from flytekit.models.types import LiteralType, StructuredDatasetType + +TASK_TYPE = "snowflake" + + +@dataclass +class Metadata: + user: str + account: str + database: str + schema: str + warehouse: str + table: str + query_id: str + + +class SnowflakeAgent(AgentBase): + def __init__(self): + super().__init__(task_type=TASK_TYPE) + + def get_private_key(self): + from cryptography.hazmat.backends import default_backend + from cryptography.hazmat.primitives import serialization + + import flytekit + + pk_string = flytekit.current_context().secrets.get(TASK_TYPE, "private_key", encode_mode="rb") + p_key = serialization.load_pem_private_key(pk_string, password=None, backend=default_backend()) + + pkb = p_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + return pkb + + def get_connection(self, metadata: Metadata) -> snowflake.connector: + return snowflake.connector.connect( + user=metadata.user, + account=metadata.account, + private_key=self.get_private_key(), + database=metadata.database, + schema=metadata.schema, + warehouse=metadata.warehouse, + ) + + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + params = None + if inputs: + ctx = FlyteContextManager.current_context() + python_interface_inputs = { + name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() + } + native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) + logger.info(f"Create Snowflake agent params with inputs: {native_inputs}") + params = native_inputs + + config = task_template.config + + conn = snowflake.connector.connect( + user=config["user"], + account=config["account"], + private_key=self.get_private_key(), + database=config["database"], + schema=config["schema"], + warehouse=config["warehouse"], + ) + + cs = conn.cursor() + cs.execute_async(task_template.sql.statement, params=params) + + metadata = Metadata( + user=config["user"], + account=config["account"], + database=config["database"], + schema=config["schema"], + warehouse=config["warehouse"], + table=config["table"], + query_id=str(cs.sfqid), + ) + + return CreateTaskResponse(resource_meta=json.dumps(asdict(metadata)).encode("utf-8")) + + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + conn = self.get_connection(metadata) + try: + query_status = conn.get_query_status_throw_if_error(metadata.query_id) + except ProgrammingError as err: + logger.error(err.msg) + context.set_code(grpc.StatusCode.INTERNAL) + context.set_details(err.msg) + return GetTaskResponse(resource=Resource(state=PERMANENT_FAILURE)) + cur_state = convert_to_flyte_state(str(query_status.name)) + res = None + + if cur_state == SUCCEEDED: + ctx = FlyteContextManager.current_context() + output_metadata = f"snowflake://{metadata.user}:{metadata.account}/{metadata.warehouse}/{metadata.database}/{metadata.schema}/{metadata.table}" + res = literals.LiteralMap( + { + "results": TypeEngine.to_literal( + ctx, + StructuredDataset(uri=output_metadata), + StructuredDataset, + LiteralType(structured_dataset_type=StructuredDatasetType(format="")), + ) + } + ).to_flyte_idl() + + return GetTaskResponse(resource=Resource(state=cur_state, outputs=res)) + + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + metadata = Metadata(**json.loads(resource_meta.decode("utf-8"))) + conn = self.get_connection(metadata) + cs = conn.cursor() + try: + cs.execute(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") + cs.fetchall() + finally: + cs.close() + conn.close() + return DeleteTaskResponse() + + +AgentRegistry.register(SnowflakeAgent()) diff --git a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py index 534acb978e..9ac9980a88 100644 --- a/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py +++ b/plugins/flytekit-snowflake/flytekitplugins/snowflake/task.py @@ -3,13 +3,16 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import SQLTask +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin from flytekit.models import task as _task_model -from flytekit.types.schema import FlyteSchema +from flytekit.types.structured import StructuredDataset +_USER_FIELD = "user" _ACCOUNT_FIELD = "account" _DATABASE_FIELD = "database" _SCHEMA_FIELD = "schema" _WAREHOUSE_FIELD = "warehouse" +_TABLE_FIELD = "table" @dataclass @@ -18,6 +21,8 @@ class SnowflakeConfig(object): SnowflakeConfig should be used to configure a Snowflake Task. """ + # The user to query against + user: Optional[str] = None # The account to query against account: Optional[str] = None # The database to query against @@ -26,9 +31,11 @@ class SnowflakeConfig(object): schema: Optional[str] = None # The optional warehouse to set for the given Snowflake query warehouse: Optional[str] = None + # The optional table to set for the given Snowflake query + table: Optional[str] = None -class SnowflakeTask(SQLTask[SnowflakeConfig]): +class SnowflakeTask(AsyncAgentExecutorMixin, SQLTask[SnowflakeConfig]): """ This is the simplest form of a Snowflake Task, that can be used even for tasks that do not produce any output. """ @@ -42,7 +49,7 @@ def __init__( query_template: str, task_config: Optional[SnowflakeConfig] = None, inputs: Optional[Dict[str, Type]] = None, - output_schema_type: Optional[Type[FlyteSchema]] = None, + output_schema_type: Optional[Type[StructuredDataset]] = None, **kwargs, ): """ @@ -76,10 +83,12 @@ def __init__( def get_config(self, settings: SerializationSettings) -> Dict[str, str]: return { + _USER_FIELD: self.task_config.user, _ACCOUNT_FIELD: self.task_config.account, _DATABASE_FIELD: self.task_config.database, _SCHEMA_FIELD: self.task_config.schema, _WAREHOUSE_FIELD: self.task_config.warehouse, + _TABLE_FIELD: self.task_config.table, } def get_sql(self, settings: SerializationSettings) -> Optional[_task_model.Sql]: diff --git a/plugins/flytekit-snowflake/setup.py b/plugins/flytekit-snowflake/setup.py index 219468b380..527daa2486 100644 --- a/plugins/flytekit-snowflake/setup.py +++ b/plugins/flytekit-snowflake/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "snowflake-connector-python>=3.1.0"] __version__ = "0.0.0+develop" @@ -32,4 +32,5 @@ "Topic :: Software Development :: Libraries", "Topic :: Software Development :: Libraries :: Python Modules", ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, ) diff --git a/plugins/flytekit-snowflake/tests/test_agent.py b/plugins/flytekit-snowflake/tests/test_agent.py new file mode 100644 index 0000000000..be9936e708 --- /dev/null +++ b/plugins/flytekit-snowflake/tests/test_agent.py @@ -0,0 +1,120 @@ +import json +from dataclasses import asdict +from datetime import timedelta +from unittest import mock +from unittest.mock import MagicMock + +import grpc +import pytest +from flyteidl.admin.agent_pb2 import SUCCEEDED, DeleteTaskResponse +from flytekitplugins.snowflake.agent import Metadata + +import flytekit.models.interface as interface_models +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task, types +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import Sql, TaskTemplate + + +@mock.patch("flytekitplugins.snowflake.agent.SnowflakeAgent.get_private_key", return_value="pb") +@mock.patch("snowflake.connector.connect") +@pytest.mark.asyncio +async def test_snowflake_agent(mock_conn, mock_get_private_key): + query_status_mock = MagicMock() + query_status_mock.name = "SUCCEEDED" + + # Configure the mock connection to return the mock status object + mock_conn_instance = mock_conn.return_value + mock_conn_instance.get_query_status_throw_if_error.return_value = query_status_mock + + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent("snowflake") + + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + + task_config = { + "user": "dummy_user", + "account": "dummy_account", + "database": "dummy_database", + "schema": "dummy_schema", + "warehouse": "dummy_warehouse", + "table": "dummy_table", + } + + int_type = types.LiteralType(types.SimpleType.INTEGER) + interfaces = interface_models.TypedInterface( + { + "a": interface_models.Variable(int_type, "description1"), + "b": interface_models.Variable(int_type, "description2"), + }, + {}, + ) + task_inputs = literals.LiteralMap( + { + "a": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + "b": literals.Literal(scalar=literals.Scalar(primitive=literals.Primitive(integer=1))), + }, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=None, + config=task_config, + metadata=task_metadata, + interface=interfaces, + type="snowflake", + sql=Sql("SELECT 1"), + ) + + metadata = Metadata( + user="dummy_user", + account="dummy_account", + table="dummy_table", + database="dummy_database", + schema="dummy_schema", + warehouse="dummy_warehouse", + query_id="dummy_query_id", + ) + + res = await agent.async_create(ctx, "/tmp", dummy_template, task_inputs) + metadata.query_id = Metadata(**json.loads(res.resource_meta.decode("utf-8"))).query_id + metadata_bytes = json.dumps(asdict(metadata)).encode("utf-8") + assert res.resource_meta == metadata_bytes + + res = await agent.async_get(ctx, metadata_bytes) + assert res.resource.state == SUCCEEDED + assert ( + res.resource.outputs.literals["results"].scalar.structured_dataset.uri + == "snowflake://dummy_user:dummy_account/dummy_warehouse/dummy_database/dummy_schema/dummy_table" + ) + + delete_response = await agent.async_delete(ctx, metadata_bytes) + + # Assert the response + assert isinstance(delete_response, DeleteTaskResponse) + + # Verify that the expected methods were called on the mock cursor + mock_cursor = mock_conn_instance.cursor.return_value + mock_cursor.fetchall.assert_called_once() + + mock_cursor.execute.assert_called_once_with(f"SELECT SYSTEM$CANCEL_QUERY('{metadata.query_id}')") + mock_cursor.fetchall.assert_called_once() + + # Verify that the connection was closed + mock_cursor.close.assert_called_once() + mock_conn_instance.close.assert_called_once() diff --git a/plugins/flytekit-spark/dev-requirements.in b/plugins/flytekit-spark/dev-requirements.in new file mode 100644 index 0000000000..78d0eca127 --- /dev/null +++ b/plugins/flytekit-spark/dev-requirements.in @@ -0,0 +1,2 @@ +aioresponses +pytest-asyncio diff --git a/plugins/flytekit-spark/dev-requirements.txt b/plugins/flytekit-spark/dev-requirements.txt new file mode 100644 index 0000000000..8d30230498 --- /dev/null +++ b/plugins/flytekit-spark/dev-requirements.txt @@ -0,0 +1,44 @@ +# +# This file is autogenerated by pip-compile with Python 3.9 +# by the following command: +# +# pip-compile dev-requirements.in +# +aiohttp==3.8.5 + # via aioresponses +aioresponses==0.7.4 + # via -r dev-requirements.in +aiosignal==1.3.1 + # via aiohttp +async-timeout==4.0.3 + # via aiohttp +attrs==23.1.0 + # via aiohttp +charset-normalizer==3.2.0 + # via aiohttp +exceptiongroup==1.1.3 + # via pytest +frozenlist==1.4.0 + # via + # aiohttp + # aiosignal +idna==3.4 + # via yarl +iniconfig==2.0.0 + # via pytest +multidict==6.0.4 + # via + # aiohttp + # yarl +packaging==23.1 + # via pytest +pluggy==1.3.0 + # via pytest +pytest==7.4.0 + # via pytest-asyncio +pytest-asyncio==0.21.1 + # via -r dev-requirements.in +tomli==2.0.1 + # via pytest +yarl==1.9.2 + # via aiohttp diff --git a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py index e769540aea..72c9f37c9f 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/__init__.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/__init__.py @@ -17,6 +17,7 @@ from flytekit.configuration import internal as _internal +from .agent import DatabricksAgent from .pyspark_transformers import PySparkPipelineModelTransformer from .schema import SparkDataFrameSchemaReader, SparkDataFrameSchemaWriter, SparkDataFrameTransformer # noqa from .sd_transformers import ParquetToSparkDecodingHandler, SparkToParquetEncodingHandler diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py new file mode 100644 index 0000000000..ce0148ad2f --- /dev/null +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -0,0 +1,98 @@ +import json +import pickle +import typing +from dataclasses import dataclass +from typing import Optional + +import aiohttp +import grpc +from flyteidl.admin.agent_pb2 import PENDING, CreateTaskResponse, DeleteTaskResponse, GetTaskResponse, Resource + +from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, convert_to_flyte_state, get_agent_secret +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + + +@dataclass +class Metadata: + databricks_instance: str + run_id: str + + +class DatabricksAgent(AgentBase): + def __init__(self): + super().__init__(task_type="spark") + + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + custom = task_template.custom + container = task_template.container + databricks_job = custom["databricksConf"] + if not databricks_job["new_cluster"].get("docker_image"): + databricks_job["new_cluster"]["docker_image"] = {"url": container.image} + if not databricks_job["new_cluster"].get("spark_conf"): + databricks_job["new_cluster"]["spark_conf"] = custom["sparkConf"] + databricks_job["spark_python_task"] = { + "python_file": custom["mainApplicationFile"], + "parameters": tuple(container.args), + } + + databricks_instance = custom["databricksInstance"] + databricks_url = f"https://{databricks_instance}/api/2.0/jobs/runs/submit" + data = json.dumps(databricks_job) + + async with aiohttp.ClientSession() as session: + async with session.post(databricks_url, headers=get_header(), data=data) as resp: + if resp.status != 200: + raise Exception(f"Failed to create databricks job with error: {resp.reason}") + response = await resp.json() + + metadata = Metadata( + databricks_instance=databricks_instance, + run_id=str(response["run_id"]), + ) + return CreateTaskResponse(resource_meta=pickle.dumps(metadata)) + + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + metadata = pickle.loads(resource_meta) + databricks_instance = metadata.databricks_instance + databricks_url = f"https://{databricks_instance}/api/2.0/jobs/runs/get?run_id={metadata.run_id}" + + async with aiohttp.ClientSession() as session: + async with session.get(databricks_url, headers=get_header()) as resp: + if resp.status != 200: + raise Exception(f"Failed to get databricks job {metadata.run_id} with error: {resp.reason}") + response = await resp.json() + + cur_state = PENDING + if response.get("state") and response["state"].get("result_state"): + cur_state = convert_to_flyte_state(response["state"]["result_state"]) + + return GetTaskResponse(resource=Resource(state=cur_state)) + + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + metadata = pickle.loads(resource_meta) + + databricks_url = f"https://{metadata.databricks_instance}/api/2.0/jobs/runs/cancel" + data = json.dumps({"run_id": metadata.run_id}) + + async with aiohttp.ClientSession() as session: + async with session.post(databricks_url, headers=get_header(), data=data) as resp: + if resp.status != 200: + raise Exception(f"Failed to cancel databricks job {metadata.run_id} with error: {resp.reason}") + await resp.json() + + return DeleteTaskResponse() + + +def get_header() -> typing.Dict[str, str]: + token = get_agent_secret("FLYTE_DATABRICKS_ACCESS_TOKEN") + return {"Authorization": f"Bearer {token}", "content-type": "application/json"} + + +AgentRegistry.register(DatabricksAgent()) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/task.py b/plugins/flytekit-spark/flytekitplugins/spark/task.py index 0d8ecd5b6e..17099350e4 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/task.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/task.py @@ -111,16 +111,18 @@ def __init__( **kwargs, ): self.sess: Optional[SparkSession] = None - self._default_executor_path: Optional[str] = task_config.executor_path - self._default_applications_path: Optional[str] = task_config.applications_path + self._default_executor_path: str = task_config.executor_path + self._default_applications_path: str = task_config.applications_path if isinstance(container_image, ImageSpec): if container_image.base_image is None: img = f"cr.flyte.org/flyteorg/flytekit:spark-{DefaultImages.get_version_suffix()}" container_image.base_image = img # default executor path and applications path in apache/spark-py:3.3.1 - self._default_executor_path = "/usr/bin/python3" - self._default_applications_path = "local:///usr/local/bin/entrypoint.py" + self._default_executor_path = self._default_executor_path or "/usr/bin/python3" + self._default_applications_path = ( + self._default_applications_path or "local:///usr/local/bin/entrypoint.py" + ) super(PysparkFunctionTask, self).__init__( task_config=task_config, task_type=self._SPARK_TASK_TYPE, diff --git a/plugins/flytekit-spark/setup.py b/plugins/flytekit-spark/setup.py index 4207a0265c..21305263a6 100644 --- a/plugins/flytekit-spark/setup.py +++ b/plugins/flytekit-spark/setup.py @@ -4,7 +4,7 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0"] +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "pyspark>=3.0.0", "aiohttp"] __version__ = "0.0.0+develop" diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py new file mode 100644 index 0000000000..1b3941d7f6 --- /dev/null +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -0,0 +1,136 @@ +import pickle +from datetime import timedelta +from unittest import mock +from unittest.mock import MagicMock + +import grpc +import pytest +from aioresponses import aioresponses +from flyteidl.admin.agent_pb2 import SUCCEEDED +from flytekitplugins.spark.agent import Metadata, get_header + +from flytekit.extend.backend.base_agent import AgentRegistry +from flytekit.interfaces.cli_identifiers import Identifier +from flytekit.models import literals, task +from flytekit.models.core.identifier import ResourceType +from flytekit.models.task import Container, Resources, TaskTemplate + + +@pytest.mark.asyncio +async def test_databricks_agent(): + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent("spark") + + task_id = Identifier( + resource_type=ResourceType.TASK, project="project", domain="domain", name="name", version="version" + ) + task_metadata = task.TaskMetadata( + True, + task.RuntimeMetadata(task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"), + timedelta(days=1), + literals.RetryStrategy(3), + True, + "0.1.1b0", + "This is deprecated!", + True, + "A", + ) + task_config = { + "sparkConf": { + "spark.driver.memory": "1000M", + "spark.executor.memory": "1000M", + "spark.executor.cores": "1", + "spark.executor.instances": "2", + "spark.driver.cores": "1", + }, + "mainApplicationFile": "dbfs:/entrypoint.py", + "databricksConf": { + "run_name": "flytekit databricks plugin example", + "new_cluster": { + "spark_version": "12.2.x-scala2.12", + "node_type_id": "n2-highmem-4", + "num_workers": 1, + }, + "timeout_seconds": 3600, + "max_retries": 1, + }, + "databricksInstance": "test-account.cloud.databricks.com", + } + container = Container( + image="flyteorg/flytekit:databricks-0.18.0-py3.7", + command=[], + args=[ + "pyflyte-fast-execute", + "--additional-distribution", + "s3://my-s3-bucket/flytesnacks/development/24UYJEF2HDZQN3SG4VAZSM4PLI======/script_mode.tar.gz", + "--dest-dir", + "/root", + "--", + "pyflyte-execute", + "--inputs", + "s3://my-s3-bucket", + "--output-prefix", + "s3://my-s3-bucket", + "--raw-output-data-prefix", + "s3://my-s3-bucket", + "--checkpoint-path", + "s3://my-s3-bucket", + "--prev-checkpoint", + "s3://my-s3-bucket", + "--resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "--", + "task-module", + "spark_local_example", + "task-name", + "hello_spark", + ], + resources=Resources( + requests=[], + limits=[], + ), + env={}, + config={}, + ) + + dummy_template = TaskTemplate( + id=task_id, + custom=task_config, + metadata=task_metadata, + container=container, + interface=None, + type="spark", + ) + mocked_token = "mocked_databricks_token" + mocked_context = mock.patch("flytekit.current_context", autospec=True).start() + mocked_context.return_value.secrets.get.return_value = mocked_token + + metadata_bytes = pickle.dumps( + Metadata( + databricks_instance="test-account.cloud.databricks.com", + run_id="123", + ) + ) + + mock_create_response = {"run_id": "123"} + mock_get_response = {"run_id": "123", "state": {"result_state": "SUCCESS"}} + mock_delete_response = {} + create_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/submit" + get_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/get?run_id=123" + delete_url = "https://test-account.cloud.databricks.com/api/2.0/jobs/runs/cancel" + with aioresponses() as mocked: + mocked.post(create_url, status=200, payload=mock_create_response) + res = await agent.async_create(ctx, "/tmp", dummy_template, None) + assert res.resource_meta == metadata_bytes + + mocked.get(get_url, status=200, payload=mock_get_response) + res = await agent.async_get(ctx, metadata_bytes) + assert res.resource.state == SUCCEEDED + assert res.resource.outputs == literals.LiteralMap({}).to_flyte_idl() + + mocked.post(delete_url, status=200, payload=mock_delete_response) + await agent.async_delete(ctx, metadata_bytes) + + assert get_header() == {"Authorization": f"Bearer {mocked_token}", "content-type": "application/json"} + + mock.patch.stopall() diff --git a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py index 8e8c464bd4..541a997f21 100644 --- a/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py +++ b/plugins/flytekit-sqlalchemy/flytekitplugins/sqlalchemy/task.py @@ -35,8 +35,8 @@ class SQLAlchemyConfig(object): sqlalchemy connector format (https://docs.sqlalchemy.org/en/14/core/engines.html#database-urls). Database can be found: - - within the container - - or from a publicly accessible source + - within the container + - or from a publicly accessible source Args: uri: default sqlalchemy connector diff --git a/plugins/flytekit-sqlalchemy/tests/test_task.py b/plugins/flytekit-sqlalchemy/tests/test_task.py index 7537a3a1de..c001f7b543 100644 --- a/plugins/flytekit-sqlalchemy/tests/test_task.py +++ b/plugins/flytekit-sqlalchemy/tests/test_task.py @@ -4,15 +4,21 @@ import sqlite3 import tempfile from typing import Iterator +from unittest import mock import pandas import pytest +from click.testing import CliRunner from flytekitplugins.sqlalchemy import SQLAlchemyConfig, SQLAlchemyTask from flytekitplugins.sqlalchemy.task import SQLAlchemyTaskExecutor from flytekit import kwtypes, task, workflow +from flytekit.clients.friendly import SynchronousFlyteClient +from flytekit.clis.sdk_in_container import pyflyte +from flytekit.core import context_manager from flytekit.core.context_manager import SecretsManager from flytekit.models.security import Secret +from flytekit.remote import FlyteRemote from flytekit.types.schema import FlyteSchema tk = SQLAlchemyTask( @@ -197,3 +203,32 @@ def test_task_serialization_deserialization_with_secret(sql_server): r = executor.execute_from_model(tt) assert r.iat[0, 0] == 1 + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_sql_task(mock_client, mock_remote): + mock_remote._client = mock_client + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + sql_task = """ +from flytekitplugins.sqlalchemy import SQLAlchemyConfig, SQLAlchemyTask + +tk = SQLAlchemyTask( + "test", + query_template="select * from tracks", + task_config=SQLAlchemyConfig( + uri="sqlite://", + ), +) +""" + with runner.isolated_filesystem(): + os.makedirs("core", exist_ok=True) + with open(os.path.join("core", "sql_task.py"), "w") as f: + f.write(sql_task) + f.close() + result = runner.invoke(pyflyte.main, ["register", "core"]) + assert "Successfully registered 1 entities" in result.output + shutil.rmtree("core") diff --git a/setup.py b/setup.py index 2a97af03eb..6828cb8661 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ }, install_requires=[ "googleapis-common-protos>=1.57", - "flyteidl>=1.5.10", + "flyteidl>=1.5.16", "wheel>=0.30.0,<1.0.0", "pandas>=1.0.0,<2.0.0", "pyarrow>=4.0.0,<11.0.0", @@ -38,10 +38,8 @@ "deprecated>=1.0,<2.0", "docker>=4.0.0,<7.0.0", "python-dateutil>=2.1", - # Restrict grpcio and grpcio-status. Version 1.50.0 pulls in a version of protobuf that is not compatible - # with the old protobuf library (as described in https://developers.google.com/protocol-buffers/docs/news/2022-05-06) - "grpcio>=1.50.0,!=1.55.0,<1.53.1,<2.0", - "grpcio-status>=1.50.0,!=1.55.0,<1.53.1,<2.0", + "grpcio", + "grpcio-status", "importlib-metadata", "fsspec>=2023.3.0", "adlfs", @@ -62,6 +60,7 @@ # TODO: remove upper-bound after fixing change in contract "dataclasses-json>=0.5.2,<0.5.12", "marshmallow-jsonschema>=0.12.0", + "mashumaro>=3.9.1", "marshmallow-enum", "natsort>=7.0.1", "docker-image-py>=0.1.10", @@ -75,6 +74,7 @@ "kubernetes>=12.0.1", "rich", "rich_click", + "jsonpickle", ], extras_require=extras_require, scripts=[ diff --git a/tests/flytekit/integration/__init__.py b/tests/flytekit/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/integration/experimental/__init__.py b/tests/flytekit/integration/experimental/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/integration/experimental/eager_workflows.py b/tests/flytekit/integration/experimental/eager_workflows.py new file mode 100644 index 0000000000..2dbc28a640 --- /dev/null +++ b/tests/flytekit/integration/experimental/eager_workflows.py @@ -0,0 +1,154 @@ +import asyncio +import os +import typing +from functools import partial +from pathlib import Path + +import pandas as pd + +from flytekit import task, workflow +from flytekit.configuration import Config +from flytekit.experimental import EagerException, eager +from flytekit.remote import FlyteRemote +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile +from flytekit.types.structured import StructuredDataset + +remote = FlyteRemote( + config=Config.for_sandbox(), + default_project="flytesnacks", + default_domain="development", +) + + +eager_partial = partial(eager, remote=remote) + + +@task +def add_one(x: int) -> int: + return x + 1 + + +@task +def double(x: int) -> int: + return x * 2 + + +@task +def gt_0(x: int) -> bool: + return x > 0 + + +@task +def raises_exc(x: int) -> int: + if x == 0: + raise TypeError + return x + + +@task +def create_structured_dataset() -> StructuredDataset: + df = pd.DataFrame({"a": [1, 2, 3]}) + return StructuredDataset(dataframe=df) + + +@task +def create_file() -> FlyteFile: + fname = "/tmp/flytekit_test_file" + with open(fname, "w") as fh: + fh.write("some data\n") + return FlyteFile(path=fname) + + +@task +def create_directory() -> FlyteDirectory: + dirname = "/tmp/flytekit_test_dir" + Path(dirname).mkdir(exist_ok=True, parents=True) + with open(os.path.join(dirname, "file"), "w") as tmp: + tmp.write("some data\n") + return FlyteDirectory(path=dirname) + + +@eager_partial +async def simple_eager_wf(x: int) -> int: + out = await add_one(x=x) + return await double(x=out) + + +@eager_partial +async def conditional_eager_wf(x: int) -> int: + if await gt_0(x=x): + return -1 + return 1 + + +@eager_partial +async def try_except_eager_wf(x: int) -> int: + try: + return await raises_exc(x=x) + except EagerException: + return -1 + + +@eager_partial +async def gather_eager_wf(x: int) -> typing.List[int]: + results = await asyncio.gather(*[add_one(x=x) for _ in range(10)]) + return results + + +@eager_partial +async def nested_eager_wf(x: int) -> int: + out = await simple_eager_wf(x=x) + return await double(x=out) + + +@workflow +def wf_with_eager_wf(x: int) -> int: + out = simple_eager_wf(x=x) + return double(x=out) + + +@workflow +def subworkflow(x: int) -> int: + return add_one(x=x) + + +@eager_partial +async def eager_wf_with_subworkflow(x: int) -> int: + out = await subworkflow(x=x) + return await double(x=out) + + +@eager_partial +async def eager_wf_structured_dataset() -> int: + dataset = await create_structured_dataset() + df = dataset.open(pd.DataFrame).all() + return int(df["a"].sum()) + + +@eager_partial +async def eager_wf_flyte_file() -> str: + file = await create_file() + file.download() + with open(file.path) as f: + data = f.read().strip() + return data + + +@eager_partial +async def eager_wf_flyte_directory() -> str: + directory = await create_directory() + directory.download() + with open(os.path.join(directory.path, "file")) as f: + data = f.read().strip() + return data + + +@eager(remote=remote, local_entrypoint=True) +async def eager_wf_local_entrypoint(x: int) -> int: + out = await add_one(x=x) + return await double(x=out) + + +if __name__ == "__main__": + print(asyncio.run(simple_eager_wf(x=1))) diff --git a/tests/flytekit/integration/experimental/test_eager_workflows.py b/tests/flytekit/integration/experimental/test_eager_workflows.py new file mode 100644 index 0000000000..ad1bc44112 --- /dev/null +++ b/tests/flytekit/integration/experimental/test_eager_workflows.py @@ -0,0 +1,116 @@ +"""Eager workflow integration tests. + +These tests are currently not run in CI. In order to run this locally you'll need to start a +local flyte cluster, and build and push a flytekit development image: + +``` + +# if you already have a local cluster running, tear it down and start fresh +flytectl demo teardown -v + +# start a local flyte cluster +flytectl demo start + +# build and push the image +docker build . -f Dockerfile.dev -t localhost:30000/flytekit:dev --build-arg PYTHON_VERSION=3.9 +docker push localhost:30000/flytekit:dev + +# run the tests +pytest tests/flytekit/integration/experimental/test_eager_workflows.py +``` +""" + +import asyncio +import os +import subprocess +import time +from pathlib import Path + +import pytest + +from flytekit.configuration import Config +from flytekit.remote import FlyteRemote + +from .eager_workflows import eager_wf_local_entrypoint + +MODULE = "eager_workflows" +MODULE_PATH = Path(__file__).parent / f"{MODULE}.py" +CONFIG = os.environ.get("FLYTECTL_CONFIG", str(Path.home() / ".flyte" / "config-sandbox.yaml")) +IMAGE = os.environ.get("FLYTEKIT_IMAGE", "localhost:30000/flytekit:dev") + + +@pytest.fixture(scope="session") +def register(): + subprocess.run( + [ + "pyflyte", + "-c", + CONFIG, + "register", + "--image", + IMAGE, + "--project", + "flytesnacks", + "--domain", + "development", + MODULE_PATH, + ] + ) + + +@pytest.mark.skipif( + os.environ.get("FLYTEKIT_CI", False), reason="Running workflows with sandbox cluster fails due to memory pressure" +) +@pytest.mark.parametrize( + "entity_type, entity_name, input, output", + [ + ("eager", "simple_eager_wf", 1, 4), + ("eager", "conditional_eager_wf", 1, -1), + ("eager", "conditional_eager_wf", -10, 1), + ("eager", "try_except_eager_wf", 1, 1), + ("eager", "try_except_eager_wf", 0, -1), + ("eager", "gather_eager_wf", 1, [2] * 10), + ("eager", "nested_eager_wf", 1, 8), + ("eager", "eager_wf_with_subworkflow", 1, 4), + ("eager", "eager_wf_structured_dataset", None, 6), + ("eager", "eager_wf_flyte_file", None, "some data"), + ("eager", "eager_wf_flyte_directory", None, "some data"), + ("workflow", "wf_with_eager_wf", 1, 8), + ], +) +def test_eager_workflows(register, entity_type, entity_name, input, output): + remote = FlyteRemote( + config=Config.auto(config_file=CONFIG), + default_project="flytesnacks", + default_domain="development", + ) + + fetch_method = { + "eager": remote.fetch_task, + "workflow": remote.fetch_workflow, + }[entity_type] + + entity = None + for i in range(100): + try: + entity = fetch_method(name=f"{MODULE}.{entity_name}") + break + except Exception: + print(f"retry {i}") + time.sleep(6) + continue + + if entity is None: + raise RuntimeError("failed to fetch entity") + + inputs = {} if input is None else {"x": input} + execution = remote.execute(entity, inputs=inputs, wait=True) + assert execution.outputs["o0"] == output + + +@pytest.mark.skipif( + os.environ.get("FLYTEKIT_CI", False), reason="Running workflows with sandbox cluster fails due to memory pressure" +) +def test_eager_workflow_local_entrypoint(register): + result = asyncio.run(eager_wf_local_entrypoint(x=1)) + assert result == 4 diff --git a/tests/flytekit/integration/remote/__init__.py b/tests/flytekit/integration/remote/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in index ec349a638a..d8e445120d 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in @@ -2,4 +2,3 @@ flytekit>=0.24.0 joblib wheel matplotlib -opencv-python diff --git a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt index 7bd27f438b..e697db1861 100644 --- a/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt +++ b/tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt @@ -1,195 +1,335 @@ # -# This file is autogenerated by pip-compile with Python 3.7 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # make tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.txt # +adlfs==2023.4.0 + # via flytekit +aiobotocore==2.5.2 + # via s3fs +aiohttp==3.8.5 + # via + # adlfs + # aiobotocore + # gcsfs + # s3fs +aioitertools==0.11.0 + # via aiobotocore +aiosignal==1.3.1 + # via aiohttp arrow==1.2.3 - # via jinja2-time + # via cookiecutter +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via aiohttp +azure-core==1.28.0 + # via + # adlfs + # azure-identity + # azure-storage-blob +azure-datalake-store==0.0.53 + # via adlfs +azure-identity==1.13.0 + # via adlfs +azure-storage-blob==12.17.0 + # via adlfs binaryornot==0.4.4 # via cookiecutter -certifi==2022.12.7 - # via requests +botocore==1.29.161 + # via aiobotocore +cachetools==5.3.1 + # via google-auth +certifi==2023.7.22 + # via + # kubernetes + # requests cffi==1.15.1 - # via cryptography -chardet==5.1.0 + # via + # azure-datalake-store + # cryptography +chardet==5.2.0 # via binaryornot -charset-normalizer==3.0.1 - # via requests -click==8.1.3 +charset-normalizer==3.2.0 + # via + # aiohttp + # requests +click==8.1.6 # via # cookiecutter # flytekit + # rich-click cloudpickle==2.2.1 # via flytekit -cookiecutter==2.1.1 +contourpy==1.1.0 + # via matplotlib +cookiecutter==2.2.3 # via flytekit -croniter==1.3.8 +croniter==1.4.1 # via flytekit -cryptography==39.0.0 +cryptography==41.0.4 # via + # azure-identity + # azure-storage-blob + # msal + # pyjwt # pyopenssl # secretstorage cycler==0.11.0 # via matplotlib -dataclasses-json==0.5.7 +dataclasses-json==0.5.9 # via flytekit decorator==5.1.1 - # via retry -deprecated==1.2.13 + # via gcsfs +deprecated==1.2.14 # via flytekit -diskcache==5.4.0 +diskcache==5.6.1 # via flytekit -docker==6.0.1 +docker==6.1.3 # via flytekit docker-image-py==0.1.12 # via flytekit docstring-parser==0.15 # via flytekit -flyteidl==1.3.5 +flyteidl==1.5.13 # via flytekit -flytekit==1.3.1 - # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -fonttools==4.38.0 +flytekit==1.8.2 + # via -r requirements.in +fonttools==4.41.1 # via matplotlib +frozenlist==1.4.0 + # via + # aiohttp + # aiosignal +fsspec==2023.6.0 + # via + # adlfs + # flytekit + # gcsfs + # s3fs +gcsfs==2023.6.0 + # via flytekit gitdb==4.0.10 # via gitpython -gitpython==3.1.30 +gitpython==3.1.35 # via flytekit -googleapis-common-protos==1.58.0 +google-api-core==2.11.1 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.22.0 + # via + # gcsfs + # google-api-core + # google-auth-oauthlib + # google-cloud-core + # google-cloud-storage + # kubernetes +google-auth-oauthlib==1.0.0 + # via gcsfs +google-cloud-core==2.3.3 + # via google-cloud-storage +google-cloud-storage==2.10.0 + # via gcsfs +google-crc32c==1.5.0 + # via google-resumable-media +google-resumable-media==2.5.0 + # via google-cloud-storage +googleapis-common-protos==1.60.0 # via # flyteidl # flytekit + # google-api-core # grpcio-status -grpcio==1.51.1 +grpcio==1.53.0 # via # flytekit # grpcio-status -grpcio-status==1.51.1 +grpcio-status==1.53.0 # via flytekit idna==3.4 - # via requests -importlib-metadata==6.0.0 # via - # click + # requests + # yarl +importlib-metadata==6.8.0 + # via # flytekit # keyring -importlib-resources==5.10.2 - # via keyring -jaraco-classes==3.2.3 +isodate==0.6.1 + # via azure-storage-blob +jaraco-classes==3.3.0 # via keyring jeepney==0.8.0 # via # keyring # secretstorage jinja2==3.1.2 - # via - # cookiecutter - # jinja2-time -jinja2-time==0.2.0 # via cookiecutter -joblib==1.2.0 +jmespath==1.0.1 + # via botocore +joblib==1.3.1 # via - # -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in + # -r requirements.in # flytekit -keyring==23.13.1 +keyring==24.2.0 # via flytekit kiwisolver==1.4.4 # via matplotlib -markupsafe==2.1.2 +kubernetes==27.2.0 + # via flytekit +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.3 # via jinja2 -marshmallow==3.19.0 +marshmallow==3.20.1 # via # dataclasses-json # marshmallow-enum # marshmallow-jsonschema marshmallow-enum==1.5.1 - # via dataclasses-json + # via + # dataclasses-json + # flytekit marshmallow-jsonschema==0.13.0 # via flytekit -matplotlib==3.5.3 - # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -more-itertools==9.0.0 +matplotlib==3.7.2 + # via -r requirements.in +mdurl==0.1.2 + # via markdown-it-py +more-itertools==10.0.0 # via jaraco-classes -mypy-extensions==0.4.3 +msal==1.23.0 + # via + # azure-datalake-store + # azure-identity + # msal-extensions +msal-extensions==1.0.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl +mypy-extensions==1.0.0 # via typing-inspect -natsort==8.2.0 +natsort==8.4.0 # via flytekit -numpy==1.21.6 +numpy==1.25.2 # via + # contourpy # flytekit # matplotlib - # opencv-python # pandas # pyarrow -opencv-python==4.7.0.68 - # via -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in -packaging==23.0 +oauthlib==3.2.2 + # via + # kubernetes + # requests-oauthlib +packaging==23.1 # via # docker # marshmallow # matplotlib -pandas==1.3.5 +pandas==1.5.3 # via flytekit -pillow==9.4.0 +pillow==10.0.0 # via matplotlib -protobuf==4.21.12 +portalocker==2.7.0 + # via msal-extensions +protobuf==4.23.4 # via # flyteidl + # google-api-core # googleapis-common-protos # grpcio-status # protoc-gen-swagger protoc-gen-swagger==0.1.0 # via flyteidl -py==1.11.0 - # via retry pyarrow==10.0.1 # via flytekit +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth pycparser==2.21 # via cffi -pyopenssl==23.0.0 +pygments==2.15.1 + # via rich +pyjwt[crypto]==2.8.0 + # via msal +pyopenssl==23.2.0 # via flytekit pyparsing==3.0.9 # via matplotlib python-dateutil==2.8.2 # via # arrow + # botocore # croniter # flytekit + # kubernetes # matplotlib # pandas -python-json-logger==2.0.4 +python-json-logger==2.0.7 # via flytekit -python-slugify==8.0.0 +python-slugify==8.0.1 # via cookiecutter pytimeparse==1.1.8 # via flytekit -pytz==2022.7.1 +pytz==2023.3 # via # flytekit # pandas -pyyaml==6.0 +pyyaml==6.0.1 # via # cookiecutter # flytekit -regex==2022.10.31 + # kubernetes + # responses +regex==2023.6.3 # via docker-image-py -requests==2.28.2 +requests==2.31.0 # via + # azure-core + # azure-datalake-store # cookiecutter # docker # flytekit + # gcsfs + # google-api-core + # google-cloud-storage + # kubernetes + # msal + # requests-oauthlib # responses -responses==0.22.0 +requests-oauthlib==1.3.1 + # via + # google-auth-oauthlib + # kubernetes +responses==0.23.3 # via flytekit -retry==0.9.2 +rich==13.5.2 + # via + # flytekit + # rich-click +rich-click==1.6.1 + # via flytekit +rsa==4.9 + # via google-auth +s3fs==2023.6.0 # via flytekit secretstorage==3.3.3 # via keyring -singledispatchmethod==1.0 - # via flytekit six==1.16.0 - # via python-dateutil + # via + # azure-core + # azure-identity + # google-auth + # isodate + # kubernetes + # python-dateutil smmap==5.0.0 # via gitdb sortedcontainers==2.4.0 @@ -198,38 +338,39 @@ statsd==3.3.0 # via flytekit text-unidecode==1.3 # via python-slugify -toml==0.10.2 - # via responses -types-toml==0.10.8.1 +types-pyyaml==6.0.12.11 # via responses -typing-extensions==4.4.0 +typing-extensions==4.7.1 # via - # arrow + # azure-core + # azure-storage-blob # flytekit - # gitpython - # importlib-metadata - # kiwisolver - # responses # typing-inspect -typing-inspect==0.8.0 +typing-inspect==0.9.0 # via dataclasses-json -urllib3==1.26.14 +urllib3==1.26.16 # via + # botocore # docker # flytekit + # google-auth + # kubernetes # requests # responses -websocket-client==1.5.0 - # via docker -wheel==0.38.4 +websocket-client==1.6.1 # via - # -r tests/flytekit/integration/remote/mock_flyte_repo/workflows/requirements.in + # docker + # kubernetes +wheel==0.41.0 + # via + # -r requirements.in # flytekit -wrapt==1.14.1 +wrapt==1.15.0 # via + # aiobotocore # deprecated # flytekit -zipp==3.12.0 - # via - # importlib-metadata - # importlib-resources +yarl==1.9.2 + # via aiohttp +zipp==3.16.2 + # via importlib-metadata diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 3466a48d92..84f9746d2f 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -173,10 +173,20 @@ def test_execute_python_task(flyteclient, flyte_workflows_register, flyte_remote remote = FlyteRemote(Config.auto(), PROJECT, "development") execution = remote.execute( - t1, inputs={"a": 10}, version=f"v{VERSION}", wait=True, overwrite_cache=True, envs={"foo": "bar"} + t1, + inputs={"a": 10}, + version=f"v{VERSION}", + wait=True, + overwrite_cache=True, + envs={"foo": "bar"}, + tags=["flyte"], + cluster_pool="gpu", ) assert execution.outputs["t1_int_output"] == 12 assert execution.outputs["c"] == "world" + assert execution.spec.envs == {"foo": "bar"} + assert execution.spec.tags == ["flyte"] + assert execution.spec.cluster_assignment.cluster_pool == "gpu" def test_execute_python_workflow_and_launch_plan(flyteclient, flyte_workflows_register, flyte_remote_env): diff --git a/tests/flytekit/unit/cli/pyflyte/default_arguments/dataclass_wf.py b/tests/flytekit/unit/cli/pyflyte/default_arguments/dataclass_wf.py index d9ba207cf2..a88cfb93ea 100644 --- a/tests/flytekit/unit/cli/pyflyte/default_arguments/dataclass_wf.py +++ b/tests/flytekit/unit/cli/pyflyte/default_arguments/dataclass_wf.py @@ -1,13 +1,12 @@ from dataclasses import dataclass -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from flytekit import task, workflow -@dataclass_json @dataclass -class DataclassA: +class DataclassA(DataClassJsonMixin): a: str b: int diff --git a/tests/flytekit/unit/cli/pyflyte/test_package.py b/tests/flytekit/unit/cli/pyflyte/test_package.py index 4d8251fc57..d3981dec72 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_package.py +++ b/tests/flytekit/unit/cli/pyflyte/test_package.py @@ -4,6 +4,7 @@ from click.testing import CliRunner import flytekit +import flytekit.clis.sdk_in_container.utils import flytekit.configuration import flytekit.tools.serialize_helpers from flytekit import TaskMetadata @@ -120,7 +121,7 @@ def test_package(): def test_pkgs(): - pp = pyflyte.validate_package(None, None, ["a.b", "a.c,b.a", "cc.a"]) + pp = flytekit.clis.sdk_in_container.utils.validate_package(None, None, ["a.b", "a.c,b.a", "cc.a"]) assert pp == ["a.b", "a.c", "b.a", "cc.a"] diff --git a/tests/flytekit/unit/cli/pyflyte/test_register.py b/tests/flytekit/unit/cli/pyflyte/test_register.py index 0a371b76d1..63e98e5302 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_register.py +++ b/tests/flytekit/unit/cli/pyflyte/test_register.py @@ -7,7 +7,7 @@ from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clis.sdk_in_container import pyflyte -from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context +from flytekit.clis.sdk_in_container.helpers import get_and_save_remote_with_click_context, get_remote from flytekit.configuration import Config from flytekit.core import context_manager from flytekit.remote.remote import FlyteRemote @@ -28,6 +28,22 @@ def my_workflow(x: int, y: int) -> int: return sum(x=square(z=x), y=square(z=y)) """ +shell_task = """ +from flytekit.extras.tasks.shell import ShellTask + +t = ShellTask( + name="test", + script="echo 'Hello World'", + ) +""" + + +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote") +def test_get_remote(mock_remote): + r = get_remote(None, "p", "d") + assert r is not None + mock_remote.assert_called_once_with(Config.for_sandbox(), default_project="p", default_domain="d") + @mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote") def test_saving_remote(mock_remote): @@ -69,6 +85,26 @@ def test_register_with_no_output_dir_passed(mock_client, mock_remote): shutil.rmtree("core1") +@mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) +@mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) +def test_register_shell_task(mock_client, mock_remote): + mock_remote._client = mock_client + mock_remote.return_value._version_from_hash.return_value = "dummy_version_from_hash" + mock_remote.return_value.fast_package.return_value = "dummy_md5_bytes", "dummy_native_url" + runner = CliRunner() + context_manager.FlyteEntities.entities.clear() + with runner.isolated_filesystem(): + out = subprocess.run(["git", "init"], capture_output=True) + assert out.returncode == 0 + os.makedirs("core2", exist_ok=True) + with open(os.path.join("core2", "shell_task.py"), "w") as f: + f.write(shell_task) + f.close() + result = runner.invoke(pyflyte.main, ["register", "core2"]) + assert "Successfully registered 2 entities" in result.output + shutil.rmtree("core2") + + @mock.patch("flytekit.clis.sdk_in_container.helpers.FlyteRemote", spec=FlyteRemote) @mock.patch("flytekit.clients.friendly.SynchronousFlyteClient", spec=SynchronousFlyteClient) def test_non_fast_register(mock_client, mock_remote): diff --git a/tests/flytekit/unit/cli/pyflyte/test_run.py b/tests/flytekit/unit/cli/pyflyte/test_run.py index f7cf2f4662..ac5654a0ce 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_run.py +++ b/tests/flytekit/unit/cli/pyflyte/test_run.py @@ -1,39 +1,17 @@ -import functools import json import os import pathlib import sys -import tempfile -import typing -from datetime import datetime, timedelta -from enum import Enum -import click import mock import pytest -import yaml from click.testing import CliRunner -from flytekit import FlyteContextManager from flytekit.clis.sdk_in_container import pyflyte -from flytekit.clis.sdk_in_container.constants import CTX_CONFIG_FILE -from flytekit.clis.sdk_in_container.helpers import FLYTE_REMOTE_INSTANCE_KEY -from flytekit.clis.sdk_in_container.run import ( - REMOTE_FLAG_KEY, - RUN_LEVEL_PARAMS_KEY, - DateTimeType, - DurationParamType, - FileParamType, - FlyteLiteralConverter, - JsonParamType, - get_entities_in_file, - run_command, -) +from flytekit.clis.sdk_in_container.run import RunLevelParams, get_entities_in_file, run_command from flytekit.configuration import Config, Image, ImageConfig from flytekit.core.task import task -from flytekit.core.type_engine import TypeEngine from flytekit.image_spec.image_spec import ImageBuildEngine, ImageSpecBuilder -from flytekit.models.types import SimpleType from flytekit.remote import FlyteRemote WORKFLOW_FILE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "workflow.py") @@ -51,7 +29,7 @@ def remote(): def test_pyflyte_run_wf(remote): - with mock.patch("flytekit.clis.sdk_in_container.helpers.get_and_save_remote_with_click_context"): + with mock.patch("flytekit.clis.sdk_in_container.helpers.get_remote"): runner = CliRunner() module_path = WORKFLOW_FILE result = runner.invoke(pyflyte.main, ["run", module_path, "my_wf", "--help"], catch_exceptions=False) @@ -134,7 +112,6 @@ def test_pyflyte_run_cli(): ], catch_exceptions=False, ) - print(result.stdout) assert result.exit_code == 0 @@ -155,7 +132,6 @@ def test_union_type1(input): ], catch_exceptions=False, ) - print(result.stdout) assert result.exit_code == 0 @@ -165,13 +141,25 @@ def test_union_type1(input): ) def test_union_type2(input): runner = CliRunner() - env = '{"foo": "bar"}' + env = "foo=bar" result = runner.invoke( pyflyte.main, - ["run", "--overwrite-cache", "--envs", env, os.path.join(DIR_NAME, "workflow.py"), "test_union2", "--a", input], + [ + "run", + "--overwrite-cache", + "--envvars", + env, + "--tag", + "flyte", + "--tag", + "hello", + os.path.join(DIR_NAME, "workflow.py"), + "test_union2", + "--a", + input, + ], catch_exceptions=False, ) - print(result.stdout) assert result.exit_code == 0 @@ -194,9 +182,18 @@ def test_union_type_with_invalid_input(): def test_get_entities_in_file(): e = get_entities_in_file(WORKFLOW_FILE, False) - assert e.workflows == ["my_wf"] - assert e.tasks == ["get_subset_df", "print_all", "show_sd", "test_union1", "test_union2"] - assert e.all() == ["my_wf", "get_subset_df", "print_all", "show_sd", "test_union1", "test_union2"] + assert e.workflows == ["my_wf", "wf_with_none"] + assert e.tasks == ["get_subset_df", "print_all", "show_sd", "task_with_optional", "test_union1", "test_union2"] + assert e.all() == [ + "my_wf", + "wf_with_none", + "get_subset_df", + "print_all", + "show_sd", + "task_with_optional", + "test_union1", + "test_union2", + ] @pytest.mark.parametrize( @@ -221,11 +218,11 @@ def test_nested_workflow(working_dir, wf_path, monkeypatch: pytest.MonkeyPatch): wf_path, "wf_id", "--m", - "wow", + "Running Execution on local.", ], catch_exceptions=False, ) - assert result.stdout.strip() == "wow" + assert result.stdout.strip() == "Running Execution on local.\nRunning Execution on local." assert result.exit_code == 0 @@ -245,7 +242,6 @@ def test_list_default_arguments(wf_path): ], catch_exceptions=False, ) - print(result.stdout) assert result.exit_code == 0 @@ -297,7 +293,7 @@ def test_list_default_arguments(wf_path): ], ) @pytest.mark.skipif( - os.environ["GITHUB_ACTIONS"] == "true" and sys.platform == "darwin", + os.environ.get("GITHUB_ACTIONS") == "true" and sys.platform == "darwin", reason="Github macos-latest image does not have docker installed as per https://github.com/orgs/community/discussions/25777", ) def test_pyflyte_run_run(mock_image, image_string, leaf_configuration_file_name, final_image_config): @@ -310,7 +306,7 @@ def build_image(self, img): ImageBuildEngine.register("test", TestImageSpecBuilder()) @task - def a(): + def tk(): ... mock_click_ctx = mock.MagicMock() @@ -318,22 +314,18 @@ def a(): image_tuple = (image_string,) image_config = ImageConfig.validate_image(None, "", image_tuple) - run_level_params = { - "project": "p", - "domain": "d", - "image_config": image_config, - } - pp = pathlib.Path.joinpath( pathlib.Path(__file__).parent.parent.parent, "configuration/configs/", leaf_configuration_file_name ) - obj = { - RUN_LEVEL_PARAMS_KEY: run_level_params, - REMOTE_FLAG_KEY: True, - FLYTE_REMOTE_INSTANCE_KEY: mock_remote, - CTX_CONFIG_FILE: str(pp), - } + obj = RunLevelParams( + project="p", + domain="d", + image_config=image_config, + remote=True, + config_file=str(pp), + ) + obj._remote = mock_remote mock_click_ctx.obj = obj def check_image(*args, **kwargs): @@ -341,115 +333,27 @@ def check_image(*args, **kwargs): mock_remote.register_script.side_effect = check_image - run_command(mock_click_ctx, a)() - + run_command(mock_click_ctx, tk)() -def test_file_param(): - m = mock.MagicMock() - l = FileParamType().convert(__file__, m, m) - assert l.local - r = FileParamType().convert("https://tmp/file", m, m) - assert r.local is False - -class Color(Enum): - RED = "red" - GREEN = "green" - BLUE = "blue" - - -@pytest.mark.parametrize( - "python_type, python_value", - [ - (typing.Union[typing.List[int], str, Color], "flyte"), - (typing.Union[typing.List[int], str, Color], "red"), - (typing.Union[typing.List[int], str, Color], [1, 2, 3]), - (typing.List[int], [1, 2, 3]), - (typing.Dict[str, int], {"flyte": 2}), - ], -) -def test_literal_converter(python_type, python_value): - get_upload_url_fn = functools.partial( - FlyteRemote(Config.auto()).client.get_upload_signed_url, project="p", domain="d" - ) - click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) - ctx = FlyteContextManager.current_context() - lt = TypeEngine.to_literal_type(python_type) - - lc = FlyteLiteralConverter( - click_ctx, ctx, literal_type=lt, python_type=python_type, get_upload_url_fn=get_upload_url_fn +@pytest.mark.parametrize("a_val", ["foo", "1", None]) +def test_pyflyte_run_with_none(a_val): + runner = CliRunner() + args = [ + "run", + WORKFLOW_FILE, + "wf_with_none", + ] + if a_val is not None: + args.extend(["--a", a_val]) + result = runner.invoke( + pyflyte.main, + args, + catch_exceptions=False, ) - - assert lc.convert(click_ctx, ctx, python_value) == TypeEngine.to_literal(ctx, python_value, python_type, lt) - - -def test_enum_converter(): - pt = typing.Union[str, Color] - - get_upload_url_fn = functools.partial(FlyteRemote(Config.auto()).client.get_upload_signed_url) - click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) - ctx = FlyteContextManager.current_context() - lt = TypeEngine.to_literal_type(pt) - lc = FlyteLiteralConverter(click_ctx, ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn) - union_lt = lc.convert(click_ctx, ctx, "red").scalar.union - - assert union_lt.stored_type.simple == SimpleType.STRING - assert union_lt.stored_type.enum_type is None - - pt = typing.Union[Color, str] - lt = TypeEngine.to_literal_type(typing.Union[Color, str]) - lc = FlyteLiteralConverter(click_ctx, ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn) - union_lt = lc.convert(click_ctx, ctx, "red").scalar.union - - assert union_lt.stored_type.simple is None - assert union_lt.stored_type.enum_type.values == ["red", "green", "blue"] - - -def test_duration_type(): - t = DurationParamType() - assert t.convert(value="1 day", param=None, ctx=None) == timedelta(days=1) - - with pytest.raises(click.BadParameter): - t.convert(None, None, None) - - -def test_datetime_type(): - t = DateTimeType() - - assert t.convert("2020-01-01", None, None) == datetime(2020, 1, 1) - - now = datetime.now() - v = t.convert("now", None, None) - assert v.day == now.day - assert v.month == now.month - - -def test_json_type(): - t = JsonParamType() - assert t.convert(value='{"a": "b"}', param=None, ctx=None) == {"a": "b"} - - with pytest.raises(click.BadParameter): - t.convert(None, None, None) - - # test that it loads a json file - with tempfile.NamedTemporaryFile("w", delete=False) as f: - json.dump({"a": "b"}, f) - f.flush() - assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"} - - # test that if the file is not a valid json, it raises an error - with tempfile.NamedTemporaryFile("w", delete=False) as f: - f.write("asdf") - f.flush() - with pytest.raises(click.BadParameter): - t.convert(value=f.name, param="asdf", ctx=None) - - # test if the file does not exist - with pytest.raises(click.BadParameter): - t.convert(value="asdf", param=None, ctx=None) - - # test if the file is yaml and ends with .yaml it works correctly - with tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) as f: - yaml.dump({"a": "b"}, f) - f.flush() - assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"} + output = result.stdout.strip().split("\n")[-1].strip() + if a_val is None: + assert output == "default" + else: + assert output == a_val + assert result.exit_code == 0 diff --git a/tests/flytekit/unit/cli/pyflyte/workflow.py b/tests/flytekit/unit/cli/pyflyte/workflow.py index 311f141a22..0b6ba98540 100644 --- a/tests/flytekit/unit/cli/pyflyte/workflow.py +++ b/tests/flytekit/unit/cli/pyflyte/workflow.py @@ -4,7 +4,7 @@ from dataclasses import dataclass import pandas as pd -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from typing_extensions import Annotated from flytekit import kwtypes, task, workflow @@ -28,9 +28,8 @@ def show_sd(in_sd: StructuredDataset): print(df) -@dataclass_json @dataclass -class MyDataclass(object): +class MyDataclass(DataClassJsonMixin): i: int a: typing.List[str] @@ -99,3 +98,13 @@ def my_wf( show_sd(in_sd=image) print_all(a=a, b=b, c=c, d=d, e=e, f=f, g=g, h=h, i=i, j=j, k=k, l=l, m=m, n=n, o=o, p=p) return x + + +@task +def task_with_optional(a: typing.Optional[str]) -> str: + return "default" if a is None else a + + +@workflow +def wf_with_none(a: typing.Optional[str] = None) -> str: + return task_with_optional(a=a) diff --git a/tests/flytekit/unit/clients/auth/test_authenticator.py b/tests/flytekit/unit/clients/auth/test_authenticator.py index 32709e1eaa..82ffa654dd 100644 --- a/tests/flytekit/unit/clients/auth/test_authenticator.py +++ b/tests/flytekit/unit/clients/auth/test_authenticator.py @@ -67,8 +67,15 @@ def test_command_authenticator(mock_subprocess: MagicMock): authn.refresh_credentials() -@patch("flytekit.clients.auth.token_client.requests") -def test_client_creds_authenticator(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_client_creds_authenticator(mock_session): + session = MagicMock() + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + session.post.return_value = response + mock_session.return_value = session + authn = ClientCredentialsAuthenticator( ENDPOINT, client_id="client", @@ -77,13 +84,11 @@ def test_client_creds_authenticator(mock_requests): http_proxy_url="https://my-proxy:31111", ) - response = MagicMock() - response.status_code = 200 - response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response authn.refresh_credentials() expected_scopes = static_cfg_store.get_client_config().scopes + assert authn._creds + assert authn._creds.access_token == "abc" assert authn._scopes == expected_scopes @@ -113,9 +118,17 @@ def test_device_flow_authenticator(poll_mock: MagicMock, device_mock: MagicMock, assert authn._creds -@patch("flytekit.clients.auth.token_client.requests") -def test_client_creds_authenticator_with_custom_scopes(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_client_creds_authenticator_with_custom_scopes(mock_session): expected_scopes = ["foo", "baz"] + + session = MagicMock() + response = MagicMock() + response.status_code = 200 + response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") + session.post.return_value = response + mock_session.return_value = session + authn = ClientCredentialsAuthenticator( ENDPOINT, client_id="client", @@ -124,11 +137,9 @@ def test_client_creds_authenticator_with_custom_scopes(mock_requests): scopes=expected_scopes, verify=True, ) - response = MagicMock() - response.status_code = 200 - response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response + authn.refresh_credentials() assert authn._creds + assert authn._creds.access_token == "abc" assert authn._scopes == expected_scopes diff --git a/tests/flytekit/unit/clients/auth/test_token_client.py b/tests/flytekit/unit/clients/auth/test_token_client.py index d0e75ec88a..a7e9c9c280 100644 --- a/tests/flytekit/unit/clients/auth/test_token_client.py +++ b/tests/flytekit/unit/clients/auth/test_token_client.py @@ -22,12 +22,14 @@ def test_get_basic_authorization_header(): assert header == "Basic Y2xpZW50X2lkOmFiYyUyNSUyNSUyNCUzRiU1QyUyRiU1QyUyRg==" -@patch("flytekit.clients.auth.token_client.requests") -def test_get_token(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_get_token(mock_session): + session = MagicMock() response = MagicMock() response.status_code = 200 response.json.return_value = json.loads("""{"access_token": "abc", "expires_in": 60}""") - mock_requests.post.return_value = response + session.post.return_value = response + mock_session.return_value = session access, expiration = get_token( "https://corp.idp.net", client_id="abc123", scopes=["my_scope"], http_proxy_url="http://proxy:3000", verify=True ) @@ -35,11 +37,13 @@ def test_get_token(mock_requests): assert expiration == 60 -@patch("flytekit.clients.auth.token_client.requests") -def test_get_device_code(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_get_device_code(mock_session): + session = MagicMock() response = MagicMock() response.ok = False - mock_requests.post.return_value = response + session.post.return_value = response + mock_session.return_value = session with pytest.raises(AuthenticationError): get_device_code("test.com", "test", http_proxy_url="http://proxy:3000") @@ -51,18 +55,21 @@ def test_get_device_code(mock_requests): "expires_in": 600, "interval": 5, } - mock_requests.post.return_value = response + session.post.return_value = response c = get_device_code("test.com", "test", http_proxy_url="http://proxy:3000") assert c assert c.device_code == "code" -@patch("flytekit.clients.auth.token_client.requests") -def test_poll_token_endpoint(mock_requests): +@patch("flytekit.clients.auth.token_client.requests.Session") +def test_poll_token_endpoint(mock_session): + session = MagicMock() response = MagicMock() response.ok = False response.json.return_value = {"error": error_auth_pending} - mock_requests.post.return_value = response + + session.post.return_value = response + mock_session.return_value = session r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=1) with pytest.raises(AuthenticationError): @@ -71,8 +78,9 @@ def test_poll_token_endpoint(mock_requests): response = MagicMock() response.ok = True response.json.return_value = {"access_token": "abc", "expires_in": 60} - mock_requests.post.return_value = response + session.post.return_value = response r = DeviceCodeResponse(device_code="x", user_code="y", verification_uri="v", expires_in=1, interval=0) t, e = poll_token_endpoint(r, "test.com", "test", http_proxy_url="http://proxy:3000", verify=True) - assert t - assert e + + assert t == "abc" + assert e == 60 diff --git a/tests/flytekit/unit/clients/test_auth_helper.py b/tests/flytekit/unit/clients/test_auth_helper.py index 3bd57918f4..9578f81b3e 100644 --- a/tests/flytekit/unit/clients/test_auth_helper.py +++ b/tests/flytekit/unit/clients/test_auth_helper.py @@ -1,7 +1,9 @@ import os.path +from http import HTTPStatus from unittest.mock import MagicMock, patch import pytest +import requests from flyteidl.service.auth_pb2 import OAuth2MetadataResponse, PublicClientAuthConfigResponse from flytekit.clients.auth.authenticator import ( @@ -16,8 +18,10 @@ from flytekit.clients.auth_helper import ( RemoteClientConfigStore, get_authenticator, + get_session, load_cert, upgrade_channel_to_authenticated, + upgrade_channel_to_proxy_authenticated, wrap_exceptions_channel, ) from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor @@ -76,7 +80,7 @@ def get_client_config(**kwargs) -> ClientConfigStore: authorization_endpoint=OAUTH_AUTHORIZE, redirect_uri=REDIRECT_URI, client_id=CLIENT_ID, - **kwargs + **kwargs, ) return cfg_store @@ -160,8 +164,45 @@ def test_upgrade_channel_to_auth(): assert isinstance(out_ch._interceptor, AuthUnaryInterceptor) # noqa +def test_upgrade_channel_to_proxy_auth(): + ch = MagicMock() + out_ch = upgrade_channel_to_proxy_authenticated( + PlatformConfig( + auth_mode="Pkce", + proxy_command=["echo", "foo-bar"], + ), + ch, + ) + assert isinstance(out_ch._interceptor, AuthUnaryInterceptor) + assert isinstance(out_ch._interceptor._authenticator, CommandAuthenticator) + + def test_load_cert(): cert_file = os.path.join(os.path.dirname(__file__), "testdata", "rootCACert.pem") f = load_cert(cert_file) assert f print(f) + + +def test_get_proxy_authenticated_session(): + """Test that proxy auth headers are added to http requests if the proxy command is provided in the platform config.""" + expected_token = "foo-bar" + platform_config = PlatformConfig( + endpoint="http://my-flyte-deployment.com", + proxy_command=["echo", expected_token], + ) + + with patch("requests.adapters.HTTPAdapter.send") as mock_send: + mock_response = requests.Response() + mock_response.status_code = HTTPStatus.UNAUTHORIZED + mock_response._content = b"{}" + mock_send.return_value = mock_response + + session = get_session(platform_config) + request = requests.Request("GET", platform_config.endpoint) + prepared_request = session.prepare_request(request) + + # Send the request to trigger the addition of the proxy auth headers + session.send(prepared_request) + + assert prepared_request.headers["proxy-authorization"] == f"Bearer {expected_token}" diff --git a/tests/flytekit/unit/configuration/configs/creds_secret_env_var.yaml b/tests/flytekit/unit/configuration/configs/creds_secret_env_var.yaml new file mode 100644 index 0000000000..e0d4748460 --- /dev/null +++ b/tests/flytekit/unit/configuration/configs/creds_secret_env_var.yaml @@ -0,0 +1,13 @@ +admin: + # For GRPC endpoints you might want to use dns:///flyte.myexample.com + endpoint: dns:///flyte.mycorp.io + clientSecretEnvVar: FAKE_SECRET_NAME + insecure: true + clientId: propeller + scopes: + - all +storage: + connection: + access-key: minio + endpoint: http://localhost:30084 + secret-key: miniostorage diff --git a/tests/flytekit/unit/configuration/test_internal.py b/tests/flytekit/unit/configuration/test_internal.py index 97e30b5612..5c1da14a5b 100644 --- a/tests/flytekit/unit/configuration/test_internal.py +++ b/tests/flytekit/unit/configuration/test_internal.py @@ -2,7 +2,7 @@ import mock -from flytekit.configuration import PlatformConfig, get_config_file, read_file_if_exists +from flytekit.configuration import AuthType, PlatformConfig, get_config_file, read_file_if_exists from flytekit.configuration.internal import AWS, Credentials, Images @@ -45,6 +45,25 @@ def test_client_secret_location(): # Assert that secret in platform config does not contain a newline platform_cfg = PlatformConfig.auto(cfg) assert platform_cfg.client_credentials_secret == "hello" + assert platform_cfg.auth_mode == AuthType.CLIENTSECRET.value + + +@mock.patch.dict("os.environ") +def test_client_secret_env_var(): + cfg = get_config_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/sample.yaml")) + secret_env_var = Credentials.CLIENT_CREDENTIALS_SECRET_ENV_VAR.read(cfg) + assert secret_env_var is None + + cfg = get_config_file( + os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs/creds_secret_env_var.yaml") + ) + secret_env_var = Credentials.CLIENT_CREDENTIALS_SECRET_ENV_VAR.read(cfg) + assert secret_env_var == "FAKE_SECRET_NAME" + + os.environ["FAKE_SECRET_NAME"] = "fake_secret_value" + platform_cfg = PlatformConfig.auto(cfg) + assert platform_cfg.client_credentials_secret == "fake_secret_value" + assert platform_cfg.auth_mode == AuthType.CLIENTSECRET.value def test_read_file_if_exists(): diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py new file mode 100644 index 0000000000..2de15667d5 --- /dev/null +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -0,0 +1,232 @@ +import functools +from collections import OrderedDict +from typing import List + +import pytest + +from flytekit import task, workflow +from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings +from flytekit.core.array_node_map_task import ArrayNodeMapTask +from flytekit.core.task import TaskMetadata +from flytekit.experimental import map_task as array_node_map_task +from flytekit.tools.translator import get_serializable + + +@pytest.fixture +def serialization_settings(): + default_img = Image(name="default", fqn="test", tag="tag") + return SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + +def test_map(serialization_settings): + @task + def say_hello(name: str) -> str: + return f"hello {name}!" + + @workflow + def wf() -> List[str]: + return array_node_map_task(say_hello)(name=["abc", "def"]) + + res = wf() + assert res is not None + + +def test_execution(serialization_settings): + @task + def say_hello(name: str) -> str: + return f"hello {name}!" + + @task + def create_input_list() -> List[str]: + return ["earth", "mars"] + + @workflow + def wf() -> List[str]: + xs = array_node_map_task(say_hello)(name=create_input_list()) + return array_node_map_task(say_hello)(name=xs) + + assert wf() == ["hello hello earth!!", "hello hello mars!!"] + + +def test_serialization(serialization_settings): + @task + def t1(a: int) -> int: + return a + 1 + + arraynode_maptask = array_node_map_task(t1, metadata=TaskMetadata(retries=2)) + task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask) + + assert task_spec.template.metadata.retries.retries == 2 + assert task_spec.template.custom["minSuccessRatio"] == 1.0 + assert task_spec.template.type == "python-task" + assert task_spec.template.task_type_version == 1 + assert task_spec.template.container.args == [ + "pyflyte-map-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--experimental", + "--resolver", + "ArrayNodeMapTaskResolver", + "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "task-module", + "tests.flytekit.unit.core.test_array_node_map_task", + "task-name", + "t1", + ] + + +def test_fast_serialization(serialization_settings): + serialization_settings.fast_serialization_settings = FastSerializationSettings(enabled=True) + + @task + def t1(a: int) -> int: + return a + 1 + + arraynode_maptask = array_node_map_task(t1, metadata=TaskMetadata(retries=2)) + task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask) + + assert task_spec.template.container.args == [ + "pyflyte-fast-execute", + "--additional-distribution", + "{{ .remote_package_path }}", + "--dest-dir", + "{{ .dest_dir }}", + "--", + "pyflyte-map-execute", + "--inputs", + "{{.input}}", + "--output-prefix", + "{{.outputPrefix}}", + "--raw-output-data-prefix", + "{{.rawOutputDataPrefix}}", + "--checkpoint-path", + "{{.checkpointOutputPrefix}}", + "--prev-checkpoint", + "{{.prevCheckpointPrefix}}", + "--experimental", + "--resolver", + "ArrayNodeMapTaskResolver", + "--", + "vars", + "", + "resolver", + "flytekit.core.python_auto_container.default_task_resolver", + "task-module", + "tests.flytekit.unit.core.test_array_node_map_task", + "task-name", + "t1", + ] + + +@pytest.mark.parametrize( + "kwargs1, kwargs2, same", + [ + ({}, {}, True), + ({}, {"concurrency": 2}, False), + ({}, {"min_successes": 3}, False), + ({}, {"min_success_ratio": 0.2}, False), + ({}, {"concurrency": 10, "min_successes": 999, "min_success_ratio": 0.2}, False), + ({"concurrency": 1}, {"concurrency": 2}, False), + ({"concurrency": 42}, {"concurrency": 42}, True), + ({"min_successes": 1}, {"min_successes": 2}, False), + ({"min_successes": 42}, {"min_successes": 42}, True), + ({"min_success_ratio": 0.1}, {"min_success_ratio": 0.2}, False), + ({"min_success_ratio": 0.42}, {"min_success_ratio": 0.42}, True), + ({"min_success_ratio": 0.42}, {"min_success_ratio": 0.42}, True), + ( + { + "concurrency": 1, + "min_successes": 2, + "min_success_ratio": 0.42, + }, + { + "concurrency": 1, + "min_successes": 2, + "min_success_ratio": 0.99, + }, + False, + ), + ], +) +def test_metadata_in_task_name(kwargs1, kwargs2, same): + @task + def say_hello(name: str) -> str: + return f"hello {name}!" + + t1 = array_node_map_task(say_hello, **kwargs1) + t2 = array_node_map_task(say_hello, **kwargs2) + + assert (t1.name == t2.name) is same + + +def test_inputs_outputs_length(): + @task + def many_inputs(a: int, b: str, c: float) -> str: + return f"{a} - {b} - {c}" + + m = array_node_map_task(many_inputs) + assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": List[float]} + assert ( + m.name + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_4ee240ef5cf979dbc133fb30035cb874-arraynode" + ) + r_m = ArrayNodeMapTask(many_inputs) + assert str(r_m.python_interface) == str(m.python_interface) + + p1 = functools.partial(many_inputs, c=1.0) + m = array_node_map_task(p1) + assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": float} + assert ( + m.name + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_352fcdea8523a83134b51bbf5793f14e-arraynode" + ) + r_m = ArrayNodeMapTask(many_inputs, bound_inputs=set("c")) + assert str(r_m.python_interface) == str(m.python_interface) + + p2 = functools.partial(p1, b="hello") + m = array_node_map_task(p2) + assert m.python_interface.inputs == {"a": List[int], "b": str, "c": float} + assert ( + m.name + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_e224ba3a5b00e08083d541a6ca99b179-arraynode" + ) + r_m = ArrayNodeMapTask(many_inputs, bound_inputs={"c", "b"}) + assert str(r_m.python_interface) == str(m.python_interface) + + p3 = functools.partial(p2, a=1) + m = array_node_map_task(p3) + assert m.python_interface.inputs == {"a": int, "b": str, "c": float} + assert ( + m.name + == "tests.flytekit.unit.core.test_array_node_map_task.map_many_inputs_f080e60be9d6faedeef0c74834d6812a-arraynode" + ) + r_m = ArrayNodeMapTask(many_inputs, bound_inputs={"a", "c", "b"}) + assert str(r_m.python_interface) == str(m.python_interface) + + with pytest.raises(TypeError): + m(a=[1, 2, 3]) + + @task + def many_outputs(a: int) -> (int, str): + return a, f"{a}" + + with pytest.raises(ValueError): + _ = array_node_map_task(many_outputs) diff --git a/tests/flytekit/unit/core/test_complex_nesting.py b/tests/flytekit/unit/core/test_complex_nesting.py index c8c643cc67..7534197f98 100644 --- a/tests/flytekit/unit/core/test_complex_nesting.py +++ b/tests/flytekit/unit/core/test_complex_nesting.py @@ -4,7 +4,7 @@ from typing import List import pytest -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.context_manager import ExecutionState, FlyteContextManager @@ -14,31 +14,27 @@ from flytekit.types.file import FlyteFile -@dataclass_json @dataclass -class MyProxyConfiguration: +class MyProxyConfiguration(DataClassJsonMixin): # File and directory paths kept as 'str' so Flyte doesn't manage these static resources splat_data_dir: str apriori_file: str -@dataclass_json @dataclass -class MyProxyParameters: +class MyProxyParameters(DataClassJsonMixin): id: str job_i_step: int -@dataclass_json @dataclass -class MyAprioriConfiguration: +class MyAprioriConfiguration(DataClassJsonMixin): static_data_dir: FlyteDirectory external_data_dir: FlyteDirectory -@dataclass_json @dataclass -class MyInput: +class MyInput(DataClassJsonMixin): main_product: FlyteFile apriori_config: MyAprioriConfiguration proxy_config: MyProxyConfiguration diff --git a/tests/flytekit/unit/core/test_conditions.py b/tests/flytekit/unit/core/test_conditions.py index 7b0b292baa..0da2467109 100644 --- a/tests/flytekit/unit/core/test_conditions.py +++ b/tests/flytekit/unit/core/test_conditions.py @@ -194,6 +194,25 @@ def decompose() -> int: assert decompose() == 20 +def test_condition_is_none(): + @task + def return_true() -> typing.Optional[None]: + return None + + @workflow + def failed() -> int: + return 10 + + @workflow + def success() -> int: + return 20 + + @workflow + def decompose_unary() -> int: + result = return_true() + return conditional("test").if_(result.is_none()).then(success()).else_().then(failed()) + + def test_subworkflow_condition_serialization(): """Test that subworkflows are correctly extracted from serialized workflows with condiationals.""" diff --git a/tests/flytekit/unit/core/test_context_manager.py b/tests/flytekit/unit/core/test_context_manager.py index b7932663a5..2ec7eb8e19 100644 --- a/tests/flytekit/unit/core/test_context_manager.py +++ b/tests/flytekit/unit/core/test_context_manager.py @@ -1,3 +1,4 @@ +import base64 import os from datetime import datetime @@ -166,6 +167,15 @@ def test_secrets_manager_file(tmpdir: py.path.local): w.write("my-password") assert sec.get("group", "test") == "my-password" assert sec.group.test == "my-password" + + base64_string = "R2Vla3NGb3JHZWV ==" + base64_bytes = base64_string.encode("ascii") + base64_str = base64.b64encode(base64_bytes) + with open(f, "wb") as w: + w.write(base64_str) + assert sec.get("group", "test") != base64_str + assert sec.get("group", "test", encode_mode="rb") == base64_str + del os.environ["FLYTE_SECRETS_DEFAULT_DIR"] diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 2d61b58d8c..667445321b 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -8,9 +8,14 @@ import mock import pytest -from flytekit.configuration import Config, S3Config +from flytekit.configuration import Config, DataConfig, S3Config from flytekit.core.context_manager import FlyteContextManager -from flytekit.core.data_persistence import FileAccessProvider, default_local_file_access_provider, s3_setup_args +from flytekit.core.data_persistence import ( + FileAccessProvider, + default_local_file_access_provider, + get_fsspec_storage_options, + s3_setup_args, +) from flytekit.types.directory.types import FlyteDirectory local = fsspec.filesystem("file") @@ -221,6 +226,73 @@ def test_s3_setup_args_env_aws(mock_os, mock_get_config_file): assert kwargs == {"cache_regions": True} +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_get_fsspec_storage_options_gcs(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_GCP_GSUTIL_PARALLELISM": "False", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + storage_options = get_fsspec_storage_options("gs", DataConfig.auto()) + assert storage_options == {} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_get_fsspec_storage_options_gcs_with_overrides(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_GCP_GSUTIL_PARALLELISM": "False", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + storage_options = get_fsspec_storage_options("gs", DataConfig.auto(), anonymous=True, other_argument="value") + assert storage_options == {"token": "anon", "other_argument": "value"} + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_get_fsspec_storage_options_azure(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", + "FLYTE_AZURE_TENANT_ID": "tenantid", + "FLYTE_AZURE_CLIENT_ID": "clientid", + "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + storage_options = get_fsspec_storage_options("abfs", DataConfig.auto()) + assert storage_options == { + "account_name": "accountname", + "account_key": "accountkey", + "client_id": "clientid", + "client_secret": "clientsecret", + "tenant_id": "tenantid", + "anon": False, + } + + +@mock.patch("flytekit.configuration.get_config_file") +@mock.patch("os.environ") +def test_get_fsspec_storage_options_azure_with_overrides(mock_os, mock_get_config_file): + mock_get_config_file.return_value = None + ee = { + "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey", + } + mock_os.get.side_effect = lambda x, y: ee.get(x) + storage_options = get_fsspec_storage_options( + "abfs", DataConfig.auto(), anonymous=True, account_name="other_accountname", other_argument="value" + ) + assert storage_options == { + "account_name": "other_accountname", + "account_key": "accountkey", + "anon": True, + "other_argument": "value", + } + + def test_crawl_local_nt(source_folder): """ running this to see what it prints diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 27b407c1ce..2fc8b6c452 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,3 +1,8 @@ +import os + +import mock +from azure.identity import ClientSecretCredential, DefaultAzureCredential + from flytekit.core.data_persistence import FileAccessProvider @@ -14,3 +19,39 @@ def test_is_remote(): assert fp.is_remote("/tmp/foo/bar") is False assert fp.is_remote("file://foo/bar") is False assert fp.is_remote("s3://my-bucket/foo/bar") is True + + +def test_initialise_azure_file_provider_with_account_key(): + with mock.patch.dict( + os.environ, + {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", "FLYTE_AZURE_STORAGE_ACCOUNT_KEY": "accountkey"}, + ): + fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") + assert fp.get_filesystem().account_name == "accountname" + assert fp.get_filesystem().account_key == "accountkey" + assert fp.get_filesystem().sync_credential is None + + +def test_initialise_azure_file_provider_with_service_principal(): + with mock.patch.dict( + os.environ, + { + "FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname", + "FLYTE_AZURE_CLIENT_SECRET": "clientsecret", + "FLYTE_AZURE_CLIENT_ID": "clientid", + "FLYTE_AZURE_TENANT_ID": "tenantid", + }, + ): + fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") + assert fp.get_filesystem().account_name == "accountname" + assert isinstance(fp.get_filesystem().sync_credential, ClientSecretCredential) + assert fp.get_filesystem().client_secret == "clientsecret" + assert fp.get_filesystem().client_id == "clientid" + assert fp.get_filesystem().tenant_id == "tenantid" + + +def test_initialise_azure_file_provider_with_default_credential(): + with mock.patch.dict(os.environ, {"FLYTE_AZURE_STORAGE_ACCOUNT_NAME": "accountname"}): + fp = FileAccessProvider("/tmp", "abfs://container/path/within/container") + assert fp.get_filesystem().account_name == "accountname" + assert isinstance(fp.get_filesystem().sync_credential, DefaultAzureCredential) diff --git a/tests/flytekit/unit/core/test_dataclass.py b/tests/flytekit/unit/core/test_dataclass.py index db49d2312c..34350ca40b 100644 --- a/tests/flytekit/unit/core/test_dataclass.py +++ b/tests/flytekit/unit/core/test_dataclass.py @@ -1,16 +1,15 @@ from dataclasses import dataclass from typing import List -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from flytekit.core.task import task from flytekit.core.workflow import workflow def test_dataclass(): - @dataclass_json @dataclass - class AppParams(object): + class AppParams(DataClassJsonMixin): snapshotDate: str region: str preprocess: bool diff --git a/tests/flytekit/unit/core/test_local_cache.py b/tests/flytekit/unit/core/test_local_cache.py index 2ee5e34674..1569a258f4 100644 --- a/tests/flytekit/unit/core/test_local_cache.py +++ b/tests/flytekit/unit/core/test_local_cache.py @@ -6,7 +6,7 @@ import pandas import pandas as pd import pytest -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from pytest import fixture from typing_extensions import Annotated @@ -164,9 +164,8 @@ def my_wf() -> FlyteSchema: def test_wf_custom_types(): - @dataclass_json @dataclass - class MyCustomType(object): + class MyCustomType(DataClassJsonMixin): x: int y: str diff --git a/tests/flytekit/unit/core/test_partials.py b/tests/flytekit/unit/core/test_partials.py index 24e3908d1d..1e7b5d43ee 100644 --- a/tests/flytekit/unit/core/test_partials.py +++ b/tests/flytekit/unit/core/test_partials.py @@ -7,10 +7,12 @@ import flytekit.configuration from flytekit.configuration import Image, ImageConfig +from flytekit.core.array_node_map_task import ArrayNodeMapTaskResolver from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.map_task import MapTaskResolver, map_task from flytekit.core.task import TaskMetadata, task from flytekit.core.workflow import workflow +from flytekit.experimental import map_task as array_node_map_task from flytekit.tools.translator import gather_dependent_entities, get_serializable default_img = Image(name="default", fqn="test", tag="tag") @@ -73,8 +75,15 @@ def my_wf_2(a: int) -> int: assert len(wf_2_spec.template.nodes) == 2 -def test_map_task_types(): - @task(cache=True, cache_version="1") +@pytest.mark.parametrize( + "map_task_fn", + [ + map_task, + array_node_map_task, + ], +) +def test_map_task_types(map_task_fn): + @task def t3(a: int, b: str, c: float) -> str: return str(a) + b + str(c) @@ -83,8 +92,8 @@ def t3(a: int, b: str, c: float) -> str: t3_bind_c1 = partial(t3_bind_b1, c=3.14) t3_bind_c2 = partial(t3_bind_b2, c=2.78) - mt1 = map_task(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) - mt2 = map_task(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) + mt1 = map_task_fn(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) + mt2 = map_task_fn(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) @task def print_lists(i: typing.List[str], j: typing.List[str]): @@ -101,8 +110,8 @@ def wf_out(a: typing.List[int]): @workflow def wf_in(a: typing.List[int]): - mt_in1 = map_task(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) - mt_in2 = map_task(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) + mt_in1 = map_task_fn(t3_bind_c1, metadata=TaskMetadata(cache=True, cache_version="1")) + mt_in2 = map_task_fn(t3_bind_c2, metadata=TaskMetadata(cache=True, cache_version="1")) i = mt_in1(a=a) j = mt_in2(a=[3, 4, 5]) print_lists(i=i, j=j) @@ -113,39 +122,62 @@ def wf_in(a: typing.List[int]): wf_spec = get_serializable(od, serialization_settings, wf_in) tts, _, _ = gather_dependent_entities(od) assert len(tts) == 2 # one map task + the print task - assert ( - wf_spec.template.nodes[0].task_node.reference_id.name == wf_spec.template.nodes[1].task_node.reference_id.name - ) + if map_task_fn == array_node_map_task: + assert ( + wf_spec.template.nodes[0].array_node.node.task_node.reference_id.name + == wf_spec.template.nodes[1].array_node.node.task_node.reference_id.name + ) + elif map_task_fn == map_task: + assert ( + wf_spec.template.nodes[0].task_node.reference_id.name + == wf_spec.template.nodes[1].task_node.reference_id.name + ) + else: + raise ValueError("Unexpected map task fn") assert wf_spec.template.nodes[0].inputs[0].binding.promise is not None # comes from wf input assert wf_spec.template.nodes[1].inputs[0].binding.collection is not None # bound to static list assert wf_spec.template.nodes[1].inputs[1].binding.scalar is not None # these are bound assert wf_spec.template.nodes[1].inputs[2].binding.scalar is not None -def test_lists_cannot_be_used_in_partials(): +@pytest.mark.parametrize( + "map_task_fn", + [ + map_task, + array_node_map_task, + ], +) +def test_lists_cannot_be_used_in_partials(map_task_fn): @task def t(a: int, b: typing.List[str]) -> str: return str(a) + str(b) with pytest.raises(ValueError): - map_task(partial(t, b=["hello", "world"]))(a=[1, 2, 3]) + map_task_fn(partial(t, b=["hello", "world"]))(a=[1, 2, 3]) @task def t_multilist(a: int, b: typing.List[float], c: typing.List[int]) -> str: return str(a) + str(b) + str(c) with pytest.raises(ValueError): - map_task(partial(t_multilist, b=[3.14, 12.34, 9876.5432], c=[42, 99]))(a=[1, 2, 3, 4]) + map_task_fn(partial(t_multilist, b=[3.14, 12.34, 9876.5432], c=[42, 99]))(a=[1, 2, 3, 4]) @task def t_list_of_lists(a: typing.List[typing.List[float]], b: int) -> str: return str(a) + str(b) with pytest.raises(ValueError): - map_task(partial(t_list_of_lists, a=[[3.14]]))(b=[1, 2, 3, 4]) + map_task_fn(partial(t_list_of_lists, a=[[3.14]]))(b=[1, 2, 3, 4]) -def test_everything(): +@pytest.mark.parametrize( + "map_task_fn", + [ + map_task, + array_node_map_task, + ], +) +def test_everything(map_task_fn): @task def get_static_list() -> typing.List[float]: return [3.14, 2.718] @@ -167,11 +199,16 @@ def t3(a: int, b: str, c: typing.List[float], d: typing.List[float], a2: pd.Data # TODO: partial lists are not supported yet. # t3_bind_b1 = partial(t3, b="hello") # t3_bind_c1 = partial(t3_bind_b1, c=[6.674, 1.618, 6.626], d=[1.0]) - # mt1 = map_task(t3_bind_c1) + # mt1 = map_task_fn(t3_bind_c1) - mt1 = map_task(t3_bind_b2) + mt1 = map_task_fn(t3_bind_b2) - mr = MapTaskResolver() + if map_task_fn == array_node_map_task: + mr = ArrayNodeMapTaskResolver() + elif map_task_fn == map_task: + mr = MapTaskResolver() + else: + raise ValueError("Unexpected map task fn") aa = mr.loader_args(serialization_settings, mt1) # Check bound vars aa = aa[1].split(",") @@ -188,14 +225,14 @@ def print_lists(i: typing.List[str], j: typing.List[str], k: typing.List[str]) - @dynamic def dt1(a: typing.List[int], a2: typing.List[pd.DataFrame], sl: typing.List[float]) -> str: i = mt1(a=a, a2=a2, c=[[1.1, 2.0, 3.0], [1.1, 2.0, 3.0]], d=[sl, sl]) - mt_in2 = map_task(t3_bind_b2) + mt_in2 = map_task_fn(t3_bind_b2) dfs = get_list_of_pd(s=3) j = mt_in2(a=[3, 4, 5], a2=dfs, c=[[1.0], [2.0], [3.0]], d=[sl, sl, sl]) # Test a2 bound to a fixed dataframe t3_bind_a2 = partial(t3_bind_b2, a2=a2[0]) - mt_in3 = map_task(t3_bind_a2) + mt_in3 = map_task_fn(t3_bind_a2) aa = mr.loader_args(serialization_settings, mt_in3) # Check bound vars diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index c1eb15912b..6a487b464a 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -2,7 +2,7 @@ from dataclasses import dataclass import pytest -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from typing_extensions import Annotated from flytekit import LaunchPlan, task, workflow @@ -92,9 +92,8 @@ def wf(i: int, j: int): create_and_link_node_from_remote(ctx, lp, _inputs_not_allowed={"i"}, _ignorable_inputs={"j"}, j=15) -@dataclass_json @dataclass -class MyDataclass(object): +class MyDataclass(DataClassJsonMixin): i: int a: typing.List[str] diff --git a/tests/flytekit/unit/core/test_resources.py b/tests/flytekit/unit/core/test_resources.py index 1a3bf64dee..25a637b2d6 100644 --- a/tests/flytekit/unit/core/test_resources.py +++ b/tests/flytekit/unit/core/test_resources.py @@ -66,3 +66,16 @@ def test_convert_limits(resource_dict: Dict[str, str], expected_resource_name: _ assert limit.name == expected_resource_name assert limit.value == expected_resource_value assert len(resources_model.requests) == 0 + + +def test_incorrect_type_resources(): + with pytest.raises(AssertionError): + Resources(cpu=1) # type: ignore + with pytest.raises(AssertionError): + Resources(mem=1) # type: ignore + with pytest.raises(AssertionError): + Resources(gpu=1) # type: ignore + with pytest.raises(AssertionError): + Resources(storage=1) # type: ignore + with pytest.raises(AssertionError): + Resources(ephemeral_storage=1) # type: ignore diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index eaba8b6343..5124193b27 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -493,3 +493,17 @@ def to_html(self, input: str) -> str: with pytest.raises(NotImplementedError, match="Could not find a renderer for in"): StructuredDatasetTransformerEngine().to_html(FlyteContextManager.current_context(), 3, int) + + +def test_list_of_annotated(): + WineDataset = Annotated[ + StructuredDataset, + kwtypes( + alcohol=float, + malic_acid=float, + ), + ] + + @task + def no_op(data: WineDataset) -> typing.List[WineDataset]: + return [data] diff --git a/tests/flytekit/unit/core/test_structured_dataset_handlers.py b/tests/flytekit/unit/core/test_structured_dataset_handlers.py index 4b9d183ad8..b26349ceeb 100644 --- a/tests/flytekit/unit/core/test_structured_dataset_handlers.py +++ b/tests/flytekit/unit/core/test_structured_dataset_handlers.py @@ -1,5 +1,6 @@ import typing +import mock import pandas as pd import pyarrow as pa import pytest @@ -50,6 +51,52 @@ def test_csv(): assert df.equals(df2) +@mock.patch("pandas.DataFrame.to_parquet") +@mock.patch("pandas.read_parquet") +@mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") +def test_pandas_to_parquet_azure_storage_options(mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet): + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + encoder = basic_dfs.PandasToParquetEncodingHandler() + decoder = basic_dfs.ParquetToPandasDecodingHandler() + + mock_get_fsspec_storage_options.return_value = {"account_name": "accountname_from_storage_options"} + ctx = context_manager.FlyteContextManager.current_context() + sd = StructuredDataset(dataframe=df, uri="abfs://container/parquet_df") + sd_type = StructuredDatasetType(format="parquet") + sd_lit = encoder.encode(ctx, sd, sd_type) + mock_to_parquet.assert_called_once() + write_storage_options = mock_to_parquet.call_args.kwargs["storage_options"] + assert write_storage_options == {"account_name": "accountname_from_storage_options"} + + decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) + mock_read_parquet.assert_called_once() + read_storage_options = mock_read_parquet.call_args.kwargs["storage_options"] + read_storage_options == {"account_name": "accountname_from_storage_options"} + + +@mock.patch("pandas.DataFrame.to_csv") +@mock.patch("pandas.read_csv") +@mock.patch("flytekit.types.structured.basic_dfs.get_fsspec_storage_options") +def test_pandas_to_csv_azure_storage_options(mock_get_fsspec_storage_options, mock_read_parquet, mock_to_parquet): + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + encoder = basic_dfs.PandasToCSVEncodingHandler() + decoder = basic_dfs.CSVToPandasDecodingHandler() + + mock_get_fsspec_storage_options.return_value = {"account_name": "accountname_from_storage_options"} + ctx = context_manager.FlyteContextManager.current_context() + sd = StructuredDataset(dataframe=df, uri="abfs://container/csv_df") + sd_type = StructuredDatasetType(format="csv") + sd_lit = encoder.encode(ctx, sd, sd_type) + mock_to_parquet.assert_called_once() + write_storage_options = mock_to_parquet.call_args.kwargs["storage_options"] + assert write_storage_options == {"account_name": "accountname_from_storage_options"} + + decoder.decode(ctx, sd_lit, StructuredDatasetMetadata(sd_type)) + mock_read_parquet.assert_called_once() + read_storage_options = mock_read_parquet.call_args.kwargs["storage_options"] + read_storage_options == {"account_name": "accountname_from_storage_options"} + + def test_base_isnt_instantiable(): with pytest.raises(TypeError): StructuredDatasetEncoder(pd.DataFrame, "", "") diff --git a/tests/flytekit/unit/core/test_type_delayed.py b/tests/flytekit/unit/core/test_type_delayed.py index 3e6824788d..a47a0b88f8 100644 --- a/tests/flytekit/unit/core/test_type_delayed.py +++ b/tests/flytekit/unit/core/test_type_delayed.py @@ -3,7 +3,7 @@ import typing from dataclasses import dataclass -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from typing_extensions import Annotated # type: ignore from flytekit.core import context_manager @@ -11,9 +11,8 @@ from flytekit.core.type_engine import TypeEngine -@dataclass_json @dataclass -class Foo(object): +class Foo(DataClassJsonMixin): x: int y: str z: typing.Dict[str, str] diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 7332d01631..14d5835142 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1,3 +1,4 @@ +import dataclasses import datetime import json import os @@ -19,6 +20,7 @@ from google.protobuf import struct_pb2 as _struct from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema +from mashumaro.mixins.json import DataClassJSONMixin from pandas._testing import assert_frame_equal from typing_extensions import Annotated, get_args, get_origin @@ -39,7 +41,8 @@ TypeTransformer, TypeTransformerFailedError, UnionTransformer, - convert_json_schema_to_python_class, + convert_marshmallow_json_schema_to_python_class, + convert_mashumaro_json_schema_to_python_class, dataclass_from_dict, get_underlying_type, is_annotated, @@ -90,6 +93,7 @@ def test_type_resolution(): assert type(TypeEngine.get_transformer(dict)) == DictTransformer assert type(TypeEngine.get_transformer(int)) == SimpleTransformer + assert type(TypeEngine.get_transformer(datetime.date)) == SimpleTransformer assert type(TypeEngine.get_transformer(os.PathLike)) == FlyteFilePathTransformer assert type(TypeEngine.get_transformer(FlytePickle)) == FlytePickleTransformer @@ -148,15 +152,13 @@ def test_list_of_dict_getting_python_value(): def test_list_of_single_dataclass(): - @dataclass_json - @dataclass() - class Bar(object): + @dataclass + class Bar(DataClassJsonMixin): v: typing.Optional[typing.List[int]] w: typing.Optional[typing.List[float]] - @dataclass_json - @dataclass() - class Foo(object): + @dataclass + class Foo(DataClassJsonMixin): a: typing.Optional[typing.List[str]] b: Bar @@ -172,6 +174,31 @@ class Foo(object): assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182]) +@dataclass +class Bar(DataClassJSONMixin): + v: typing.Optional[typing.List[int]] + w: typing.Optional[typing.List[float]] + + +@dataclass +class Foo(DataClassJSONMixin): + a: typing.Optional[typing.List[str]] + b: Bar + + +def test_list_of_single_dataclassjsonmixin(): + foo = Foo(a=["abc", "def"], b=Bar(v=[1, 2, 99], w=[3.1415, 2.7182])) + generic = _json_format.Parse(typing.cast(DataClassJSONMixin, foo).to_json(), _struct.Struct()) + lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) + + transformer = TypeEngine.get_transformer(typing.List) + ctx = FlyteContext.current_context() + + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) + assert pv[0].a == ["abc", "def"] + assert pv[0].b == Bar(v=[1, 2, 99], w=[3.1415, 2.7182]) + + def test_annotated_type(): class JsonTypeTransformer(TypeTransformer[T]): LiteralType = LiteralType( @@ -218,18 +245,16 @@ def __class_getitem__(cls, item: Type[T]): def test_list_of_dataclass_getting_python_value(): - @dataclass_json - @dataclass() - class Bar(object): + @dataclass + class Bar(DataClassJsonMixin): v: typing.Union[int, None] w: typing.Optional[str] x: float y: str z: typing.Dict[str, bool] - @dataclass_json - @dataclass() - class Foo(object): + @dataclass + class Foo(DataClassJsonMixin): u: typing.Optional[int] v: typing.Optional[int] w: int @@ -245,7 +270,7 @@ class Foo(object): ctx = FlyteContext.current_context() schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema()) - foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema") + foo_class = convert_marshmallow_json_schema_to_python_class(schema["definitions"], "FooSchema") guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo]) @@ -265,6 +290,67 @@ class Foo(object): assert guessed_pv[0].z.y == pv[0].z.y assert guessed_pv[0].z.z == pv[0].z.z assert pv[0] == dataclass_from_dict(Foo, asdict(guessed_pv[0])) + assert dataclasses.is_dataclass(foo_class) + + +@dataclass +class Bar_getting_python_value(DataClassJSONMixin): + v: typing.Union[int, None] + w: typing.Optional[str] + x: float + y: str + z: typing.Dict[str, bool] + + +@dataclass +class Foo_getting_python_value(DataClassJSONMixin): + u: typing.Optional[int] + v: typing.Optional[int] + w: int + x: typing.List[int] + y: typing.Dict[str, str] + z: Bar_getting_python_value + + +def test_list_of_dataclassjsonmixin_getting_python_value(): + foo = Foo_getting_python_value( + u=5, + v=None, + w=1, + x=[1], + y={"hello": "10"}, + z=Bar_getting_python_value(v=3, w=None, x=1.0, y="hello", z={"world": False}), + ) + generic = _json_format.Parse(typing.cast(DataClassJSONMixin, foo).to_json(), _struct.Struct()) + lv = Literal(collection=LiteralCollection(literals=[Literal(scalar=Scalar(generic=generic))])) + + transformer = TypeEngine.get_transformer(typing.List) + ctx = FlyteContext.current_context() + + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo_getting_python_value)).to_dict() + foo_class = convert_mashumaro_json_schema_to_python_class(schema, "FooSchema") + + guessed_pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[foo_class]) + pv = transformer.to_python_value(ctx, lv, expected_python_type=typing.List[Foo_getting_python_value]) + assert isinstance(guessed_pv, list) + assert guessed_pv[0].u == pv[0].u + assert guessed_pv[0].v == pv[0].v + assert guessed_pv[0].w == pv[0].w + assert guessed_pv[0].x == pv[0].x + assert guessed_pv[0].y == pv[0].y + assert guessed_pv[0].z.x == pv[0].z.x + assert type(guessed_pv[0].u) == int + assert guessed_pv[0].v is None + assert type(guessed_pv[0].w) == int + assert type(guessed_pv[0].z.v) == int + assert type(guessed_pv[0].z.x) == float + assert guessed_pv[0].z.v == pv[0].z.v + assert guessed_pv[0].z.y == pv[0].z.y + assert guessed_pv[0].z.z == pv[0].z.z + assert pv[0] == dataclass_from_dict(Foo_getting_python_value, asdict(guessed_pv[0])) + assert dataclasses.is_dataclass(foo_class) def test_file_no_downloader_default(): @@ -301,6 +387,16 @@ def test_dir_no_downloader_default(): assert pv.download() == local_dir +def test_dir_with_batch_size(): + flyte_dir = Annotated[FlyteDirectory, BatchSize(100)] + val = flyte_dir("s3://bucket/key") + transformer = TypeEngine.get_transformer(flyte_dir) + ctx = FlyteContext.current_context() + lt = transformer.get_literal_type(flyte_dir) + lv = transformer.to_literal(ctx, val, flyte_dir, lt) + assert val.path == transformer.to_python_value(ctx, lv, flyte_dir).remote_source + + def test_dict_transformer(): d = DictTransformer() @@ -323,6 +419,7 @@ def recursive_assert(lit: LiteralType, expected: LiteralType, expected_depth: in recursive_assert(d.get_literal_type(typing.Dict[str, int]), LiteralType(simple=SimpleType.INTEGER)) recursive_assert(d.get_literal_type(typing.Dict[str, datetime.datetime]), LiteralType(simple=SimpleType.DATETIME)) recursive_assert(d.get_literal_type(typing.Dict[str, datetime.timedelta]), LiteralType(simple=SimpleType.DURATION)) + recursive_assert(d.get_literal_type(typing.Dict[str, datetime.date]), LiteralType(simple=SimpleType.DATETIME)) recursive_assert(d.get_literal_type(typing.Dict[str, dict]), LiteralType(simple=SimpleType.STRUCT)) recursive_assert( d.get_literal_type(typing.Dict[str, typing.Dict[str, str]]), @@ -377,21 +474,41 @@ def recursive_assert(lit: LiteralType, expected: LiteralType, expected_depth: in ) -def test_convert_json_schema_to_python_class(): - @dataclass_json +def test_convert_marshmallow_json_schema_to_python_class(): @dataclass - class Foo(object): + class Foo(DataClassJsonMixin): x: int y: str schema = JSONSchema().dump(typing.cast(DataClassJsonMixin, Foo).schema()) - foo_class = convert_json_schema_to_python_class(schema["definitions"], "FooSchema") + foo_class = convert_marshmallow_json_schema_to_python_class(schema["definitions"], "FooSchema") + foo = foo_class(x=1, y="hello") + foo.x = 2 + assert foo.x == 2 + assert foo.y == "hello" + with pytest.raises(AttributeError): + _ = foo.c + assert dataclasses.is_dataclass(foo_class) + + +def test_convert_mashumaro_json_schema_to_python_class(): + @dataclass + class Foo(DataClassJSONMixin): + x: int + y: str + + # schema = JSONSchema().dump(typing.cast(DataClassJSONMixin, Foo).schema()) + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(typing.cast(DataClassJSONMixin, Foo)).to_dict() + foo_class = convert_mashumaro_json_schema_to_python_class(schema, "FooSchema") foo = foo_class(x=1, y="hello") foo.x = 2 assert foo.x == 2 assert foo.y == "hello" with pytest.raises(AttributeError): _ = foo.c + assert dataclasses.is_dataclass(foo_class) def test_list_transformer(): @@ -490,40 +607,35 @@ def test_zero_floats(): assert TypeEngine.to_python_value(ctx, l1, float) == 0 -@dataclass_json @dataclass -class InnerStruct(object): +class InnerStruct(DataClassJsonMixin): a: int b: typing.Optional[str] c: typing.List[int] -@dataclass_json @dataclass -class TestStruct(object): +class TestStruct(DataClassJsonMixin): s: InnerStruct m: typing.Dict[str, str] -@dataclass_json @dataclass -class TestStructB(object): +class TestStructB(DataClassJsonMixin): s: InnerStruct m: typing.Dict[int, str] n: typing.Optional[typing.List[typing.List[int]]] = None o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None -@dataclass_json @dataclass -class TestStructC(object): +class TestStructC(DataClassJsonMixin): s: InnerStruct m: typing.Dict[str, int] -@dataclass_json @dataclass -class TestStructD(object): +class TestStructD(DataClassJsonMixin): s: InnerStruct m: typing.Dict[str, typing.List[int]] @@ -533,9 +645,8 @@ def __init__(self): self._a = "Hello" -@dataclass_json @dataclass -class UnsupportedNestedStruct(object): +class UnsupportedNestedStruct(DataClassJsonMixin): a: int s: UnsupportedSchemaType @@ -590,6 +701,94 @@ def test_dataclass_transformer(): assert t.metadata is None +@dataclass +class InnerStruct_transformer(DataClassJSONMixin): + a: int + b: typing.Optional[str] + c: typing.List[int] + + +@dataclass +class TestStruct_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, str] + + +@dataclass +class TestStructB_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[int, str] + n: typing.Optional[typing.List[typing.List[int]]] = None + o: typing.Optional[typing.Dict[int, typing.Dict[int, int]]] = None + + +@dataclass +class TestStructC_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, int] + + +@dataclass +class TestStructD_transformer(DataClassJSONMixin): + s: InnerStruct_transformer + m: typing.Dict[str, typing.List[int]] + + +@dataclass +class UnsupportedSchemaType_transformer: + _a: str = "Hello" + + +@dataclass +class UnsupportedNestedStruct_transformer(DataClassJSONMixin): + a: int + s: UnsupportedSchemaType_transformer + + +def test_dataclass_transformer_with_dataclassjsonmixin(): + schema = { + "type": "object", + "title": "TestStruct_transformer", + "properties": { + "s": { + "type": "object", + "title": "InnerStruct_transformer", + "properties": { + "a": {"type": "integer"}, + "b": {"anyOf": [{"type": "string"}, {"type": "null"}]}, + "c": {"type": "array", "items": {"type": "integer"}}, + }, + "additionalProperties": False, + "required": ["a", "b", "c"], + }, + "m": {"type": "object", "additionalProperties": {"type": "string"}, "propertyNames": {"type": "string"}}, + }, + "additionalProperties": False, + "required": ["s", "m"], + } + + tf = DataclassTransformer() + t = tf.get_literal_type(TestStruct_transformer) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is not None + assert t.metadata == schema + + t = TypeEngine.to_literal_type(TestStruct_transformer) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is not None + assert t.metadata == schema + + t = tf.get_literal_type(UnsupportedNestedStruct) + assert t is not None + assert t.simple is not None + assert t.simple == SimpleType.STRUCT + assert t.metadata is None + + def test_dataclass_int_preserving(): ctx = FlyteContext.current_context() @@ -621,14 +820,12 @@ def test_dataclass_int_preserving(): def test_optional_flytefile_in_dataclass(mock_upload_dir): mock_upload_dir.return_value = True - @dataclass_json @dataclass - class A(object): + class A(DataClassJsonMixin): a: int - @dataclass_json @dataclass - class TestFileStruct(object): + class TestFileStruct(DataClassJsonMixin): a: FlyteFile b: typing.Optional[FlyteFile] b_prime: typing.Optional[FlyteFile] @@ -701,19 +898,101 @@ class TestFileStruct(object): assert o.i_prime == A(a=99) +@dataclass +class A_optional_flytefile(DataClassJSONMixin): + a: int + + +@dataclass +class TestFileStruct_optional_flytefile(DataClassJSONMixin): + a: FlyteFile + b: typing.Optional[FlyteFile] + b_prime: typing.Optional[FlyteFile] + c: typing.Union[FlyteFile, None] + d: typing.List[FlyteFile] + e: typing.List[typing.Optional[FlyteFile]] + e_prime: typing.List[typing.Optional[FlyteFile]] + f: typing.Dict[str, FlyteFile] + g: typing.Dict[str, typing.Optional[FlyteFile]] + g_prime: typing.Dict[str, typing.Optional[FlyteFile]] + h: typing.Optional[FlyteFile] = None + h_prime: typing.Optional[FlyteFile] = None + i: typing.Optional[A_optional_flytefile] = None + i_prime: typing.Optional[A_optional_flytefile] = field(default_factory=lambda: A_optional_flytefile(a=99)) + + +@mock.patch("flytekit.core.data_persistence.FileAccessProvider.put_data") +def test_optional_flytefile_in_dataclassjsonmixin(mock_upload_dir): + mock_upload_dir.return_value = True + + remote_path = "s3://tmp/file" + with tempfile.TemporaryFile() as f: + f.write(b"abc") + f1 = FlyteFile("f1", remote_path=remote_path) + o = TestFileStruct_optional_flytefile( + a=f1, + b=f1, + b_prime=None, + c=f1, + d=[f1], + e=[f1], + e_prime=[None], + f={"a": f1}, + g={"a": f1}, + g_prime={"a": None}, + h=f1, + i=A_optional_flytefile(a=42), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct_optional_flytefile) + lv = tf.to_literal(ctx, o, TestFileStruct_optional_flytefile, lt) + + assert lv.scalar.generic["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b"].fields["path"].string_value == remote_path + assert lv.scalar.generic["b_prime"] is None + assert lv.scalar.generic["c"].fields["path"].string_value == remote_path + assert lv.scalar.generic["d"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e"].values[0].struct_value.fields["path"].string_value == remote_path + assert lv.scalar.generic["e_prime"].values[0].WhichOneof("kind") == "null_value" + assert lv.scalar.generic["f"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g"]["a"].fields["path"].string_value == remote_path + assert lv.scalar.generic["g_prime"]["a"] is None + assert lv.scalar.generic["h"].fields["path"].string_value == remote_path + assert lv.scalar.generic["h_prime"] is None + assert lv.scalar.generic["i"].fields["a"].number_value == 42 + assert lv.scalar.generic["i_prime"].fields["a"].number_value == 99 + + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_optional_flytefile) + + assert o.a.path == ot.a.remote_source + assert o.b.path == ot.b.remote_source + assert ot.b_prime is None + assert o.c.path == ot.c.remote_source + assert o.d[0].path == ot.d[0].remote_source + assert o.e[0].path == ot.e[0].remote_source + assert o.e_prime == [None] + assert o.f["a"].path == ot.f["a"].remote_source + assert o.g["a"].path == ot.g["a"].remote_source + assert o.g_prime == {"a": None} + assert o.h.path == ot.h.remote_source + assert ot.h_prime is None + assert o.i == ot.i + assert o.i_prime == A_optional_flytefile(a=99) + + def test_flyte_file_in_dataclass(): - @dataclass_json @dataclass - class TestInnerFileStruct(object): + class TestInnerFileStruct(DataClassJsonMixin): a: JPEGImageFile b: typing.List[FlyteFile] c: typing.Dict[str, FlyteFile] d: typing.List[FlyteFile] e: typing.Dict[str, FlyteFile] - @dataclass_json @dataclass - class TestFileStruct(object): + class TestFileStruct(DataClassJsonMixin): a: FlyteFile b: TestInnerFileStruct @@ -746,19 +1025,64 @@ class TestFileStruct(object): assert not ctx.file_access.is_remote(ot.b.e["hello"].path) +@dataclass +class TestInnerFileStruct_flyte_file(DataClassJSONMixin): + a: JPEGImageFile + b: typing.List[FlyteFile] + c: typing.Dict[str, FlyteFile] + d: typing.List[FlyteFile] + e: typing.Dict[str, FlyteFile] + + +@dataclass +class TestFileStruct_flyte_file(DataClassJSONMixin): + a: FlyteFile + b: TestInnerFileStruct_flyte_file + + +def test_flyte_file_in_dataclassjsonmixin(): + remote_path = "s3://tmp/file" + f1 = FlyteFile(remote_path) + f2 = FlyteFile("/tmp/file") + f2._remote_source = remote_path + o = TestFileStruct_flyte_file( + a=f1, + b=TestInnerFileStruct_flyte_file( + a=JPEGImageFile("s3://tmp/file.jpeg"), b=[f1], c={"hello": f1}, d=[f2], e={"hello": f2} + ), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct_flyte_file) + lv = tf.to_literal(ctx, o, TestFileStruct_flyte_file, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_flyte_file) + assert ot.a._downloader is not noop + assert ot.b.a._downloader is not noop + assert ot.b.b[0]._downloader is not noop + assert ot.b.c["hello"]._downloader is not noop + + assert o.a.path == ot.a.remote_source + assert o.b.a.path == ot.b.a.remote_source + assert o.b.b[0].path == ot.b.b[0].remote_source + assert o.b.c["hello"].path == ot.b.c["hello"].remote_source + assert ot.b.d[0].remote_source == remote_path + assert not ctx.file_access.is_remote(ot.b.d[0].path) + assert ot.b.e["hello"].remote_source == remote_path + assert not ctx.file_access.is_remote(ot.b.e["hello"].path) + + def test_flyte_directory_in_dataclass(): - @dataclass_json @dataclass - class TestInnerFileStruct(object): + class TestInnerFileStruct(DataClassJsonMixin): a: TensorboardLogs b: typing.List[FlyteDirectory] c: typing.Dict[str, FlyteDirectory] d: typing.List[FlyteDirectory] e: typing.Dict[str, FlyteDirectory] - @dataclass_json @dataclass - class TestFileStruct(object): + class TestFileStruct(DataClassJsonMixin): a: FlyteDirectory b: TestInnerFileStruct @@ -794,20 +1118,68 @@ class TestFileStruct(object): assert o.b.e["hello"].path == ot.b.e["hello"].remote_source +@dataclass +class TestInnerFileStruct_flyte_directory(DataClassJSONMixin): + a: TensorboardLogs + b: typing.List[FlyteDirectory] + c: typing.Dict[str, FlyteDirectory] + d: typing.List[FlyteDirectory] + e: typing.Dict[str, FlyteDirectory] + + +@dataclass +class TestFileStruct_flyte_directory(DataClassJSONMixin): + a: FlyteDirectory + b: TestInnerFileStruct_flyte_directory + + +def test_flyte_directory_in_dataclassjsonmixin(): + remote_path = "s3://tmp/file" + tempdir = tempfile.mkdtemp(prefix="flyte-") + f1 = FlyteDirectory(tempdir) + f1._remote_source = remote_path + f2 = FlyteDirectory(remote_path) + o = TestFileStruct_flyte_directory( + a=f1, + b=TestInnerFileStruct_flyte_directory( + a=TensorboardLogs("s3://tensorboard"), b=[f1], c={"hello": f1}, d=[f2], e={"hello": f2} + ), + ) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(TestFileStruct_flyte_directory) + lv = tf.to_literal(ctx, o, TestFileStruct_flyte_directory, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=TestFileStruct_flyte_directory) + + assert ot.a._downloader is not noop + assert ot.b.a._downloader is not noop + assert ot.b.b[0]._downloader is not noop + assert ot.b.c["hello"]._downloader is not noop + + assert o.a.remote_directory == ot.a.remote_directory + assert not ctx.file_access.is_remote(ot.a.path) + assert o.b.a.path == ot.b.a.remote_source + assert o.b.b[0].remote_directory == ot.b.b[0].remote_directory + assert not ctx.file_access.is_remote(ot.b.b[0].path) + assert o.b.c["hello"].remote_directory == ot.b.c["hello"].remote_directory + assert not ctx.file_access.is_remote(ot.b.c["hello"].path) + assert o.b.d[0].path == ot.b.d[0].remote_source + assert o.b.e["hello"].path == ot.b.e["hello"].remote_source + + def test_structured_dataset_in_dataclass(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) People = Annotated[StructuredDataset, "parquet", kwtypes(Name=str, Age=int)] - @dataclass_json @dataclass - class InnerDatasetStruct(object): + class InnerDatasetStruct(DataClassJsonMixin): a: StructuredDataset b: typing.List[Annotated[StructuredDataset, "parquet"]] c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]] - @dataclass_json @dataclass - class DatasetStruct(object): + class DatasetStruct(DataClassJsonMixin): a: People b: InnerDatasetStruct @@ -830,6 +1202,41 @@ class DatasetStruct(object): assert "parquet" == ot.b.c["hello"].file_format +@dataclass +class InnerDatasetStruct_dataclassjsonmixin(DataClassJSONMixin): + a: StructuredDataset + b: typing.List[Annotated[StructuredDataset, "parquet"]] + c: typing.Dict[str, Annotated[StructuredDataset, kwtypes(Name=str, Age=int)]] + + +def test_structured_dataset_in_dataclassjsonmixin(): + df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) + People = Annotated[StructuredDataset, "parquet"] + + @dataclass + class DatasetStruct_dataclassjsonmixin(DataClassJSONMixin): + a: People + b: InnerDatasetStruct_dataclassjsonmixin + + sd = StructuredDataset(dataframe=df, file_format="parquet") + o = DatasetStruct_dataclassjsonmixin(a=sd, b=InnerDatasetStruct_dataclassjsonmixin(a=sd, b=[sd], c={"hello": sd})) + + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(DatasetStruct_dataclassjsonmixin) + lv = tf.to_literal(ctx, o, DatasetStruct_dataclassjsonmixin, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=DatasetStruct_dataclassjsonmixin) + + assert_frame_equal(df, ot.a.open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.a.open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.b[0].open(pd.DataFrame).all()) + assert_frame_equal(df, ot.b.c["hello"].open(pd.DataFrame).all()) + assert "parquet" == ot.a.file_format + assert "parquet" == ot.b.a.file_format + assert "parquet" == ot.b.b[0].file_format + assert "parquet" == ot.b.c["hello"].file_format + + # Enums should have string values class Color(Enum): RED = "red" @@ -895,6 +1302,9 @@ def test_enum_type(): assert t.enum_type.values assert t.enum_type.values == [c.value for c in Color] + g = TypeEngine.guess_python_type(t) + assert [e.value for e in g] == [e.value for e in Color] + ctx = FlyteContextManager.current_context() lv = TypeEngine.to_literal(ctx, Color.RED, Color, TypeEngine.to_literal_type(Color)) assert lv @@ -953,15 +1363,13 @@ def test_union_type(): def test_assert_dataclass_type(): - @dataclass_json @dataclass - class Args(object): + class Args(DataClassJsonMixin): x: int y: typing.Optional[str] - @dataclass_json @dataclass - class Schema(object): + class Schema(DataClassJsonMixin): x: typing.Optional[Args] = None pt = Schema @@ -971,9 +1379,8 @@ class Schema(object): DataclassTransformer().assert_type(gt, pv) DataclassTransformer().assert_type(Schema, pv) - @dataclass_json @dataclass - class Bar(object): + class Bar(DataClassJsonMixin): x: int pv = Bar(x=3) @@ -983,6 +1390,37 @@ class Bar(object): DataclassTransformer().assert_type(gt, pv) +@dataclass +class ArgsAssert(DataClassJSONMixin): + x: int + y: typing.Optional[str] + + +@dataclass +class SchemaArgsAssert(DataClassJSONMixin): + x: typing.Optional[ArgsAssert] + + +def test_assert_dataclassjsonmixin_type(): + pt = SchemaArgsAssert + lt = TypeEngine.to_literal_type(pt) + gt = TypeEngine.guess_python_type(lt) + pv = SchemaArgsAssert(x=ArgsAssert(x=3, y="hello")) + DataclassTransformer().assert_type(gt, pv) + DataclassTransformer().assert_type(SchemaArgsAssert, pv) + + @dataclass + class Bar(DataClassJSONMixin): + x: int + + pv = Bar(x=3) + with pytest.raises( + TypeTransformerFailedError, + match="Type of Val '' is not an instance of ", + ): + DataclassTransformer().assert_type(gt, pv) + + def test_union_transformer(): assert UnionTransformer.is_optional_type(typing.Optional[int]) assert not UnionTransformer.is_optional_type(str) @@ -1278,9 +1716,8 @@ def __init__(self, number: int): def test_enum_in_dataclass(): - @dataclass_json @dataclass - class Datum(object): + class Datum(DataClassJsonMixin): x: int y: Color @@ -1299,6 +1736,28 @@ class Datum(object): assert datum.y.value == pv.y +def test_enum_in_dataclassjsonmixin(): + @dataclass + class Datum(DataClassJSONMixin): + x: int + y: Color + + lt = TypeEngine.to_literal_type(Datum) + from mashumaro.jsonschema import build_json_schema + + schema = build_json_schema(typing.cast(DataClassJSONMixin, Datum)).to_dict() + assert lt.metadata == schema + + transformer = DataclassTransformer() + ctx = FlyteContext.current_context() + datum = Datum(5, Color.RED) + lv = transformer.to_literal(ctx, datum, Datum, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum.x == pv.x + assert datum.y.value == pv.y + + @pytest.mark.parametrize( "python_value,python_types,expected_literal_map", [ @@ -1535,16 +1994,14 @@ def test_multiple_annotations(): TestSchema = FlyteSchema[kwtypes(some_str=str)] # type: ignore -@dataclass_json @dataclass -class InnerResult: +class InnerResult(DataClassJsonMixin): number: int schema: TestSchema # type: ignore -@dataclass_json @dataclass -class Result: +class Result(DataClassJsonMixin): result: InnerResult schema: TestSchema # type: ignore @@ -1563,10 +2020,54 @@ def test_schema_in_dataclass(): assert o == ot +@dataclass +class InnerResult_dataclassjsonmixin(DataClassJSONMixin): + number: int + schema: TestSchema # type: ignore + + +@dataclass +class Result_dataclassjsonmixin(DataClassJSONMixin): + result: InnerResult_dataclassjsonmixin + schema: TestSchema # type: ignore + + +def test_schema_in_dataclassjsonmixin(): + schema = TestSchema() + df = pd.DataFrame(data={"some_str": ["a", "b", "c"]}) + schema.open().write(df) + o = Result(result=InnerResult(number=1, schema=schema), schema=schema) + ctx = FlyteContext.current_context() + tf = DataclassTransformer() + lt = tf.get_literal_type(Result) + lv = tf.to_literal(ctx, o, Result, lt) + ot = tf.to_python_value(ctx, lv=lv, expected_python_type=Result) + + assert o == ot + + def test_guess_of_dataclass(): - @dataclass_json - @dataclass() - class Foo(object): + @dataclass + class Foo(DataClassJsonMixin): + x: int + y: str + z: typing.Dict[str, int] + + def hello(self): + ... + + lt = TypeEngine.to_literal_type(Foo) + foo = Foo(1, "hello", {"world": 3}) + lv = TypeEngine.to_literal(FlyteContext.current_context(), foo, Foo, lt) + lit_dict = {"a": lv} + lr = LiteralsResolver(lit_dict) + assert lr.get("a", Foo) == foo + assert hasattr(lr.get("a", Foo), "hello") is True + + +def test_guess_of_dataclassjsonmixin(): + @dataclass + class Foo(DataClassJSONMixin): x: int y: str z: typing.Dict[str, int] @@ -1727,3 +2228,107 @@ def test_get_underlying_type(t, expected): def test_dict_get(): assert DictTransformer.get_dict_types(None) == (None, None) + + +def test_DataclassTransformer_get_literal_type(): + @dataclass + class MyDataClassMashumaro(DataClassJsonMixin): + x: int + + @dataclass_json + @dataclass + class MyDataClass: + x: int + + de = DataclassTransformer() + + literal_type = de.get_literal_type(MyDataClass) + assert literal_type is not None + + literal_type = de.get_literal_type(MyDataClassMashumaro) + assert literal_type is not None + + invalid_json_str = "{ unbalanced_braces" + with pytest.raises(Exception): + Literal(scalar=Scalar(generic=_json_format.Parse(invalid_json_str, _struct.Struct()))) + + +def test_DataclassTransformer_to_literal(): + @dataclass + class MyDataClassMashumaro(DataClassJsonMixin): + x: int + + @dataclass_json + @dataclass + class MyDataClass: + x: int + + transformer = DataclassTransformer() + ctx = FlyteContext.current_context() + + my_dat_class_mashumaro = MyDataClassMashumaro(5) + my_data_class = MyDataClass(5) + + lv_mashumaro = transformer.to_literal(ctx, my_dat_class_mashumaro, MyDataClassMashumaro, MyDataClassMashumaro) + assert lv_mashumaro is not None + assert lv_mashumaro.scalar.generic["x"] == 5 + + lv = transformer.to_literal(ctx, my_data_class, MyDataClass, MyDataClass) + assert lv is not None + assert lv.scalar.generic["x"] == 5 + + +def test_DataclassTransformer_to_python_value(): + @dataclass + class MyDataClassMashumaro(DataClassJsonMixin): + x: int + + @dataclass_json + @dataclass + class MyDataClass: + x: int + + de = DataclassTransformer() + + json_str = '{ "x" : 5 }' + mock_literal = Literal(scalar=Scalar(generic=_json_format.Parse(json_str, _struct.Struct()))) + + result = de.to_python_value(FlyteContext.current_context(), mock_literal, MyDataClass) + assert isinstance(result, MyDataClass) + assert result.x == 5 + + result = de.to_python_value(FlyteContext.current_context(), mock_literal, MyDataClassMashumaro) + assert isinstance(result, MyDataClassMashumaro) + assert result.x == 5 + + +def test_DataclassTransformer_guess_python_type(): + @dataclass + class DatumMashumaro(DataClassJSONMixin): + x: int + y: Color + + @dataclass_json + @dataclass + class Datum(DataClassJSONMixin): + x: int + y: Color + + transformer = DataclassTransformer() + ctx = FlyteContext.current_context() + + lt = TypeEngine.to_literal_type(Datum) + datum = Datum(5, Color.RED) + lv = transformer.to_literal(ctx, datum, Datum, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum.x == pv.x + assert datum.y.value == pv.y + + lt = TypeEngine.to_literal_type(DatumMashumaro) + datum_mashumaro = DatumMashumaro(5, Color.RED) + lv = transformer.to_literal(ctx, datum_mashumaro, DatumMashumaro, lt) + gt = transformer.guess_python_type(lt) + pv = transformer.to_python_value(ctx, lv, expected_python_type=gt) + assert datum_mashumaro.x == pv.x + assert datum_mashumaro.y.value == pv.y diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 875c56a4b7..4e32070f9f 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -13,7 +13,7 @@ import pandas import pandas as pd import pytest -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from google.protobuf.struct_pb2 import Struct from pandas._testing import assert_frame_equal from typing_extensions import Annotated, get_origin @@ -387,15 +387,13 @@ def test_user_demo_test(mock_sql): def test_flyte_file_in_dataclass(): - @dataclass_json @dataclass - class InnerFileStruct(object): + class InnerFileStruct(DataClassJsonMixin): a: FlyteFile b: PNGImageFile - @dataclass_json @dataclass - class FileStruct(object): + class FileStruct(DataClassJsonMixin): a: FlyteFile b: InnerFileStruct @@ -437,15 +435,13 @@ def wf(path: str) -> (os.PathLike, FlyteFile): def test_flyte_directory_in_dataclass(): - @dataclass_json @dataclass - class InnerFileStruct(object): + class InnerFileStruct(DataClassJsonMixin): a: FlyteDirectory b: TensorboardLogs - @dataclass_json @dataclass - class FileStruct(object): + class FileStruct(DataClassJsonMixin): a: FlyteDirectory b: InnerFileStruct @@ -470,14 +466,12 @@ def wf(path: str) -> os.PathLike: def test_structured_dataset_in_dataclass(): df = pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]}) - @dataclass_json @dataclass - class InnerDatasetStruct(object): + class InnerDatasetStruct(DataClassJsonMixin): a: StructuredDataset - @dataclass_json @dataclass - class DatasetStruct(object): + class DatasetStruct(DataClassJsonMixin): a: StructuredDataset b: InnerDatasetStruct @@ -1087,9 +1081,8 @@ def t1(a: int) -> MyCustomType: def test_wf_custom_types(): - @dataclass_json @dataclass - class MyCustomType(object): + class MyCustomType(DataClassJsonMixin): x: int y: str @@ -1137,9 +1130,8 @@ def wf(a: int) -> typing.Dict[str, Foo]: def test_dataclass_more(): - @dataclass_json @dataclass - class Datum(object): + class Datum(DataClassJsonMixin): x: int y: str z: typing.Dict[int, str] @@ -1166,9 +1158,8 @@ class Color(Enum): GREEN = "green" BLUE = "blue" - @dataclass_json @dataclass - class Datum(object): + class Datum(DataClassJsonMixin): x: int y: Color @@ -1186,15 +1177,13 @@ def wf(x: int) -> Datum: def test_flyte_schema_dataclass(): TestSchema = FlyteSchema[kwtypes(some_str=str)] - @dataclass_json @dataclass - class InnerResult: + class InnerResult(DataClassJsonMixin): number: int schema: TestSchema - @dataclass_json @dataclass - class Result: + class Result(DataClassJsonMixin): result: InnerResult schema: TestSchema @@ -1514,16 +1503,14 @@ def t2() -> dict: def test_guess_dict4(): - @dataclass_json @dataclass - class Foo(object): + class Foo(DataClassJsonMixin): x: int y: str z: typing.Dict[str, str] - @dataclass_json @dataclass - class Bar(object): + class Bar(DataClassJsonMixin): x: int y: dict z: Foo @@ -1863,3 +1850,58 @@ def wf() -> pandas.DataFrame: expected_df = pandas.DataFrame({"column_1": [5, 7, 9]}) assert expected_df.equals(df) + + +def test_ref_as_key_name(): + class MyOutput(typing.NamedTuple): + # to make sure flytekit itself doesn't use this string + ref: str + + @task + def produce_things() -> MyOutput: + return MyOutput(ref="ref") + + @workflow + def run_things() -> MyOutput: + return produce_things() + + assert run_things().ref == "ref" + + +def test_promise_not_allowed_in_overrides(): + @task + def t1(a: int) -> int: + return a + 1 + + @workflow + def my_wf(a: int, cpu: str) -> int: + return t1(a=a).with_overrides(requests=Resources(cpu=cpu)) + + with pytest.raises(AssertionError): + my_wf(a=1, cpu=1) + + +def test_promise_illegal_resources(): + @task + def t1(a: int) -> int: + return a + 1 + + @workflow + def my_wf(a: int) -> int: + return t1(a=a).with_overrides(requests=Resources(cpu=1)) # type: ignore + + with pytest.raises(AssertionError): + my_wf(a=1) + + +def test_promise_illegal_retries(): + @task + def t1(a: int) -> int: + return a + 1 + + @workflow + def my_wf(a: int, retries: int) -> int: + return t1(a=a).with_overrides(retries=retries) + + with pytest.raises(AssertionError): + my_wf(a=1, retries=1) diff --git a/tests/flytekit/unit/experimental/__init__.py b/tests/flytekit/unit/experimental/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/experimental/test_eager_workflows.py b/tests/flytekit/unit/experimental/test_eager_workflows.py new file mode 100644 index 0000000000..50ed29063c --- /dev/null +++ b/tests/flytekit/unit/experimental/test_eager_workflows.py @@ -0,0 +1,273 @@ +import asyncio +import os +import typing +from pathlib import Path + +import hypothesis.strategies as st +import pandas as pd +import pytest +from hypothesis import given, settings + +from flytekit import dynamic, task, workflow +from flytekit.core.type_engine import TypeTransformerFailedError +from flytekit.experimental import EagerException, eager +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile +from flytekit.types.structured import StructuredDataset + +DEADLINE = 2000 +INTEGER_ST = st.integers(min_value=-10_000_000, max_value=10_000_000) + + +@task +def add_one(x: int) -> int: + return x + 1 + + +@task +def double(x: int) -> int: + return x * 2 + + +@task +def gt_0(x: int) -> bool: + return x > 0 + + +@task +def raises_exc(x: int) -> int: + if x == 0: + raise TypeError + return x + + +@dynamic +def dynamic_wf(x: int) -> int: + out = add_one(x=x) + return double(x=out) + + +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) +def test_simple_eager_workflow(x_input: int): + """Testing simple eager workflow with just tasks.""" + + @eager + async def eager_wf(x: int) -> int: + out = await add_one(x=x) + return await double(x=out) + + result = asyncio.run(eager_wf(x=x_input)) + assert result == (x_input + 1) * 2 + + +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) +def test_conditional_eager_workflow(x_input: int): + """Test eager workfow with conditional logic.""" + + @eager + async def eager_wf(x: int) -> int: + if await gt_0(x=x): + return -1 + return 1 + + result = asyncio.run(eager_wf(x=x_input)) + if x_input > 0: + assert result == -1 + else: + assert result == 1 + + +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) +def test_try_except_eager_workflow(x_input: int): + """Test eager workflow with try/except logic.""" + + @eager + async def eager_wf(x: int) -> int: + try: + return await raises_exc(x=x) + except EagerException: + return -1 + + result = asyncio.run(eager_wf(x=x_input)) + if x_input == 0: + assert result == -1 + else: + assert result == x_input + + +@given(x_input=INTEGER_ST, n_input=st.integers(min_value=1, max_value=20)) +@settings(deadline=DEADLINE, max_examples=5) +def test_gather_eager_workflow(x_input: int, n_input: int): + """Test eager workflow with asyncio gather.""" + + @eager + async def eager_wf(x: int, n: int) -> typing.List[int]: + results = await asyncio.gather(*[add_one(x=x) for _ in range(n)]) + return results + + results = asyncio.run(eager_wf(x=x_input, n=n_input)) + assert results == [x_input + 1 for _ in range(n_input)] + + +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) +def test_eager_workflow_with_dynamic_exception(x_input: int): + """Test eager workflow with dynamic workflow is not supported.""" + + @eager + async def eager_wf(x: int) -> typing.List[int]: + return await dynamic_wf(x=x) + + with pytest.raises(EagerException, match="Eager workflows currently do not work with dynamic workflows"): + asyncio.run(eager_wf(x=x_input)) + + +@eager +async def nested_eager_wf(x: int) -> int: + return await add_one(x=x) + + +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) +def test_nested_eager_workflow(x_input: int): + """Testing running nested eager workflows.""" + + @eager + async def eager_wf(x: int) -> int: + out = await nested_eager_wf(x=x) + return await double(x=out) + + result = asyncio.run(eager_wf(x=x_input)) + assert result == (x_input + 1) * 2 + + +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) +def test_eager_workflow_within_workflow(x_input: int): + """Testing running eager workflow within a static workflow.""" + + @eager + async def eager_wf(x: int) -> int: + return await add_one(x=x) + + @workflow + def wf(x: int) -> int: + out = eager_wf(x=x) + return double(x=out) + + result = wf(x=x_input) + assert result == (x_input + 1) * 2 + + +@workflow +def subworkflow(x: int) -> int: + return add_one(x=x) + + +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) +def test_workflow_within_eager_workflow(x_input: int): + """Testing running a static workflow within an eager workflow.""" + + @eager + async def eager_wf(x: int) -> int: + out = await subworkflow(x=x) + return await double(x=out) + + result = asyncio.run(eager_wf(x=x_input)) + assert result == (x_input + 1) * 2 + + +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) +def test_local_task_eager_workflow_exception(x_input: int): + """Testing simple eager workflow with a local function task doesn't work.""" + + @task + def local_task(x: int) -> int: + return x + + @eager + async def eager_wf_with_local(x: int) -> int: + return await local_task(x=x) + + with pytest.raises(TypeError): + asyncio.run(eager_wf_with_local(x=x_input)) + + +@given(x_input=INTEGER_ST) +@settings(deadline=DEADLINE, max_examples=5) +@pytest.mark.filterwarnings("ignore:coroutine 'AsyncEntity.__call__' was never awaited") +def test_local_workflow_within_eager_workflow_exception(x_input: int): + """Cannot call a locally-defined workflow within an eager workflow""" + + @workflow + def local_wf(x: int) -> int: + return add_one(x=x) + + @eager + async def eager_wf(x: int) -> int: + out = await local_wf(x=x) + return await double(x=out) + + with pytest.raises(TypeTransformerFailedError): + asyncio.run(eager_wf(x=x_input)) + + +@task +def create_structured_dataset() -> StructuredDataset: + df = pd.DataFrame({"a": [1, 2, 3]}) + return StructuredDataset(dataframe=df) + + +@task +def create_file() -> FlyteFile: + fname = "/tmp/flytekit_test_file" + with open(fname, "w") as fh: + fh.write("some data\n") + return FlyteFile(path=fname) + + +@task +def create_directory() -> FlyteDirectory: + dirname = "/tmp/flytekit_test_dir" + Path(dirname).mkdir(exist_ok=True, parents=True) + with open(os.path.join(dirname, "file"), "w") as tmp: + tmp.write("some data\n") + return FlyteDirectory(path=dirname) + + +def test_eager_workflow_with_offloaded_types(): + """Test eager workflow that eager workflows work with offloaded types.""" + + @eager + async def eager_wf_structured_dataset() -> int: + dataset = await create_structured_dataset() + df = dataset.open(pd.DataFrame).all() + return df["a"].sum() + + @eager + async def eager_wf_flyte_file() -> str: + file = await create_file() + with open(file.path) as f: + data = f.read().strip() + return data + + @eager + async def eager_wf_flyte_directory() -> str: + directory = await create_directory() + with open(os.path.join(directory.path, "file")) as f: + data = f.read().strip() + return data + + result = asyncio.run(eager_wf_structured_dataset()) + assert result == 6 + + result = asyncio.run(eager_wf_flyte_file()) + assert result == "some data" + + result = asyncio.run(eager_wf_flyte_directory()) + assert result == "some data" diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index f88ae96987..e9555b2026 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -1,13 +1,15 @@ +import asyncio import json import typing from dataclasses import asdict, dataclass from datetime import timedelta -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import grpc import pytest from flyteidl.admin.agent_pb2 import ( PERMANENT_FAILURE, + RETRYABLE_FAILURE, RUNNING, SUCCEEDED, CreateTaskRequest, @@ -21,14 +23,22 @@ import flytekit.models.interface as interface_models from flytekit import PythonFunctionTask -from flytekit.extend.backend.agent_service import AgentService -from flytekit.extend.backend.base_agent import AgentBase, AgentRegistry, AsyncAgentExecutorMixin, is_terminal_state +from flytekit.extend.backend.agent_service import AsyncAgentService +from flytekit.extend.backend.base_agent import ( + AgentBase, + AgentRegistry, + AsyncAgentExecutorMixin, + convert_to_flyte_state, + get_agent_secret, + is_terminal_state, +) from flytekit.models import literals, task, types from flytekit.models.core.identifier import Identifier, ResourceType from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskTemplate dummy_id = "dummy_id" +loop = asyncio.get_event_loop() @dataclass @@ -38,7 +48,7 @@ class Metadata: class DummyAgent(AgentBase): def __init__(self): - super().__init__(task_type="dummy") + super().__init__(task_type="dummy", asynchronous=False) def create( self, @@ -56,7 +66,28 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT return DeleteTaskResponse() +class AsyncDummyAgent(AgentBase): + def __init__(self): + super().__init__(task_type="async_dummy", asynchronous=True) + + async def async_create( + self, + context: grpc.ServicerContext, + output_prefix: str, + task_template: TaskTemplate, + inputs: typing.Optional[LiteralMap] = None, + ) -> CreateTaskResponse: + return CreateTaskResponse(resource_meta=json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8")) + + async def async_get(self, context: grpc.ServicerContext, resource_meta: bytes) -> GetTaskResponse: + return GetTaskResponse(resource=Resource(state=SUCCEEDED)) + + async def async_delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteTaskResponse: + return DeleteTaskResponse() + + AgentRegistry.register(DummyAgent()) +AgentRegistry.register(AsyncDummyAgent()) task_id = Identifier(resource_type=ResourceType.TASK, project="project", domain="domain", name="t1", version="version") task_metadata = task.TaskMetadata( @@ -92,10 +123,18 @@ def delete(self, context: grpc.ServicerContext, resource_meta: bytes) -> DeleteT custom={}, ) +async_dummy_template = TaskTemplate( + id=task_id, + metadata=task_metadata, + interface=interfaces, + type="async_dummy", + custom={}, +) + def test_dummy_agent(): ctx = MagicMock(spec=grpc.ServicerContext) - agent = AgentRegistry.get_agent(ctx, "dummy") + agent = AgentRegistry.get_agent("dummy") metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") assert agent.create(ctx, "/tmp", dummy_template, task_inputs).resource_meta == metadata_bytes assert agent.get(ctx, metadata_bytes).resource.state == SUCCEEDED @@ -112,30 +151,56 @@ def __init__(self, **kwargs): t.execute() t._task_type = "non-exist-type" - with pytest.raises(Exception, match="Cannot run the task locally"): + with pytest.raises(Exception, match="Cannot find agent for task type: non-exist-type."): t.execute() -def test_agent_server(): - service = AgentService() +@pytest.mark.asyncio +async def test_async_dummy_agent(): + ctx = MagicMock(spec=grpc.ServicerContext) + agent = AgentRegistry.get_agent("async_dummy") + metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") + res = await agent.async_create(ctx, "/tmp", async_dummy_template, task_inputs) + assert res.resource_meta == metadata_bytes + res = await agent.async_get(ctx, metadata_bytes) + assert res.resource.state == SUCCEEDED + res = await agent.async_delete(ctx, metadata_bytes) + assert res == DeleteTaskResponse() + + +@pytest.mark.asyncio +async def run_agent_server(): + service = AsyncAgentService() ctx = MagicMock(spec=grpc.ServicerContext) request = CreateTaskRequest( inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=dummy_template.to_flyte_idl() ) - - metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") - assert service.CreateTask(request, ctx).resource_meta == metadata_bytes - assert ( - service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx).resource.state - == SUCCEEDED - ) - assert ( - service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) - == DeleteTaskResponse() + async_request = CreateTaskRequest( + inputs=task_inputs.to_flyte_idl(), output_prefix="/tmp", template=async_dummy_template.to_flyte_idl() ) + fake_agent = "fake" + metadata_bytes = json.dumps(asdict(Metadata(job_id=dummy_id))).encode("utf-8") + + res = await service.CreateTask(request, ctx) + assert res.resource_meta == metadata_bytes + res = await service.GetTask(GetTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) + assert res.resource.state == SUCCEEDED + res = await service.DeleteTask(DeleteTaskRequest(task_type="dummy", resource_meta=metadata_bytes), ctx) + assert isinstance(res, DeleteTaskResponse) + + res = await service.CreateTask(async_request, ctx) + assert res.resource_meta == metadata_bytes + res = await service.GetTask(GetTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) + assert res.resource.state == SUCCEEDED + res = await service.DeleteTask(DeleteTaskRequest(task_type="async_dummy", resource_meta=metadata_bytes), ctx) + assert isinstance(res, DeleteTaskResponse) - res = service.GetTask(GetTaskRequest(task_type="fake", resource_meta=metadata_bytes), ctx) - assert res.resource.state == PERMANENT_FAILURE + res = await service.GetTask(GetTaskRequest(task_type=fake_agent, resource_meta=metadata_bytes), ctx) + assert res is None + + +def test_agent_server(): + loop.run_in_executor(None, run_agent_server) def test_is_terminal_state(): @@ -143,3 +208,25 @@ def test_is_terminal_state(): assert is_terminal_state(PERMANENT_FAILURE) assert is_terminal_state(PERMANENT_FAILURE) assert not is_terminal_state(RUNNING) + + +def test_convert_to_flyte_state(): + assert convert_to_flyte_state("FAILED") == RETRYABLE_FAILURE + assert convert_to_flyte_state("TIMEDOUT") == RETRYABLE_FAILURE + assert convert_to_flyte_state("CANCELED") == RETRYABLE_FAILURE + + assert convert_to_flyte_state("DONE") == SUCCEEDED + assert convert_to_flyte_state("SUCCEEDED") == SUCCEEDED + assert convert_to_flyte_state("SUCCESS") == SUCCEEDED + + assert convert_to_flyte_state("RUNNING") == RUNNING + + invalid_state = "INVALID_STATE" + with pytest.raises(Exception, match=f"Unrecognized state: {invalid_state.lower()}"): + convert_to_flyte_state(invalid_state) + + +@patch("flytekit.current_context") +def test_get_agent_secret(mocked_context): + mocked_context.return_value.secrets.get.return_value = "mocked token" + assert get_agent_secret("mocked key") == "mocked token" diff --git a/tests/flytekit/unit/extras/pytorch/test_checkpoint.py b/tests/flytekit/unit/extras/pytorch/test_checkpoint.py index 49ad083285..10721bf237 100644 --- a/tests/flytekit/unit/extras/pytorch/test_checkpoint.py +++ b/tests/flytekit/unit/extras/pytorch/test_checkpoint.py @@ -5,16 +5,15 @@ import torch.nn as nn import torch.nn.functional as F import torch.optim as optim -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin from flytekit import task, workflow from flytekit.core.type_engine import TypeTransformerFailedError from flytekit.extras.pytorch import PyTorchCheckpoint -@dataclass_json @dataclass -class Hyperparameters: +class Hyperparameters(DataClassJsonMixin): epochs: int loss: float diff --git a/tests/flytekit/unit/extras/sklearn/test_transformations.py b/tests/flytekit/unit/extras/sklearn/test_transformations.py index 39343f9180..9df16abaaf 100644 --- a/tests/flytekit/unit/extras/sklearn/test_transformations.py +++ b/tests/flytekit/unit/extras/sklearn/test_transformations.py @@ -34,6 +34,9 @@ def get_model(model_type: str) -> BaseEstimator: } x = np.random.normal(size=(10, 2)) y = np.random.randint(2, size=(10,)) + while len(set(y)) < 2: + y = np.random.randint(2, size=(10,)) + model = models_map[model_type]() model.fit(x, y) return model diff --git a/tests/flytekit/unit/extras/tasks/test_shell.py b/tests/flytekit/unit/extras/tasks/test_shell.py index 580cec4394..e70515ec73 100644 --- a/tests/flytekit/unit/extras/tasks/test_shell.py +++ b/tests/flytekit/unit/extras/tasks/test_shell.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import pytest -from dataclasses_json import dataclass_json +from dataclasses_json import DataClassJsonMixin import flytekit from flytekit import kwtypes @@ -215,9 +215,8 @@ def test_reuse_variables_for_both_inputs_and_outputs(): def test_can_use_complex_types_for_inputs_to_f_string_template(): - @dataclass_json @dataclass - class InputArgs: + class InputArgs(DataClassJsonMixin): in_file: CSVFile t = ShellTask( diff --git a/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py b/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py index 392ab695c5..9d7aa12737 100644 --- a/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py +++ b/tests/flytekit/unit/extras/tensorflow/model/test_transformations.py @@ -24,7 +24,7 @@ ) -def get_tf_model(): +def get_tf_model() -> tf.keras.Model: inputs = tf.keras.Input(shape=(32,)) outputs = tf.keras.layers.Dense(1)(inputs) tf_model = tf.keras.Model(inputs, outputs) diff --git a/tests/flytekit/unit/interaction/__init__.py b/tests/flytekit/unit/interaction/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/interaction/test_click_types.py b/tests/flytekit/unit/interaction/test_click_types.py new file mode 100644 index 0000000000..323e7b5c44 --- /dev/null +++ b/tests/flytekit/unit/interaction/test_click_types.py @@ -0,0 +1,159 @@ +import functools +import json +import tempfile +import typing +from datetime import datetime, timedelta +from enum import Enum + +import click +import mock +import pytest +import yaml + +from flytekit import FlyteContextManager +from flytekit.configuration import Config +from flytekit.core.type_engine import TypeEngine +from flytekit.interaction.click_types import ( + DateTimeType, + DurationParamType, + FileParamType, + FlyteLiteralConverter, + JsonParamType, + key_value_callback, +) +from flytekit.models.types import SimpleType +from flytekit.remote import FlyteRemote + + +def test_file_param(): + m = mock.MagicMock() + l = FileParamType().convert(__file__, m, m) + assert l.local + r = FileParamType().convert("https://tmp/file", m, m) + assert r.local is False + + +class Color(Enum): + RED = "red" + GREEN = "green" + BLUE = "blue" + + +@pytest.mark.parametrize( + "python_type, python_value", + [ + (typing.Union[typing.List[int], str, Color], "flyte"), + (typing.Union[typing.List[int], str, Color], "red"), + (typing.Union[typing.List[int], str, Color], [1, 2, 3]), + (typing.List[int], [1, 2, 3]), + (typing.Dict[str, int], {"flyte": 2}), + ], +) +def test_literal_converter(python_type, python_value): + get_upload_url_fn = functools.partial( + FlyteRemote(Config.auto()).client.get_upload_signed_url, project="p", domain="d" + ) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(python_type) + + lc = FlyteLiteralConverter( + ctx, + literal_type=lt, + python_type=python_type, + get_upload_url_fn=get_upload_url_fn, + is_remote=True, + ) + + click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + assert lc.convert(click_ctx, ctx, python_value) == TypeEngine.to_literal(ctx, python_value, python_type, lt) + + +def test_enum_converter(): + pt = typing.Union[str, Color] + + get_upload_url_fn = functools.partial(FlyteRemote(Config.auto()).client.get_upload_signed_url) + click_ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + ctx = FlyteContextManager.current_context() + lt = TypeEngine.to_literal_type(pt) + lc = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn, is_remote=True + ) + union_lt = lc.convert(click_ctx, ctx, "red").scalar.union + + assert union_lt.stored_type.simple == SimpleType.STRING + assert union_lt.stored_type.enum_type is None + + pt = typing.Union[Color, str] + lt = TypeEngine.to_literal_type(typing.Union[Color, str]) + lc = FlyteLiteralConverter( + ctx, literal_type=lt, python_type=pt, get_upload_url_fn=get_upload_url_fn, is_remote=True + ) + union_lt = lc.convert(click_ctx, ctx, "red").scalar.union + + assert union_lt.stored_type.simple is None + assert union_lt.stored_type.enum_type.values == ["red", "green", "blue"] + + +def test_duration_type(): + t = DurationParamType() + assert t.convert(value="1 day", param=None, ctx=None) == timedelta(days=1) + + with pytest.raises(click.BadParameter): + t.convert(None, None, None) + + +def test_datetime_type(): + t = DateTimeType() + + assert t.convert("2020-01-01", None, None) == datetime(2020, 1, 1) + + now = datetime.now() + v = t.convert("now", None, None) + assert v.day == now.day + assert v.month == now.month + + +def test_json_type(): + t = JsonParamType() + assert t.convert(value='{"a": "b"}', param=None, ctx=None) == {"a": "b"} + + with pytest.raises(click.BadParameter): + t.convert(None, None, None) + + # test that it loads a json file + with tempfile.NamedTemporaryFile("w", delete=False) as f: + json.dump({"a": "b"}, f) + f.flush() + assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"} + + # test that if the file is not a valid json, it raises an error + with tempfile.NamedTemporaryFile("w", delete=False) as f: + f.write("asdf") + f.flush() + with pytest.raises(click.BadParameter): + t.convert(value=f.name, param="asdf", ctx=None) + + # test if the file does not exist + with pytest.raises(click.BadParameter): + t.convert(value="asdf", param=None, ctx=None) + + # test if the file is yaml and ends with .yaml it works correctly + with tempfile.NamedTemporaryFile("w", suffix=".yaml", delete=False) as f: + yaml.dump({"a": "b"}, f) + f.flush() + assert t.convert(value=f.name, param=None, ctx=None) == {"a": "b"} + + +def test_key_value_callback(): + """Write a test that verifies that the callback works correctly.""" + ctx = click.Context(click.Command("test_command"), obj={"remote": True}) + assert key_value_callback(ctx, "a", None) is None + assert key_value_callback(ctx, "a", ["a=b"]) == {"a": "b"} + assert key_value_callback(ctx, "a", ["a=b", "c=d"]) == {"a": "b", "c": "d"} + assert key_value_callback(ctx, "a", ["a=b", "c=d", "e=f"]) == {"a": "b", "c": "d", "e": "f"} + with pytest.raises(click.BadParameter): + key_value_callback(ctx, "a", ["a=b", "c"]) + with pytest.raises(click.BadParameter): + key_value_callback(ctx, "a", ["a=b", "c=d", "e"]) + with pytest.raises(click.BadParameter): + key_value_callback(ctx, "a", ["a=b", "c=d", "e=f", "g"]) diff --git a/tests/flytekit/unit/models/core/test_workflow.py b/tests/flytekit/unit/models/core/test_workflow.py index de83f66f78..6775d58940 100644 --- a/tests/flytekit/unit/models/core/test_workflow.py +++ b/tests/flytekit/unit/models/core/test_workflow.py @@ -228,6 +228,73 @@ def test_branch_node(): assert bn.if_else.case.then_node == obj +def test_branch_node_with_none(): + nm = _get_sample_node_metadata() + task = _workflow.TaskNode(reference_id=_generic_id) + bd = _literals.BindingData(scalar=_literals.Scalar(none_type=_literals.Void())) + lt = _literals.Literal(scalar=_literals.Scalar(primitive=_literals.Primitive(integer=99))) + bd2 = _literals.BindingData( + scalar=_literals.Scalar( + union=_literals.Union(value=lt, stored_type=_types.LiteralType(_types.SimpleType.INTEGER)) + ) + ) + binding = _literals.Binding(var="myvar", binding=bd) + binding2 = _literals.Binding(var="myothervar", binding=bd2) + + obj = _workflow.Node( + id="some:node:id", + metadata=nm, + inputs=[binding, binding2], + upstream_node_ids=[], + output_aliases=[], + task_node=task, + ) + + bn = _workflow.BranchNode( + _workflow.IfElseBlock( + case=_workflow.IfBlock( + condition=_condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + then_node=obj, + ), + other=[ + _workflow.IfBlock( + condition=_condition.BooleanExpression( + conjunction=_condition.ConjunctionExpression( + _condition.ConjunctionExpression.LogicalOperator.AND, + _condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + _condition.BooleanExpression( + comparison=_condition.ComparisonExpression( + _condition.ComparisonExpression.Operator.EQ, + _condition.Operand(scalar=_literals.Scalar(none_type=_literals.Void())), + _condition.Operand(primitive=_literals.Primitive(integer=2)), + ) + ), + ) + ), + then_node=obj, + ) + ], + else_node=obj, + ) + ) + + bn2 = _workflow.BranchNode.from_flyte_idl(bn.to_flyte_idl()) + assert bn == bn2 + assert bn.if_else.case.then_node == obj + + def test_task_node_overrides(): overrides = _workflow.TaskNodeOverrides( Resources( diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 94b03b044a..54b6627b8a 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -2,6 +2,7 @@ import pathlib import tempfile import typing +import uuid from collections import OrderedDict from datetime import datetime, timedelta @@ -9,11 +10,11 @@ import pytest from flyteidl.core import compiler_pb2 as _compiler_pb2 from flyteidl.service import dataproxy_pb2 -from mock import MagicMock, patch +from mock import ANY, MagicMock, patch import flytekit.configuration -from flytekit import CronSchedule, LaunchPlan, task, workflow -from flytekit.configuration import Config, DefaultImages, ImageConfig +from flytekit import CronSchedule, LaunchPlan, WorkflowFailurePolicy, task, workflow +from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContextManager from flytekit.core.type_engine import TypeEngine @@ -25,6 +26,7 @@ from flytekit.models.core.identifier import Identifier, ResourceType, WorkflowExecutionIdentifier from flytekit.models.execution import Execution from flytekit.models.task import Task +from flytekit.remote import FlyteTask from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote import FlyteRemote from flytekit.tools.translator import Options, get_serializable, get_serializable_launch_plan @@ -203,6 +205,18 @@ def test_more_stuff(mock_client): assert computed_v2 != computed_v3 +def test_get_extra_headers_azure_blob_storage(): + native_url = "abfs://flyte@storageaccount/container/path/to/file" + headers = FlyteRemote.get_extra_headers_for_protocol(native_url) + assert headers["x-ms-blob-type"] == "BlockBlob" + + +def test_get_extra_headers_s3(): + native_url = "s3://flyte@storageaccount/container/path/to/file" + headers = FlyteRemote.get_extra_headers_for_protocol(native_url) + assert headers == {} + + @patch("flytekit.remote.remote.SynchronousFlyteClient") def test_generate_console_http_domain_sandbox_rewrite(mock_client): _, temp_filename = tempfile.mkstemp(suffix=".yaml") @@ -341,8 +355,18 @@ def test_launch_backfill(remote): ), ) - wf = remote.launch_backfill("p", "d", start_date, end_date, "daily2", "v1", dry_run=True) + wf = remote.launch_backfill( + "p", + "d", + start_date, + end_date, + "daily2", + "v1", + dry_run=True, + failure_policy=WorkflowFailurePolicy.FAIL_IMMEDIATELY, + ) assert wf + assert wf.workflow_metadata.on_failure == WorkflowFailurePolicy.FAIL_IMMEDIATELY @mock.patch("flytekit.remote.remote.FlyteRemote.client") @@ -361,3 +385,53 @@ def test_local_server(mock_client): ) lr = rr.get("flyte://v1/flytesnacks/development/f6988c7bdad554a4da7a/n0/o") assert lr.get("hello", int) == 55 + + +@mock.patch("flytekit.remote.remote.uuid") +@mock.patch("flytekit.remote.remote.FlyteRemote.client") +def test_execution_name(mock_client, mock_uuid): + test_uuid = uuid.UUID("16fd2706-8baf-433b-82eb-8c7fada847da") + mock_uuid.uuid4.return_value = test_uuid + remote = FlyteRemote(config=Config.auto(), default_project="project", default_domain="domain") + + default_img = Image(name="default", fqn="test", tag="tag") + serialization_settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + tk_spec = get_serializable(OrderedDict(), serialization_settings, tk) + ft = FlyteTask.promote_from_model(tk_spec.template) + + remote._execute( + entity=ft, + inputs={"t": datetime.now(), "v": 0}, + execution_name="execution-test", + ) + remote._execute( + entity=ft, + inputs={"t": datetime.now(), "v": 0}, + execution_name_prefix="execution-test", + ) + remote._execute( + entity=ft, + inputs={"t": datetime.now(), "v": 0}, + ) + mock_client.create_execution.assert_has_calls( + [ + mock.call(ANY, ANY, "execution-test", ANY, ANY), + mock.call(ANY, ANY, "execution-test-" + test_uuid.hex[:19], ANY, ANY), + mock.call(ANY, ANY, "f" + test_uuid.hex[:19], ANY, ANY), + ] + ) + with pytest.raises( + ValueError, match="Only one of execution_name and execution_name_prefix can be set, but got both set" + ): + remote._execute( + entity=ft, + inputs={"t": datetime.now(), "v": 0}, + execution_name="execution-test", + execution_name_prefix="execution-test", + ) diff --git a/tests/flytekit/unit/sensor/__init__.py b/tests/flytekit/unit/sensor/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/flytekit/unit/sensor/test_file_sensor.py b/tests/flytekit/unit/sensor/test_file_sensor.py new file mode 100644 index 0000000000..f6a50836be --- /dev/null +++ b/tests/flytekit/unit/sensor/test_file_sensor.py @@ -0,0 +1,31 @@ +import tempfile + +from flytekit import task, workflow +from flytekit.configuration import ImageConfig, SerializationSettings +from flytekit.sensor.file_sensor import FileSensor +from tests.flytekit.unit.test_translator import default_img + + +def test_sensor_task(): + sensor = FileSensor(name="test_sensor") + assert sensor.task_type == "sensor" + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + assert sensor.get_custom(settings) == {"sensor_module": "flytekit.sensor.file_sensor", "sensor_name": "FileSensor"} + tmp_file = tempfile.NamedTemporaryFile() + + @task() + def t1(): + print("flyte") + + @workflow + def wf(): + sensor(tmp_file.name) >> t1() + + if __name__ == "__main__": + wf()