From 54923afab4ef50c59465135e6f3c71358fdb8f35 Mon Sep 17 00:00:00 2001 From: Eduardo Apolinario <653394+eapolinario@users.noreply.github.com> Date: Tue, 15 Oct 2024 13:38:12 -0400 Subject: [PATCH] More instance generic checks (#2813) (#2817) * don't check sub-types * update test * lint * forgot to switch back to instance generic --------- Signed-off-by: Yee Hing Tong Co-authored-by: Yee Hing Tong --- flytekit/core/type_engine.py | 14 -------------- tests/flytekit/unit/core/test_dynamic.py | 18 ++++++++++++++++++ tests/flytekit/unit/core/test_type_engine.py | 2 +- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d7a6aca75d..800a6345c1 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -150,24 +150,10 @@ def type_assertions_enabled(self) -> bool: def isinstance_generic(self, obj, generic_alias): origin = get_origin(generic_alias) # list from list[int]) - args = get_args(generic_alias) # (int,) from list[int] if not isinstance(obj, origin): raise TypeTransformerFailedError(f"Value '{obj}' is not of container type {origin}") - # Optionally check the type of elements if it's a collection like list or dict - if origin in {list, tuple, set}: - for item in obj: - self.assert_type(args[0], item) - return - - if origin is dict: - key_type, value_type = args - for k, v in obj.items(): - self.assert_type(key_type, k) - self.assert_type(value_type, v) - return - def assert_type(self, t: Type[T], v: T): if sys.version_info >= (3, 10): import types diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index 72e4c9b244..80350334ff 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -19,6 +19,8 @@ from flytekit.tools.translator import get_serializable_task from flytekit.types.file import FlyteFile +pd = pytest.importorskip("pandas") + settings = flytekit.configuration.SerializationSettings( project="test_proj", domain="test_domain", @@ -373,3 +375,19 @@ def dynamic_task() -> List[FlyteFile]: ) as new_ctx: with pytest.raises(FlyteUserRuntimeException): dynamic_task.dispatch_execute(new_ctx, input_literal_map) + + +def test_dyn_pd(): + @task + def nested_task() -> pd.DataFrame: # type: ignore + return pd.DataFrame({"a": [1, 2, 3]}) + + @dynamic + def my_dynamic() -> list[pd.DataFrame]: # type: ignore + dfs = [] + for i in range(3): + dfs.append(nested_task()) + return dfs + + list_pd = my_dynamic() + assert len(list_pd) == 3 diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index e6b4acd485..f7cc325b7a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3481,7 +3481,7 @@ def test_generic_errors_and_empty(): with pytest.raises(TypeTransformerFailedError): TypeEngine.to_literal(ctx, {"a": 3}, pt, lt) - with pytest.raises(TypeTransformerFailedError): + with pytest.raises(ValueError): TypeEngine.to_literal(ctx, {3: "a"}, pt, lt) # Test lists