From fd98e3a7be56c95d0ca64f4d59983df5b3812cee Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Thu, 20 Mar 2025 15:59:51 -0600 Subject: [PATCH 01/12] init xarray plugin Signed-off-by: Len Strnad --- plugins/flytekit-xarray-zarr/README.md | 10 ++ .../flytekitplugins/xarray_zarr/__init__.py | 14 +++ .../xarray_zarr/xarray_transformers.py | 100 ++++++++++++++++++ plugins/flytekit-xarray-zarr/setup.py | 35 ++++++ .../tests/test_xarray_zarr_plugin.py | 48 +++++++++ 5 files changed, 207 insertions(+) create mode 100644 plugins/flytekit-xarray-zarr/README.md create mode 100644 plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py create mode 100644 plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py create mode 100644 plugins/flytekit-xarray-zarr/setup.py create mode 100644 plugins/flytekit-xarray-zarr/tests/test_xarray_zarr_plugin.py diff --git a/plugins/flytekit-xarray-zarr/README.md b/plugins/flytekit-xarray-zarr/README.md new file mode 100644 index 0000000000..b09ccf7b24 --- /dev/null +++ b/plugins/flytekit-xarray-zarr/README.md @@ -0,0 +1,10 @@ +# Flytekit Xarray Zarr Plugin +The Xarray Zarr plugin adds support to persist xarray datasets and dataarrays to zarr between tasks. If a dask cluster is present (see flytekitplugins-dask), it will attempt to connect to the distributed client. This prevents the need to explicitly connect to a distributed client within the task. + +If deck is enabled, we also render the datasets/dataarrays to html. + +To install the plugin, run the following command: + +```bash +pip install flytekitplugins-xarray-zarr +``` diff --git a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py new file mode 100644 index 0000000000..b1bab7445f --- /dev/null +++ b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py @@ -0,0 +1,14 @@ +""" +.. currentmodule:: flytekitplugins.geopandas + +This package contains things that are useful when extending Flytekit. + +.. autosummary:: + :template: custom.rst + :toctree: generated/ + + GeoPandasDecodingHandler + GeoPandasEncodingHandler +""" + +from .xarray_transformers import XarrayDaZarrTypeTransformer, XarrayZarrTypeTransformer diff --git a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py new file mode 100644 index 0000000000..ad31062741 --- /dev/null +++ b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py @@ -0,0 +1,100 @@ +import typing + +import xarray as xr +from dask.distributed import Client + +from flytekit import ( + Blob, + BlobMetadata, + BlobType, + FlyteContext, + Literal, + LiteralType, + Scalar, +) +from flytekit.extend import TypeEngine, TypeTransformer + + +class XarrayZarrTypeTransformer(TypeTransformer[xr.Dataset]): + _TYPE_INFO = BlobType(format="binary", dimensionality=BlobType.BlobDimensionality.MULTIPART) + + def __init__(self) -> None: + super().__init__(name="xarray-zarr", t=xr.Dataset) + + def get_literal_type(self, t: typing.Type[xr.Dataset]) -> LiteralType: + return LiteralType(blob=self._TYPE_INFO) + + def to_literal( + self, + ctx: FlyteContext, + python_val: xr.Dataset, + python_type: typing.Type[xr.Dataset], + expected: LiteralType, + ) -> Literal: + remote_dir = ctx.file_access.get_random_remote_path("data.zarr") + # this actually works if there is a dask cluster! Is this safe in geneal? + with Client(): + python_val.to_zarr(remote_dir, mode="w") + return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO)))) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: typing.Type[xr.Dataset], + ) -> xr.Dataset: + return xr.open_zarr(lv.scalar.blob.uri) + + def to_html( + self, + ctx: FlyteContext, + python_val: xr.Dataset, + expected_python_type: LiteralType, + ) -> str: + assert isinstance(python_val, (xr.Dataset, xr.DataArray)) + return python_val._repr_html_() + + +class XarrayDaZarrTypeTransformer(TypeTransformer[xr.DataArray]): + _TYPE_INFO = BlobType(format="binary", dimensionality=BlobType.BlobDimensionality.MULTIPART) + + def __init__(self) -> None: + super().__init__(name="xarray-zarr-da", t=xr.DataArray) + + def get_literal_type(self, t: typing.Type[xr.DataArray]) -> LiteralType: + return LiteralType(blob=self._TYPE_INFO) + + def to_literal( + self, + ctx: FlyteContext, + python_val: xr.DataArray, + python_type: typing.Type[xr.DataArray], + expected: LiteralType, + ) -> Literal: + remote_dir = ctx.file_access.get_random_remote_path("data.zarr") + # this actually works if there is a dask cluster! Is this safe in geneal? + with Client(timeout=120): + python_val.to_zarr(remote_dir, mode="w") + return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO)))) + + def to_python_value( + self, + ctx: FlyteContext, + lv: Literal, + expected_python_type: typing.Type[xr.DataArray], + ) -> xr.DataArray: + # xr.open_zarr always opens a dataset, so we take the first variable + return list(xr.open_zarr(lv.scalar.blob.uri).data_vars.values())[0] + + def to_html( + self, + ctx: FlyteContext, + python_val: xr.DataArray, + expected_python_type: LiteralType, + ) -> str: + assert isinstance(python_val, (xr.DataArray, xr.DataArray)) + return python_val._repr_html_() + + +TypeEngine.register(XarrayZarrTypeTransformer()) +TypeEngine.register(XarrayDaZarrTypeTransformer()) diff --git a/plugins/flytekit-xarray-zarr/setup.py b/plugins/flytekit-xarray-zarr/setup.py new file mode 100644 index 0000000000..b8acd91d3b --- /dev/null +++ b/plugins/flytekit-xarray-zarr/setup.py @@ -0,0 +1,35 @@ +from setuptools import setup + +PLUGIN_NAME = "xarray_zarr" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "xarray", "zarr", "distributed"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="Xarray Zarr plugin for flytekit", + namespace_packages=["flytekitplugins"], + packages=[f"flytekitplugins.{PLUGIN_NAME}"], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.9", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-xarray-zarr/tests/test_xarray_zarr_plugin.py b/plugins/flytekit-xarray-zarr/tests/test_xarray_zarr_plugin.py new file mode 100644 index 0000000000..fcf7d06065 --- /dev/null +++ b/plugins/flytekit-xarray-zarr/tests/test_xarray_zarr_plugin.py @@ -0,0 +1,48 @@ +from flytekit import task, workflow +import numpy as np +import dask.array as da +import xarray as xr + + +def _sample_dataset() -> xr.Dataset: + return xr.Dataset( + {"test": (("x", "y"), da.random.uniform(size=(32, 32)))}, + ) + + +def test_xarray_zarr_dataarray_plugin(): + + @task + def _generate_xarray() -> xr.DataArray: + return _sample_dataset()["test"] + + @task + def _consume_xarray(ds: xr.DataArray) -> xr.DataArray: + return ds + + @workflow + def _xarray_wf() -> xr.DataArray: + ds = _generate_xarray() + return _consume_xarray(ds=ds) + + array = _xarray_wf() + assert isinstance(array, xr.DataArray) + + +def test_xarray_zarr_dataset_plugin(): + + @task + def _generate_xarray() -> xr.Dataset: + return _sample_dataset() + + @task + def _consume_xarray(ds: xr.Dataset) -> xr.Dataset: + return ds + + @workflow + def _xarray_wf() -> xr.Dataset: + ds = _generate_xarray() + return _consume_xarray(ds=ds) + + array = _xarray_wf() + assert isinstance(array, xr.Dataset) From 2e02f3f3e52de8757ac6cd3e5624d4ec4344e09e Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Thu, 20 Mar 2025 17:51:29 -0600 Subject: [PATCH 02/12] set timeout for 120s Signed-off-by: Len Strnad --- .../flytekitplugins/xarray_zarr/xarray_transformers.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py index ad31062741..b6619ccef0 100644 --- a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py +++ b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py @@ -32,8 +32,10 @@ def to_literal( expected: LiteralType, ) -> Literal: remote_dir = ctx.file_access.get_random_remote_path("data.zarr") - # this actually works if there is a dask cluster! Is this safe in geneal? - with Client(): + # Opening with the dask client will attach the client eliminating the + # need for users to connect to the client if a task tasks a xr.Dataset + # type. + with Client(timeout=120): python_val.to_zarr(remote_dir, mode="w") return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO)))) @@ -72,7 +74,9 @@ def to_literal( expected: LiteralType, ) -> Literal: remote_dir = ctx.file_access.get_random_remote_path("data.zarr") - # this actually works if there is a dask cluster! Is this safe in geneal? + # Opening with the dask client will attach the client eliminating the + # need for users to connect to the client if a task tasks a xr.Dataset + # type. with Client(timeout=120): python_val.to_zarr(remote_dir, mode="w") return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO)))) From 770a00f194e1754c474f572a8c50a5c6384716f0 Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Thu, 20 Mar 2025 19:17:08 -0600 Subject: [PATCH 03/12] add example in docs Signed-off-by: Len Strnad --- plugins/flytekit-xarray-zarr/README.md | 39 +++++++++++++++++++++++++- 1 file changed, 38 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-xarray-zarr/README.md b/plugins/flytekit-xarray-zarr/README.md index b09ccf7b24..26b1d547bb 100644 --- a/plugins/flytekit-xarray-zarr/README.md +++ b/plugins/flytekit-xarray-zarr/README.md @@ -1,5 +1,5 @@ # Flytekit Xarray Zarr Plugin -The Xarray Zarr plugin adds support to persist xarray datasets and dataarrays to zarr between tasks. If a dask cluster is present (see flytekitplugins-dask), it will attempt to connect to the distributed client. This prevents the need to explicitly connect to a distributed client within the task. +The Xarray Zarr plugin adds support to persist xarray datasets and dataarrays to zarr between tasks. If a dask cluster is present (see flytekitplugins-dask), it will attempt to connect to the distributed client before we call `.to_zarr(url)` call. This prevents the need to explicitly connect to a distributed client within the task. If deck is enabled, we also render the datasets/dataarrays to html. @@ -8,3 +8,40 @@ To install the plugin, run the following command: ```bash pip install flytekitplugins-xarray-zarr ``` + +## Example + +```python +import dask.array as da +import xarray as xr +from flytekit import task, workflow +from flytekitplugins.dask import Dask, WorkerGroup + +dask_task = task( + task_config=Dask(workers=WorkerGroup(number_of_workers=6)), + enable_deck=True, # enables input/output html views of xarray objects +) + + +@dask_task +def generate_xarray_task() -> xr.Dataset: + return xr.Dataset( + { + "variable": ( + ("time", "x", "y"), + da.random.uniform(size=(1024, 1024, 1024)), + ) + }, + ) + + +@dask_task +def preprocess_xarray_task(ds: xr.Dataset) -> xr.Dataset: + return ds * 2 + + +@workflow +def test_xarray_workflow() -> xr.DataArray: + ds = generate_xarray_task() + return preprocess_xarray_task(ds=ds) +``` From 20e84735c1211782ffd1fc3b3532e3f5db96f025 Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Thu, 20 Mar 2025 19:20:30 -0600 Subject: [PATCH 04/12] add plugin to plugin names Signed-off-by: Len Strnad --- .github/workflows/pythonbuild.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 439c94a29c..f7c62fd689 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -344,6 +344,7 @@ jobs: - flytekit-sqlalchemy - flytekit-vaex - flytekit-whylogs + - flytekit-xarray-zarr exclude: - python-version: 3.9 plugin-names: "flytekit-aws-sagemaker" From bc1dc21bbbeedbc7620032905c608d9bee8a7a20 Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Thu, 20 Mar 2025 19:31:09 -0600 Subject: [PATCH 05/12] update `__init__.py` Signed-off-by: Len Strnad --- .../flytekitplugins/xarray_zarr/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py index b1bab7445f..9f8f8f9fc1 100644 --- a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py +++ b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py @@ -1,5 +1,5 @@ """ -.. currentmodule:: flytekitplugins.geopandas +.. currentmodule:: flytekitplugins.xarray_zarr This package contains things that are useful when extending Flytekit. @@ -7,8 +7,8 @@ :template: custom.rst :toctree: generated/ - GeoPandasDecodingHandler - GeoPandasEncodingHandler + XarrayDaZarrTypeTransformer + XarrayZarrTypeTransformer """ from .xarray_transformers import XarrayDaZarrTypeTransformer, XarrayZarrTypeTransformer From 97554e7fb7670eb74c79f33897dca212f256a087 Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Thu, 20 Mar 2025 19:34:16 -0600 Subject: [PATCH 06/12] add lazy module type checking Signed-off-by: Len Strnad --- .../xarray_zarr/xarray_transformers.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py index b6619ccef0..e21fc4e407 100644 --- a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py +++ b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py @@ -1,8 +1,5 @@ import typing -import xarray as xr -from dask.distributed import Client - from flytekit import ( Blob, BlobMetadata, @@ -11,9 +8,17 @@ Literal, LiteralType, Scalar, + lazy_module, ) from flytekit.extend import TypeEngine, TypeTransformer +if typing.TYPE_CHECKING: + import xarray as xr + from dask.distributed import Client +else: + pandas = lazy_module("xarray") + pyarrow = lazy_module("dask.distributed") + class XarrayZarrTypeTransformer(TypeTransformer[xr.Dataset]): _TYPE_INFO = BlobType(format="binary", dimensionality=BlobType.BlobDimensionality.MULTIPART) From a7ae40c12c8e160dcb0d0b275a69dd3bbe6b96dc Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Thu, 20 Mar 2025 19:36:45 -0600 Subject: [PATCH 07/12] add xarray zarr to sources in `setup.py` Signed-off-by: Len Strnad --- plugins/setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/setup.py b/plugins/setup.py index ff1d5c6c88..2899842bf4 100644 --- a/plugins/setup.py +++ b/plugins/setup.py @@ -49,6 +49,7 @@ "flytekitplugins-sqlalchemy": "flytekit-sqlalchemy", "flytekitplugins-vaex": "flytekit-vaex", "flytekitplugins-whylogs": "flytekit-whylogs", + "flytekitplugins-xarray-zarr": "flytekit-xarray-zarr", } From 47f768ff9b7ed4bc61b6685c71410dccd5b3c61b Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Thu, 20 Mar 2025 19:41:42 -0600 Subject: [PATCH 08/12] match microlib_name 'ing convention to deck-standard Signed-off-by: Len Strnad --- .../flytekitplugins/xarray_zarr/__init__.py | 2 +- plugins/flytekit-xarray-zarr/setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py index 9f8f8f9fc1..1abdb0ee7d 100644 --- a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py +++ b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py @@ -1,5 +1,5 @@ """ -.. currentmodule:: flytekitplugins.xarray_zarr +.. currentmodule:: flytekitplugins.xarray This package contains things that are useful when extending Flytekit. diff --git a/plugins/flytekit-xarray-zarr/setup.py b/plugins/flytekit-xarray-zarr/setup.py index b8acd91d3b..611907d565 100644 --- a/plugins/flytekit-xarray-zarr/setup.py +++ b/plugins/flytekit-xarray-zarr/setup.py @@ -1,8 +1,8 @@ from setuptools import setup -PLUGIN_NAME = "xarray_zarr" +PLUGIN_NAME = "xarray" -microlib_name = f"flytekitplugins-{PLUGIN_NAME}" +microlib_name = f"flytekitplugins-{PLUGIN_NAME}-zarr" plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "xarray", "zarr", "distributed"] From 3ea448337f4e73455650b827897831f621c161c3 Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Mon, 24 Mar 2025 09:33:09 -0600 Subject: [PATCH 09/12] alphabetical order plugin_requires Signed-off-by: Len Strnad --- plugins/flytekit-xarray-zarr/setup.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/plugins/flytekit-xarray-zarr/setup.py b/plugins/flytekit-xarray-zarr/setup.py index 611907d565..f36b2d14c6 100644 --- a/plugins/flytekit-xarray-zarr/setup.py +++ b/plugins/flytekit-xarray-zarr/setup.py @@ -4,7 +4,12 @@ microlib_name = f"flytekitplugins-{PLUGIN_NAME}-zarr" -plugin_requires = ["flytekit>=1.3.0b2,<2.0.0", "xarray", "zarr", "distributed"] +plugin_requires = [ + "dask[distributed]>=2022.10.2", + "flytekit>=1.3.0b2,<2.0.0", + "xarray", + "zarr", +] __version__ = "0.0.0+develop" From 2fcf85ce10de92c47d86fa9b82a84a90198dd149 Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Mon, 24 Mar 2025 10:35:02 -0600 Subject: [PATCH 10/12] move runtime types outside type checking + rename Signed-off-by: Len Strnad --- .../{xarray_zarr => xarray}/__init__.py | 0 .../xarray_transformers.py | 21 +++++++------------ 2 files changed, 7 insertions(+), 14 deletions(-) rename plugins/flytekit-xarray-zarr/flytekitplugins/{xarray_zarr => xarray}/__init__.py (100%) rename plugins/flytekit-xarray-zarr/flytekitplugins/{xarray_zarr => xarray}/xarray_transformers.py (85%) diff --git a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray/__init__.py similarity index 100% rename from plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/__init__.py rename to plugins/flytekit-xarray-zarr/flytekitplugins/xarray/__init__.py diff --git a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray/xarray_transformers.py similarity index 85% rename from plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py rename to plugins/flytekit-xarray-zarr/flytekitplugins/xarray/xarray_transformers.py index e21fc4e407..2304576f1c 100644 --- a/plugins/flytekit-xarray-zarr/flytekitplugins/xarray_zarr/xarray_transformers.py +++ b/plugins/flytekit-xarray-zarr/flytekitplugins/xarray/xarray_transformers.py @@ -1,5 +1,8 @@ import typing +import dask.distributed as dd + +import xarray as xr from flytekit import ( Blob, BlobMetadata, @@ -8,23 +11,15 @@ Literal, LiteralType, Scalar, - lazy_module, ) from flytekit.extend import TypeEngine, TypeTransformer -if typing.TYPE_CHECKING: - import xarray as xr - from dask.distributed import Client -else: - pandas = lazy_module("xarray") - pyarrow = lazy_module("dask.distributed") - class XarrayZarrTypeTransformer(TypeTransformer[xr.Dataset]): _TYPE_INFO = BlobType(format="binary", dimensionality=BlobType.BlobDimensionality.MULTIPART) def __init__(self) -> None: - super().__init__(name="xarray-zarr", t=xr.Dataset) + super().__init__(name="Xarray Dataset", t=xr.Dataset) def get_literal_type(self, t: typing.Type[xr.Dataset]) -> LiteralType: return LiteralType(blob=self._TYPE_INFO) @@ -40,7 +35,7 @@ def to_literal( # Opening with the dask client will attach the client eliminating the # need for users to connect to the client if a task tasks a xr.Dataset # type. - with Client(timeout=120): + with dd.Client(timeout=120): python_val.to_zarr(remote_dir, mode="w") return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO)))) @@ -58,7 +53,6 @@ def to_html( python_val: xr.Dataset, expected_python_type: LiteralType, ) -> str: - assert isinstance(python_val, (xr.Dataset, xr.DataArray)) return python_val._repr_html_() @@ -66,7 +60,7 @@ class XarrayDaZarrTypeTransformer(TypeTransformer[xr.DataArray]): _TYPE_INFO = BlobType(format="binary", dimensionality=BlobType.BlobDimensionality.MULTIPART) def __init__(self) -> None: - super().__init__(name="xarray-zarr-da", t=xr.DataArray) + super().__init__(name="Xarray DataArray", t=xr.DataArray) def get_literal_type(self, t: typing.Type[xr.DataArray]) -> LiteralType: return LiteralType(blob=self._TYPE_INFO) @@ -82,7 +76,7 @@ def to_literal( # Opening with the dask client will attach the client eliminating the # need for users to connect to the client if a task tasks a xr.Dataset # type. - with Client(timeout=120): + with dd.Client(timeout=120): python_val.to_zarr(remote_dir, mode="w") return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO)))) @@ -101,7 +95,6 @@ def to_html( python_val: xr.DataArray, expected_python_type: LiteralType, ) -> str: - assert isinstance(python_val, (xr.DataArray, xr.DataArray)) return python_val._repr_html_() From 492e31c6213a4f1a0047aacbf1355892733e10e9 Mon Sep 17 00:00:00 2001 From: Len Strnad Date: Mon, 24 Mar 2025 10:38:08 -0600 Subject: [PATCH 11/12] update readme Signed-off-by: Len Strnad --- plugins/flytekit-xarray-zarr/README.md | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/plugins/flytekit-xarray-zarr/README.md b/plugins/flytekit-xarray-zarr/README.md index 26b1d547bb..82e423982e 100644 --- a/plugins/flytekit-xarray-zarr/README.md +++ b/plugins/flytekit-xarray-zarr/README.md @@ -17,13 +17,11 @@ import xarray as xr from flytekit import task, workflow from flytekitplugins.dask import Dask, WorkerGroup -dask_task = task( + +@task( task_config=Dask(workers=WorkerGroup(number_of_workers=6)), - enable_deck=True, # enables input/output html views of xarray objects + enable_deck=True, ) - - -@dask_task def generate_xarray_task() -> xr.Dataset: return xr.Dataset( { @@ -35,13 +33,16 @@ def generate_xarray_task() -> xr.Dataset: ) -@dask_task +@task( + task_config=Dask(workers=WorkerGroup(number_of_workers=6)), + enable_deck=True, +) def preprocess_xarray_task(ds: xr.Dataset) -> xr.Dataset: return ds * 2 @workflow -def test_xarray_workflow() -> xr.DataArray: +def xarray_workflow() -> xr.Dataset: ds = generate_xarray_task() return preprocess_xarray_task(ds=ds) ``` From 53de8265ca6dbf1e1a9dee9640046e50f2231b5a Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Fri, 21 Mar 2025 08:32:47 -0700 Subject: [PATCH 12/12] Add LOCAL_DYNAMIC_TASK_EXECUTION mode (#3202) Signed-off-by: Kevin Su Signed-off-by: Len Strnad --- flytekit/core/context_manager.py | 3 + flytekit/core/node_creation.py | 2 +- flytekit/core/promise.py | 2 +- flytekit/core/python_function_task.py | 2 +- .../unit/core/test_dataclass_dynamic.py | 83 +++++++++++++++++++ 5 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 tests/flytekit/unit/core/test_dataclass_dynamic.py diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index c8d4d92b40..6378f42706 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -561,6 +561,8 @@ class Mode(Enum): EAGER_LOCAL_EXECUTION = 6 + LOCAL_DYNAMIC_TASK_EXECUTION = 7 + mode: Optional[ExecutionState.Mode] working_dir: Union[os.PathLike, str] engine_dir: Optional[Union[os.PathLike, str]] @@ -622,6 +624,7 @@ def is_local_execution(self) -> bool: self.mode == ExecutionState.Mode.LOCAL_TASK_EXECUTION or self.mode == ExecutionState.Mode.LOCAL_WORKFLOW_EXECUTION or self.mode == ExecutionState.Mode.EAGER_LOCAL_EXECUTION + or self.mode == ExecutionState.Mode.LOCAL_DYNAMIC_TASK_EXECUTION ) diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index 3699078bf4..b9c74d1cb3 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -138,7 +138,7 @@ def create_node( # https://github.com/flyteorg/flytekit/blob/0815345faf0fae5dc26746a43d4bda4cc2cdf830/flytekit/core/python_function_task.py#L262 elif ctx.execution_state and ( ctx.execution_state.is_local_execution() - or ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION + or ctx.execution_state.mode == ExecutionState.Mode.LOCAL_DYNAMIC_TASK_EXECUTION ): if isinstance(entity, RemoteEntity): raise AssertionError(f"Remote entities are not yet runnable locally {entity.name}") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index a9b8dd284b..4d96df6602 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1539,7 +1539,7 @@ def flyte_entity_call_handler( else: raise ValueError(f"Received an output when workflow local execution expected None. Received: {result}") - if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.DYNAMIC_TASK_EXECUTION: + if ctx.execution_state and ctx.execution_state.mode == ExecutionState.Mode.LOCAL_DYNAMIC_TASK_EXECUTION: return result if (1 < expected_outputs == len(cast(Tuple[Promise], result))) or ( diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 206624fc19..48d29e6625 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -346,7 +346,7 @@ def dynamic_execute(self, task_function: Callable, **kwargs) -> Any: logger.debug(f"Executing Dynamic workflow, using raw inputs {kwargs}") self._create_and_cache_dynamic_workflow() if self.execution_mode == self.ExecutionBehavior.DYNAMIC: - es = ctx.new_execution_state().with_params(mode=ExecutionState.Mode.DYNAMIC_TASK_EXECUTION) + es = ctx.new_execution_state().with_params(mode=ExecutionState.Mode.LOCAL_DYNAMIC_TASK_EXECUTION) else: es = cast(ExecutionState, ctx.execution_state) with FlyteContextManager.with_context(ctx.with_execution_state(es)): diff --git a/tests/flytekit/unit/core/test_dataclass_dynamic.py b/tests/flytekit/unit/core/test_dataclass_dynamic.py new file mode 100644 index 0000000000..a121b530fe --- /dev/null +++ b/tests/flytekit/unit/core/test_dataclass_dynamic.py @@ -0,0 +1,83 @@ +import sys +from dataclasses import dataclass +from typing import List, Optional + +import pytest +from dataclasses_json import DataClassJsonMixin +from mashumaro.mixins.json import DataClassJSONMixin +from typing_extensions import Annotated + +from flytekit.core.task import task +from flytekit.core.type_engine import DataclassTransformer +from flytekit.core.workflow import workflow + +from dataclasses import dataclass, field +import base64 +import pickle +from flytekit.core.resources import Resources +from flytekit.core.dynamic_workflow_task import dynamic +from mashumaro.config import BaseConfig + + +class SimpleObjectOriginal: + def __init__(self, a: Optional[str] = None, b: Optional[int] = None): + self.a = a + self.b = b + + +def encode_object(obj: SimpleObjectOriginal) -> str: + s = base64.b64encode(pickle.dumps(obj)).decode("utf-8") + print(f"Encode object to {s}") + return s + + +def decode_object(object_value: str) -> SimpleObjectOriginal: + print(f"Decoding from string {object_value}") + return pickle.loads(base64.b64decode(object_value.encode("utf-8"))) + + +@dataclass +class SimpleObjectOriginalMixin(DataClassJsonMixin): + # This is a mixin that adds a SimpleObjectOriginal field to any dataclass. + + simple_object: SimpleObjectOriginal = field( + default_factory=SimpleObjectOriginal, + ) + + class Config(BaseConfig): + serialization_strategy = { + SimpleObjectOriginal: { + # you can use specific str values for datetime here as well + "deserialize": decode_object, + "serialize": encode_object, + }, + } + + +@dataclass +class ParentDC(SimpleObjectOriginalMixin): + parent_val: str = "" + + +@task +def generate_result() -> SimpleObjectOriginalMixin: + return SimpleObjectOriginalMixin(simple_object=SimpleObjectOriginal(a="a", b=1)) + + +@task +def check_result(obj: SimpleObjectOriginalMixin): + assert obj.simple_object is not None + +@task +def generate_int() -> int: + return 42 + + +def test_simple_object_yee(): + @dynamic + def my_dynamic(): + result = generate_result() + n = check_result(obj=result) + n.with_overrides(limits=Resources(cpu="3", mem="500Mi")) + + my_dynamic()