diff --git a/tests/test_type_based_replacement.py b/tests/test_type_based_replacement.py index 6f6b5be..30e3f6a 100644 --- a/tests/test_type_based_replacement.py +++ b/tests/test_type_based_replacement.py @@ -16,6 +16,7 @@ remap_by_types, remap_from_lambda, ) +from func_adl.util_types import is_iterable, unwrap_iterable class Track: @@ -547,7 +548,20 @@ def test_dictionary_bad_key(): def test_dictionary_through_Select(): """Make sure the Select statement carries the typing all the way through""" - assert False + + s = ast_lambda("e.Jets().Select(lambda j: {'pt': j.pt(), 'eta': j.eta()})") + objs = ObjectStream[Event](ast.Name(id="e", ctx=ast.Load())) + + _, _, expr_type = remap_by_types(objs, "e", Event, s) + + assert is_iterable(expr_type) + obj_itr = unwrap_iterable(expr_type) + assert isclass(obj_itr) + sig = inspect.signature(obj_itr.__init__) + assert len(sig.parameters) == 3 + assert "pt" in sig.parameters + j_info = sig.parameters["pt"] + assert j_info.annotation == float def test_indexed_tuple():