Skip to content

Commit

Permalink
More instance generic checks (#2813) (#2817)
Browse files Browse the repository at this point in the history
* don't check sub-types



* update test



* lint



* forgot to switch back to instance generic



---------

Signed-off-by: Yee Hing Tong <[email protected]>
Co-authored-by: Yee Hing Tong <[email protected]>
  • Loading branch information
eapolinario and wild-endeavor authored Oct 15, 2024
1 parent eaa5cfe commit 54923af
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
14 changes: 0 additions & 14 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,24 +150,10 @@ def type_assertions_enabled(self) -> bool:

def isinstance_generic(self, obj, generic_alias):
origin = get_origin(generic_alias) # list from list[int])
args = get_args(generic_alias) # (int,) from list[int]

if not isinstance(obj, origin):
raise TypeTransformerFailedError(f"Value '{obj}' is not of container type {origin}")

# Optionally check the type of elements if it's a collection like list or dict
if origin in {list, tuple, set}:
for item in obj:
self.assert_type(args[0], item)
return

if origin is dict:
key_type, value_type = args
for k, v in obj.items():
self.assert_type(key_type, k)
self.assert_type(value_type, v)
return

def assert_type(self, t: Type[T], v: T):
if sys.version_info >= (3, 10):
import types
Expand Down
18 changes: 18 additions & 0 deletions tests/flytekit/unit/core/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from flytekit.tools.translator import get_serializable_task
from flytekit.types.file import FlyteFile

pd = pytest.importorskip("pandas")

settings = flytekit.configuration.SerializationSettings(
project="test_proj",
domain="test_domain",
Expand Down Expand Up @@ -373,3 +375,19 @@ def dynamic_task() -> List[FlyteFile]:
) as new_ctx:
with pytest.raises(FlyteUserRuntimeException):
dynamic_task.dispatch_execute(new_ctx, input_literal_map)


def test_dyn_pd():
@task
def nested_task() -> pd.DataFrame: # type: ignore
return pd.DataFrame({"a": [1, 2, 3]})

@dynamic
def my_dynamic() -> list[pd.DataFrame]: # type: ignore
dfs = []
for i in range(3):
dfs.append(nested_task())
return dfs

list_pd = my_dynamic()
assert len(list_pd) == 3
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3481,7 +3481,7 @@ def test_generic_errors_and_empty():
with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, {"a": 3}, pt, lt)

with pytest.raises(TypeTransformerFailedError):
with pytest.raises(ValueError):
TypeEngine.to_literal(ctx, {3: "a"}, pt, lt)

# Test lists
Expand Down

0 comments on commit 54923af

Please sign in to comment.