Skip to content

Commit

Permalink
Allow types for parameterized functions (#138)
Browse files Browse the repository at this point in the history
* Fix up pylance and flak8 errors

* Adding more tests of what we want covered

* Add a test to make sure the type based stuff is working well.

* Implement ast checking and constant wrapping for types and variables

* Add explicit AST checking

* Fix up pylance error

* Add check for bad variable and flake8 fixes

* Fix up flake8 and pylance errors

* Remove 3.7 testing support

* 3.8 and above use ast.Constant - get rid of 3.7 support.

* Fix up some minor test issues for python 3.8

* Merge branch 'master' into feat/pr_types
  • Loading branch information
gordonwatts authored Apr 25, 2024
1 parent b283b72 commit 2891ca1
Show file tree
Hide file tree
Showing 7 changed files with 240 additions and 160 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.7, 3.8, 3.9, "3.10", 3.11, 3.12]
python-version: [3.8, 3.9, "3.10", 3.11, 3.12]

steps:
- uses: actions/checkout@v3
Expand Down
19 changes: 8 additions & 11 deletions func_adl/object_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from func_adl.util_types import unwrap_iterable

from .util_ast import as_ast, function_call, parse_as_ast
from .util_ast import as_ast, check_ast, function_call, parse_as_ast

# Attribute that will be used to store the executor reference
executor_attr_name = "_func_adl_executor"
Expand Down Expand Up @@ -66,7 +66,7 @@ class ObjectStream(Generic[T]):
`Iterable[T]`).
"""

def __init__(self, the_ast: ast.AST, item_type: Type = Any):
def __init__(self, the_ast: ast.AST, item_type: Type = Any): # type: ignore
r"""
Initialize the stream with the ast that will produce this stream of objects.
The user will almost never use this initializer.
Expand All @@ -88,9 +88,7 @@ def query_ast(self) -> ast.AST:
"""
return self._q_ast

def clone_with_new_ast(
self, new_ast: ast.AST, new_type: type[S]
) -> ObjectStream[S]:
def clone_with_new_ast(self, new_ast: ast.AST, new_type: type[S]) -> ObjectStream[S]:
clone = copy.deepcopy(self)
clone._q_ast = new_ast
clone._item_type = new_type
Expand Down Expand Up @@ -121,6 +119,7 @@ def SelectMany(
n_stream, n_ast, rtn_type = remap_from_lambda(
self, _local_simplification(parse_as_ast(func, "SelectMany"))
)
check_ast(n_ast)

return self.clone_with_new_ast(
function_call("SelectMany", [n_stream.query_ast, cast(ast.AST, n_ast)]),
Expand Down Expand Up @@ -149,14 +148,13 @@ def Select(self, f: Union[str, ast.Lambda, Callable[[T], S]]) -> ObjectStream[S]
n_stream, n_ast, rtn_type = remap_from_lambda(
self, _local_simplification(parse_as_ast(f, "Select"))
)
check_ast(n_ast)
return self.clone_with_new_ast(
function_call("Select", [n_stream.query_ast, cast(ast.AST, n_ast)]),
rtn_type,
)

def Where(
self, filter: Union[str, ast.Lambda, Callable[[T], bool]]
) -> ObjectStream[T]:
def Where(self, filter: Union[str, ast.Lambda, Callable[[T], bool]]) -> ObjectStream[T]:
r"""
Filter the object stream, allowing only items for which `filter` evaluates as true through.
Expand All @@ -177,6 +175,7 @@ def Where(
n_stream, n_ast, rtn_type = remap_from_lambda(
self, _local_simplification(parse_as_ast(filter, "Where"))
)
check_ast(n_ast)
if rtn_type != bool:
raise ValueError(f"The Where filter must return a boolean (not {rtn_type})")
return self.clone_with_new_ast(
Expand Down Expand Up @@ -329,9 +328,7 @@ def AsParquetFiles(
columns = [columns]

return ObjectStream[ReturnedDataPlaceHolder](
function_call(
"ResultParquet", [self._q_ast, as_ast(columns), as_ast(filename)]
)
function_call("ResultParquet", [self._q_ast, as_ast(columns), as_ast(filename)])
)

as_parquet = AsParquetFiles
Expand Down
29 changes: 16 additions & 13 deletions func_adl/type_based_replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@
("function", Callable),
(
"processor_function",
Optional[Callable[[ObjectStream[U], ast.Call], Tuple[ObjectStream[U], ast.AST]]],
Optional[
Callable[
[ObjectStream[U], ast.Call], Tuple[ObjectStream[U], ast.AST] # type: ignore
]
],
),
],
)
Expand All @@ -56,8 +60,7 @@ def _load_default_global_functions():
"Define the python standard functions that map straight through"
# TODO: Add in other functions

def my_abs(x: float) -> float:
...
def my_abs(x: float) -> float: ... # noqa

_global_functions["abs"] = _FuncAdlFunction("abs", my_abs, None)

Expand Down Expand Up @@ -87,7 +90,7 @@ def register_func_adl_function(
Tuple[ObjectStream[T], ast.AST]]):
The processor function that can modify the stream, etc.
"""
info = _FuncAdlFunction(function.__name__, function, processor_function)
info = _FuncAdlFunction(function.__name__, function, processor_function) # type: ignore
_global_functions[info.name] = info


Expand Down Expand Up @@ -132,6 +135,7 @@ def MySqrt(x: float) -> float:
processor (Optional[Callable[[ObjectStream[W], ast.Call],
Tuple[ObjectStream[W], ast.AST]]], optional): [description]. Defaults to None.
"""

