From b9df64def12919d634f1619851aa4a70877c491b Mon Sep 17 00:00:00 2001 From: Gordon Watts Date: Thu, 16 Sep 2021 14:43:03 +0200 Subject: [PATCH] Add a function to allow MetaData over the wire (#78) Provides language constructs for adding MetaData to the query --- func_adl/ast/__init__.py | 6 +++ func_adl/ast/meta_data.py | 56 +++++++++++++++++++++++++++ func_adl/object_stream.py | 11 +++++- tests/ast/test_function_simplifier.py | 16 +------- tests/ast/test_meta_data.py | 43 ++++++++++++++++++++ tests/ast/utils.py | 16 ++++++++ tests/test_object_stream.py | 10 +++++ 7 files changed, 142 insertions(+), 16 deletions(-) create mode 100644 func_adl/ast/meta_data.py create mode 100644 tests/ast/test_meta_data.py diff --git a/func_adl/ast/__init__.py b/func_adl/ast/__init__.py index e69de29..6a1b92b 100644 --- a/func_adl/ast/__init__.py +++ b/func_adl/ast/__init__.py @@ -0,0 +1,6 @@ +from .meta_data import extract_metadata # NOQA +from .call_stack import argument_stack, stack_frame # NOQA +from .func_adl_ast_utils import FuncADLNodeVisitor, function_call # NOQA +from .aggregate_shortcuts import aggregate_node_transformer # NOQA +from .func_adl_ast_utils import change_extension_functions_to_calls # NOQA +from .function_simplifier import simplify_chained_calls # NOQA diff --git a/func_adl/ast/meta_data.py b/func_adl/ast/meta_data.py new file mode 100644 index 0000000..5fc7487 --- /dev/null +++ b/func_adl/ast/meta_data.py @@ -0,0 +1,56 @@ + +import ast +from func_adl.ast.func_adl_ast_utils import FuncADLNodeTransformer +from typing import Dict, List, Tuple + + +class _extract_metadata(FuncADLNodeTransformer): + '''Extract all the metadata from an expression, and remove the + metadata nodes. Assume the metadata can all be bubbled to the top and + has equal precedence. + ''' + def __init__(self): + super().__init__() + self._metadata = [] + + @property + def metadata(self) -> List[Dict[str, str]]: + '''Returns the metadata found while scanning expressions + in the order it was encountered. + + Returns: + List[Dict[str, str]]: List of all metadata found. + ''' + return self._metadata + + def visit_Call(self, node: ast.Call): + '''Detect a MetaData call, and remove it, storing the + information. + + Args: + node (ast.Call): The call node to process. + + Returns: + ast.AST: The ast without the call node (if need be). + ''' + if isinstance(node.func, ast.Name) and node.func.id == 'MetaData': + self._metadata.append(ast.literal_eval(node.args[1])) + return self.visit(node.args[0]) + return super().visit_Call(node) + return super().visit_Call(node) + + +def extract_metadata(a: ast.AST) -> Tuple[ast.AST, List[Dict[str, str]]]: + '''Returns the expresion with extracted metadata and the metadata, in order + from the outter most to the inner most `MetaData` expressions. + + Args: + a (ast.AST): The AST potentially containing metadata definitions + + Returns: + Tuple[ast.AST, List[Dict[str, str]]]: a new AST without the metadata references + and a list of metadata found. + ''' + e = _extract_metadata() + a_new = e.visit(a) + return a_new, e.metadata diff --git a/func_adl/object_stream.py b/func_adl/object_stream.py index 75dbc8a..8b5969a 100644 --- a/func_adl/object_stream.py +++ b/func_adl/object_stream.py @@ -1,6 +1,6 @@ # An Object stream represents a stream of objects, floats, integers, etc. import ast -from typing import Any, Awaitable, Callable, List, Optional, Union, cast +from typing import Any, Awaitable, Callable, Dict, List, Optional, Union, cast from make_it_sync import make_sync @@ -106,6 +106,15 @@ def Where(self, filter: Union[str, ast.Lambda, Callable]) -> 'ObjectStream': return ObjectStream(function_call("Where", [self._q_ast, cast(ast.AST, parse_as_ast(filter))])) + def MetaData(self, metadata: Dict[str, Any]) -> 'ObjectStream': + '''Add metadata to the current object stream. The metadata is an arbitrary set of string + key-value pairs. The backend must be able to properly interpret the metadata. + + Returns: + ObjectStream: A new stream, of the same type and contents, but with metadata added. + ''' + return ObjectStream(function_call("MetaData", [self._q_ast, as_ast(metadata)])) + def AsPandasDF(self, columns=[]) -> 'ObjectStream': r""" Return a pandas stream that contains one item, an pandas `DataFrame`. diff --git a/tests/ast/test_function_simplifier.py b/tests/ast/test_function_simplifier.py index 2d8c775..0cac703 100644 --- a/tests/ast/test_function_simplifier.py +++ b/tests/ast/test_function_simplifier.py @@ -1,13 +1,11 @@ import ast -from typing import Tuple, cast from astunparse import unparse from func_adl.ast.function_simplifier import (FuncADLIndexError, - make_args_unique, simplify_chained_calls) from tests.util_debug_ast import normalize_ast -from .utils import reset_ast_counters # NOQA +from .utils import reset_ast_counters, util_run_parse # NOQA def util_process(ast_in, ast_out): @@ -31,18 +29,6 @@ def util_process(ast_in, ast_out): return a_updated_raw -############## -# Test lambda copier -def util_run_parse(a_text: str) -> Tuple[ast.Lambda, ast.Lambda]: - module = ast.parse(a_text) - assert isinstance(module, ast.Module) - s = cast(ast.Expr, module.body[0]) - a = s.value - assert isinstance(a, ast.Lambda) - new_a = make_args_unique(a) - return (a, new_a) - - def test_lambda_copy_simple(): a, new_a = util_run_parse('lambda a: a') assert unparse(new_a).strip() == "(lambda arg_0: arg_0)" diff --git a/tests/ast/test_meta_data.py b/tests/ast/test_meta_data.py new file mode 100644 index 0000000..3aff63c --- /dev/null +++ b/tests/ast/test_meta_data.py @@ -0,0 +1,43 @@ + + +import ast +from func_adl.ast.meta_data import extract_metadata +from typing import Dict, List + + +def compare_metadata(with_metadata: str, without_metadata: str) -> List[Dict[str, str]]: + ''' + Compares two AST expressions after first removing all metadata references from the + first expression. Returns a list of dictionaries of the found metadata + ''' + a_with = ast.parse(with_metadata) + a_without = ast.parse(without_metadata) + + a_removed, metadata = extract_metadata(a_with) + + assert ast.dump(a_removed) == ast.dump(a_without) + return metadata + + +def test_no_metadata(): + 'Make sure expression with no metadata is not changed' + meta = compare_metadata("Select(jets, lambda j: j*2)", "Select(jets, lambda j: j*2)") + assert len(meta) == 0 + + +def test_simple_metadata(): + 'Make sure expression with metadata correctly cleaned up and removed' + meta = compare_metadata("MetaData(Select(jets, lambda j: j*2), {'hi': 'there'})", + "Select(jets, lambda j: j*2)") + assert len(meta) == 1 + assert meta[0] == {'hi': 'there'} + + +def test_two_metadata(): + 'Make sure expression with no metadata is not changed' + meta = compare_metadata( + "MetaData(Select(MetaData(jets, {'fork': 'dude'}), lambda j: j*2), {'hi': 'there'})", + "Select(jets, lambda j: j*2)") + assert len(meta) == 2 + assert meta[0] == {'hi': 'there'} + assert meta[1] == {'fork': 'dude'} diff --git a/tests/ast/utils.py b/tests/ast/utils.py index 34e71e9..50d1adb 100644 --- a/tests/ast/utils.py +++ b/tests/ast/utils.py @@ -1,4 +1,8 @@ +import ast +from typing import Tuple, cast + import pytest +from func_adl.ast.function_simplifier import make_args_unique @pytest.fixture(autouse=True) @@ -7,3 +11,15 @@ def reset_ast_counters(): fs.argument_var_counter = 0 yield fs.argument_var_counter = 0 + + +############## +# Test lambda copier +def util_run_parse(a_text: str) -> Tuple[ast.Lambda, ast.Lambda]: + module = ast.parse(a_text) + assert isinstance(module, ast.Module) + s = cast(ast.Expr, module.body[0]) + a = s.value + assert isinstance(a, ast.Lambda) + new_a = make_args_unique(a) + return (a, new_a) diff --git a/tests/test_object_stream.py b/tests/test_object_stream.py index 2d32684..3188349 100644 --- a/tests/test_object_stream.py +++ b/tests/test_object_stream.py @@ -84,6 +84,16 @@ def test_simple_query_awkward(): assert isinstance(r, ast.AST) +def test_metadata(): + r = my_event() \ + .MetaData({'one': 'two', 'two': 'three'}) \ + .SelectMany("lambda e: e.jets()") \ + .Select("lambda j: j.pT()") \ + .AsROOTTTree("junk.root", "analysis", "jetPT") \ + .value() + assert isinstance(r, ast.AST) + + def test_nested_query_rendered_correctly(): r = my_event() \ .Where("lambda e: e.jets.Select(lambda j: j.pT()).Where(lambda j: j > 10).Count() > 0") \