diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 800a6345c1..632976808d 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1525,8 +1525,7 @@ def __init__(self): @staticmethod def is_optional_type(t: Type) -> bool: - """Return True if `t` is a Union or Optional type.""" - return _is_union_type(t) or type(None) in get_args(t) + return _is_union_type(t) and type(None) in get_args(t) @staticmethod def get_sub_type_in_optional(t: Type[T]) -> Type[T]: diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index f7cc325b7a..223caaadf6 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -1707,6 +1707,8 @@ def test_union_transformer(): assert not UnionTransformer.is_optional_type(str) assert UnionTransformer.get_sub_type_in_optional(typing.Optional[int]) == int assert UnionTransformer.get_sub_type_in_optional(int | None) == int + assert not UnionTransformer.is_optional_type(typing.Union[int, str]) + assert UnionTransformer.is_optional_type(typing.Union[int, None]) def test_union_guess_type():