Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Xarray zarr persistence support #3205

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ jobs:
- flytekit-sqlalchemy
- flytekit-vaex
- flytekit-whylogs
- flytekit-xarray-zarr
exclude:
- python-version: 3.9
plugin-names: "flytekit-aws-sagemaker"
Expand Down
48 changes: 48 additions & 0 deletions plugins/flytekit-xarray-zarr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Flytekit Xarray Zarr Plugin
The Xarray Zarr plugin adds support to persist xarray datasets and dataarrays to zarr between tasks. If a dask cluster is present (see flytekitplugins-dask), it will attempt to connect to the distributed client before we call `.to_zarr(url)` call. This prevents the need to explicitly connect to a distributed client within the task.

If deck is enabled, we also render the datasets/dataarrays to html.

To install the plugin, run the following command:

```bash
pip install flytekitplugins-xarray-zarr
```

## Example

```python
import dask.array as da
import xarray as xr
from flytekit import task, workflow
from flytekitplugins.dask import Dask, WorkerGroup


@task(
task_config=Dask(workers=WorkerGroup(number_of_workers=6)),
enable_deck=True,
)
def generate_xarray_task() -> xr.Dataset:
return xr.Dataset(
{
"variable": (
("time", "x", "y"),
da.random.uniform(size=(1024, 1024, 1024)),
)
},
)


@task(
task_config=Dask(workers=WorkerGroup(number_of_workers=6)),
enable_deck=True,
)
def preprocess_xarray_task(ds: xr.Dataset) -> xr.Dataset:
return ds * 2


@workflow
def xarray_workflow() -> xr.Dataset:
ds = generate_xarray_task()
return preprocess_xarray_task(ds=ds)
```
14 changes: 14 additions & 0 deletions plugins/flytekit-xarray-zarr/flytekitplugins/xarray/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
.. currentmodule:: flytekitplugins.xarray

This package contains things that are useful when extending Flytekit.

.. autosummary::
:template: custom.rst
:toctree: generated/

