Skip to content

Commit

Permalink
Instance generic empty case (#2802)
Browse files Browse the repository at this point in the history
  • Loading branch information
wild-endeavor authored Oct 11, 2024
1 parent 014d08c commit eaa5cfe
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 7 deletions.
8 changes: 2 additions & 6 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,14 @@ def isinstance_generic(self, obj, generic_alias):
if origin in {list, tuple, set}:
for item in obj:
self.assert_type(args[0], item)
return
raise TypeTransformerFailedError(f"Not all items in '{obj}' are of type {args[0]}")
return

if origin is dict:
key_type, value_type = args
for k, v in obj.items():
self.assert_type(key_type, k)
self.assert_type(value_type, v)
return
raise TypeTransformerFailedError(f"Not all values in '{obj}' are of type {value_type}")

return
return

def assert_type(self, t: Type[T], v: T):
if sys.version_info >= (3, 10):
Expand Down
1 change: 1 addition & 0 deletions flytekit/tools/script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def ls_files(
else:
all_files = list_all_files(source_path, deref_symlinks, ignore_group)

all_files.sort()
hasher = hashlib.md5()
for abspath in all_files:
relpath = os.path.relpath(abspath, source_path)
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/cli/pyflyte/test_script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_list_dir(dummy_dir_structure):
files, d = ls_files(str(dummy_dir_structure), CopyFileDetection.ALL)
assert len(files) == 5
if os.name != "nt":
assert d == "c092f1b85f7c6b2a71881a946c00a855"
assert d == "b6907fd823a45e26c780a4ba62111243"


def test_list_filtered_on_modules(dummy_dir_structure):
Expand Down
29 changes: 29 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3466,3 +3466,32 @@ def test_option_list_with_pipe_2():

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, [[{"a": "one"}], None, [{"b": 3}]], pt, lt)


@pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10, 585 requires >=3.9")
def test_generic_errors_and_empty():
# Test dictionaries
pt = dict[str, str]
lt = TypeEngine.to_literal_type(pt)

ctx = FlyteContextManager.current_context()
lit = TypeEngine.to_literal(ctx, {}, pt, lt)
lit = TypeEngine.to_literal(ctx, {"a": "b"}, pt, lt)

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, {"a": 3}, pt, lt)

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, {3: "a"}, pt, lt)

# Test lists
pt = list[str]
lt = TypeEngine.to_literal_type(pt)
lit = TypeEngine.to_literal(ctx, [], pt, lt)
lit = TypeEngine.to_literal(ctx, ["a"], pt, lt)

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, {"a": 3}, pt, lt)

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, [3], pt, lt)

0 comments on commit eaa5cfe

Please sign in to comment.