# TODO: Do we really need to register this inside the decorator? Can we just register
# and return the function?
def decorate(function: Callable):
Expand Down Expand Up @@ -218,7 +222,7 @@ def register_func_adl_os_collection(c: C_TYPE) -> C_TYPE:
Returns:
[type]: [description]
"""
_g_collection_classes[c] = CollectionClassInfo(c)
_g_collection_classes[c] = CollectionClassInfo(c) # type: ignore
return c


Expand All @@ -236,18 +240,17 @@ class ObjectStreamInternalMethods(ObjectStream[StreamItem]):
to follow generics in python at runtime (think of this as poor man's type resolution).
"""

def __init__(self, a: ast.AST, item_type: Type):
super().__init__(a, item_type)
def __init__(self, a: ast.AST, item_type: Union[Type, object]):
super().__init__(a, item_type) # type: ignore

@property
def item_type(self) -> Type:
return self._item_type

def First(self) -> StreamItem:
return self.item_type
return self.item_type # type: ignore

def Count(self) -> int:
...
def Count(self) -> int: ... # noqa

# TODO: Add all other items that belong here

Expand Down Expand Up @@ -448,7 +451,7 @@ class _MethodTypeReturnInfo:
node: ast.Call

# Return type
return_type: Type
return_type: Union[Type, object]

# If full type resolution was done (e.g. lambda following), or
# if there were no lambda arguments to follow.
Expand Down Expand Up @@ -485,15 +488,15 @@ def remap_by_types(
class type_transformer(ast.NodeTransformer, Generic[S]):
def __init__(self, o_stream: ObjectStream[S]):
self._stream = o_stream
self._found_types: Dict[Union[str, object], type] = {var_name: var_type}
self._found_types: Dict[Union[str, object], Union[type, object]] = {var_name: var_type}

@property
def stream(self) -> ObjectStream[S]:
return self._stream # type: ignore

def lookup_type(self, name: Union[str, object]) -> Type:
"Return the type for a node, Any if we do not know about it"
return self._found_types.get(name, Any)
return self._found_types.get(name, Any) # type: ignore

def process_method_call_on_stream_obj(
self,
Expand Down
67 changes: 37 additions & 30 deletions func_adl/util_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,17 @@
from types import ModuleType
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast

# Some functions to enable backwards compatibility.
# Capability may be degraded in older versions.
if sys.version_info >= (3, 8): # pragma: no cover

def as_literal(p: Union[str, int, float, bool, None]) -> ast.Constant:
return ast.Constant(value=p, kind=None)
def as_literal(p: Union[str, int, float, bool, None]) -> ast.Constant:
"""Convert a python constant into an AST constant node.
else: # pragma: no cover
Args:
p (Union[str, int, float, bool, None]): what should be wrapped
def as_literal(p: Union[str, int, float, bool, None]):
if isinstance(p, str):
return ast.Str(p)
elif isinstance(p, (int, float)):
return ast.Num(p)
elif isinstance(p, bool):
return ast.NameConstant(p)
elif p is None:
return ast.NameConstant(None)
else:
raise ValueError(f"Unknown type {type(p)} - do not know how to make a literal!")
Returns:
ast.Constant: The ast constant node that represents the value.
"""
return ast.Constant(value=p, kind=None)


def as_ast(p_var: Any) -> ast.AST:
Expand Down Expand Up @@ -333,19 +324,10 @@ def visit_Name(self, node: ast.Name) -> Any:

if node.id in self._lookup_dict:
v = self._lookup_dict[node.id]
if not callable(v):
# Modules should be sent on down to be dealt with by the
# backend.
if not isinstance(v, ModuleType):
legal_capture_types = [str, int, float, bool, complex, str, bytes]
if type(v) not in legal_capture_types:
raise ValueError(
f"Do not know how to capture data type '{type(v).__name__}' for "
f"variable '{node.id}' - only "
f"{', '.join([c.__name__ for c in legal_capture_types])} are "
"supported."
)
return as_literal(v)
if not callable(v) and not isinstance(v, ModuleType):
# If it is something we know how to make into a literal, we just send it down
# like that.
return as_literal(v)
return node

def visit_Lambda(self, node: ast.Lambda) -> Any:
Expand Down Expand Up @@ -728,3 +710,28 @@ def visit_Call(self, node: ast.Call):
callback(node.args[1]) # type: ignore

metadata_finder().visit(a)


g_legal_capture_types = (str, int, float, bool, complex, str, bytes, ModuleType)


def check_ast(a: ast.AST):
"""Check to make sure the ast does not have anything we can't send over the wire
in `qastle` or similar.
Args:
a (ast.AST): The AST to check
Raises:
ValueError: If something unsupported is found.
"""

class ConstantTypeChecker(ast.NodeVisitor):
def visit_Constant(self, node: ast.Constant):
if not isinstance(node.value, g_legal_capture_types):
raise ValueError(f"Invalid constant type: {type(node.value)} for {ast.dump(node)}")
self.generic_visit(node)

# Usage example:
checker = ConstantTypeChecker()
checker.visit(a)
19 changes: 19 additions & 0 deletions tests/test_object_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,25 @@ def test_two_simple_query():
assert ast.dump(r1) == ast.dump(r2)


def test_query_bad_variable():
class my_type:
def __init__(self, n):
self._n = 10

my_10 = my_type(10)

with pytest.raises(ValueError) as e:
(
my_event()
.SelectMany(lambda e: e.jets())
.Select(lambda j: j.pT() > my_10)
.AsROOTTTree("junk.root", "analysis", "jetPT")
.value()
)

assert "my_type" in str(e)


def test_with_types():
r1 = my_event_with_type().SelectMany(lambda e: e.Jets("jets"))
r = r1.Select(lambda j: j.eta()).value()
Expand Down
Loading

0 comments on commit 2891ca1

Please sign in to comment.