XarrayDaZarrTypeTransformer
XarrayZarrTypeTransformer
"""

from .xarray_transformers import XarrayDaZarrTypeTransformer, XarrayZarrTypeTransformer
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import typing

import dask.distributed as dd

import xarray as xr
from flytekit import (
Blob,
BlobMetadata,
BlobType,
FlyteContext,
Literal,
LiteralType,
Scalar,
)
from flytekit.extend import TypeEngine, TypeTransformer


class XarrayZarrTypeTransformer(TypeTransformer[xr.Dataset]):
_TYPE_INFO = BlobType(format="binary", dimensionality=BlobType.BlobDimensionality.MULTIPART)

def __init__(self) -> None:
super().__init__(name="Xarray Dataset", t=xr.Dataset)

def get_literal_type(self, t: typing.Type[xr.Dataset]) -> LiteralType:
return LiteralType(blob=self._TYPE_INFO)

def to_literal(
self,
ctx: FlyteContext,
python_val: xr.Dataset,
python_type: typing.Type[xr.Dataset],
expected: LiteralType,
) -> Literal:
remote_dir = ctx.file_access.get_random_remote_path("data.zarr")
# Opening with the dask client will attach the client eliminating the
# need for users to connect to the client if a task tasks a xr.Dataset
# type.
with dd.Client(timeout=120):
python_val.to_zarr(remote_dir, mode="w")
return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO))))

def to_python_value(
self,
ctx: FlyteContext,
lv: Literal,
expected_python_type: typing.Type[xr.Dataset],
) -> xr.Dataset:
return xr.open_zarr(lv.scalar.blob.uri)

def to_html(
self,
ctx: FlyteContext,
python_val: xr.Dataset,
expected_python_type: LiteralType,
) -> str:
return python_val._repr_html_()


class XarrayDaZarrTypeTransformer(TypeTransformer[xr.DataArray]):
_TYPE_INFO = BlobType(format="binary", dimensionality=BlobType.BlobDimensionality.MULTIPART)

def __init__(self) -> None:
super().__init__(name="Xarray DataArray", t=xr.DataArray)

def get_literal_type(self, t: typing.Type[xr.DataArray]) -> LiteralType:
return LiteralType(blob=self._TYPE_INFO)

def to_literal(
self,
ctx: FlyteContext,
python_val: xr.DataArray,
python_type: typing.Type[xr.DataArray],
expected: LiteralType,
) -> Literal:
remote_dir = ctx.file_access.get_random_remote_path("data.zarr")
# Opening with the dask client will attach the client eliminating the
# need for users to connect to the client if a task tasks a xr.Dataset
# type.
with dd.Client(timeout=120):
python_val.to_zarr(remote_dir, mode="w")
return Literal(scalar=Scalar(blob=Blob(uri=remote_dir, metadata=BlobMetadata(type=self._TYPE_INFO))))

def to_python_value(
self,
ctx: FlyteContext,
lv: Literal,
expected_python_type: typing.Type[xr.DataArray],
) -> xr.DataArray:
# xr.open_zarr always opens a dataset, so we take the first variable
return list(xr.open_zarr(lv.scalar.blob.uri).data_vars.values())[0]

def to_html(
self,
ctx: FlyteContext,
python_val: xr.DataArray,
expected_python_type: LiteralType,
) -> str:
return python_val._repr_html_()


TypeEngine.register(XarrayZarrTypeTransformer())
TypeEngine.register(XarrayDaZarrTypeTransformer())
40 changes: 40 additions & 0 deletions plugins/flytekit-xarray-zarr/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from setuptools import setup

PLUGIN_NAME = "xarray"

microlib_name = f"flytekitplugins-{PLUGIN_NAME}-zarr"

plugin_requires = [
"dask[distributed]>=2022.10.2",
"flytekit>=1.3.0b2,<2.0.0",
"xarray",
"zarr",
]

__version__ = "0.0.0+develop"

setup(
name=microlib_name,
version=__version__,
author="flyteorg",
author_email="[email protected]",
description="Xarray Zarr plugin for flytekit",
namespace_packages=["flytekitplugins"],
packages=[f"flytekitplugins.{PLUGIN_NAME}"],
install_requires=plugin_requires,
license="apache2",
python_requires=">=3.9",
classifiers=[
"Intended Audience :: Science/Research",
"Intended Audience :: Developers",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Topic :: Scientific/Engineering",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Software Development",
"Topic :: Software Development :: Libraries",
"Topic :: Software Development :: Libraries :: Python Modules",
],
entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]},
)
48 changes: 48 additions & 0 deletions plugins/flytekit-xarray-zarr/tests/test_xarray_zarr_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from flytekit import task, workflow
import numpy as np
import dask.array as da
import xarray as xr


def _sample_dataset() -> xr.Dataset:
return xr.Dataset(
{"test": (("x", "y"), da.random.uniform(size=(32, 32)))},
)


def test_xarray_zarr_dataarray_plugin():

@task
def _generate_xarray() -> xr.DataArray:
return _sample_dataset()["test"]

@task
def _consume_xarray(ds: xr.DataArray) -> xr.DataArray:
return ds

@workflow
def _xarray_wf() -> xr.DataArray:
ds = _generate_xarray()
return _consume_xarray(ds=ds)

array = _xarray_wf()
assert isinstance(array, xr.DataArray)


def test_xarray_zarr_dataset_plugin():

@task
def _generate_xarray() -> xr.Dataset:
return _sample_dataset()

@task
def _consume_xarray(ds: xr.Dataset) -> xr.Dataset:
return ds

@workflow
def _xarray_wf() -> xr.Dataset:
ds = _generate_xarray()
return _consume_xarray(ds=ds)

array = _xarray_wf()
assert isinstance(array, xr.Dataset)
1 change: 1 addition & 0 deletions plugins/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"flytekitplugins-sqlalchemy": "flytekit-sqlalchemy",
"flytekitplugins-vaex": "flytekit-vaex",
"flytekitplugins-whylogs": "flytekit-whylogs",
"flytekitplugins-xarray-zarr": "flytekit-xarray-zarr",
}


Expand Down
Loading