Skip to content

Commit

Permalink
Allow a hash method to be present for numpy arrays (flyteorg#2649)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <[email protected]>
Signed-off-by: mao3267 <[email protected]>
  • Loading branch information
demmerichs authored and mao3267 committed Aug 9, 2024
1 parent 901fe4f commit 9e0dbf5
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 16 deletions.
49 changes: 38 additions & 11 deletions flytekit/types/numpy/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,37 @@
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.core.hash import HashMethod
from flytekit.core.type_engine import (
TypeEngine,
TypeTransformer,
TypeTransformerFailedError,
)
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType


def extract_metadata(t: Type[np.ndarray]) -> Tuple[Type[np.ndarray], Dict[str, bool]]:
metadata = {}
metadata: dict = {}
metadata_set = False

if get_origin(t) is Annotated:
base_type, metadata = get_args(t)
if isinstance(metadata, OrderedDict):
return base_type, metadata
else:
raise TypeTransformerFailedError(f"{t}'s metadata needs to be of type kwtypes.")
base_type, *annotate_args = get_args(t)

for aa in annotate_args:
if isinstance(aa, OrderedDict):
if metadata_set:
raise TypeTransformerFailedError(f"Metadata {metadata} is already specified, cannot use {aa}.")
metadata = aa
metadata_set = True
elif isinstance(aa, HashMethod):
continue
else:
raise TypeTransformerFailedError(f"The metadata for {t} must be of type kwtypes or HashMethod.")
return base_type, metadata

# Return the type itself if no metadata was found.
return t, metadata


Expand All @@ -37,26 +54,36 @@ def __init__(self):
def get_literal_type(self, t: Type[np.ndarray]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.NUMPY_ARRAY_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
format=self.NUMPY_ARRAY_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)

def to_literal(
self, ctx: FlyteContext, python_val: np.ndarray, python_type: Type[np.ndarray], expected: LiteralType
self,
ctx: FlyteContext,
python_val: np.ndarray,
python_type: Type[np.ndarray],
expected: LiteralType,
) -> Literal:
python_type, metadata = extract_metadata(python_type)

meta = BlobMetadata(
type=_core_types.BlobType(
format=self.NUMPY_ARRAY_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
format=self.NUMPY_ARRAY_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)

local_path = ctx.file_access.get_random_local_path() + ".npy"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)

# save numpy array to file
np.save(file=local_path, arr=python_val, allow_pickle=metadata.get("allow_pickle", False))
np.save(
file=local_path,
arr=python_val,
allow_pickle=metadata.get("allow_pickle", False),
)
remote_path = ctx.file_access.put_raw_data(local_path)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

Expand Down
46 changes: 41 additions & 5 deletions tests/flytekit/unit/types/numpy/test_ndarray.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import pytest
import numpy as np
from typing_extensions import Annotated

from flytekit import kwtypes, task, workflow
from flytekit import HashMethod, kwtypes, task, workflow
from flytekit.core.type_engine import TypeTransformerFailedError


@task
Expand Down Expand Up @@ -63,6 +65,35 @@ def t4(array: Annotated[np.ndarray, kwtypes(allow_pickle=True)]) -> int:
return array.size


def dummy_hash_array(arr: np.ndarray) -> str:
return "dummy"


@task
def t5_annotate_kwtypes_and_hash(
array: Annotated[
np.ndarray, kwtypes(allow_pickle=True), HashMethod(dummy_hash_array)
],
):
pass


@task
def t6_annotate_kwtypes_twice(
array: Annotated[
np.ndarray, kwtypes(allow_pickle=True), kwtypes(allow_pickle=False)
],
):
pass


@task
def t7_annotate_with_sth_strange(
array: Annotated[np.ndarray, (1, 2, 3)],
):
pass


@workflow
def wf():
array_1d = generate_numpy_1d()
Expand All @@ -72,10 +103,15 @@ def wf():
t2(array=array_2d)
t3(array=array_1d)
t4(array=array_dtype_object)
try:
generate_numpy_fails()
except Exception as e:
assert isinstance(e, TypeError)
t5_annotate_kwtypes_and_hash(array=array_1d)

if array_1d.is_ready:
with pytest.raises(TypeTransformerFailedError, match=r"Metadata OrderedDict.*'allow_pickle'.*True.* is already specified, cannot use OrderedDict.*'allow_pickle'.*False.*\."):
t6_annotate_kwtypes_twice(array=array_1d)
with pytest.raises(TypeTransformerFailedError, match=r"The metadata for typing.Annotated.*numpy\.ndarray.*1, 2, 3.* must be of type kwtypes or HashMethod\."):
t7_annotate_with_sth_strange(array=array_1d)
with pytest.raises(TypeError, match=r"The metadata for typing.Annotated.*numpy\.ndarray.*'allow_pickle'.*True.* must be of type kwtypes or HashMethod\."):
generate_numpy_fails()


@workflow
Expand Down

0 comments on commit 9e0dbf5

Please sign in to comment.