Skip to content

Commit

Permalink
Add a function to allow MetaData over the wire (#78)
Browse files Browse the repository at this point in the history
Provides language constructs for adding MetaData to the query
  • Loading branch information
gordonwatts authored Sep 16, 2021
1 parent 0535266 commit b9df64d
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 16 deletions.
6 changes: 6 additions & 0 deletions func_adl/ast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .meta_data import extract_metadata # NOQA
from .call_stack import argument_stack, stack_frame # NOQA
from .func_adl_ast_utils import FuncADLNodeVisitor, function_call # NOQA
from .aggregate_shortcuts import aggregate_node_transformer # NOQA
from .func_adl_ast_utils import change_extension_functions_to_calls # NOQA
from .function_simplifier import simplify_chained_calls # NOQA
56 changes: 56 additions & 0 deletions func_adl/ast/meta_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

import ast
from func_adl.ast.func_adl_ast_utils import FuncADLNodeTransformer
from typing import Dict, List, Tuple


class _extract_metadata(FuncADLNodeTransformer):
'''Extract all the metadata from an expression, and remove the
metadata nodes. Assume the metadata can all be bubbled to the top and
has equal precedence.
'''
def __init__(self):
super().__init__()
self._metadata = []

@property
def metadata(self) -> List[Dict[str, str]]:
'''Returns the metadata found while scanning expressions
in the order it was encountered.
Returns:
List[Dict[str, str]]: List of all metadata found.
'''
return self._metadata

def visit_Call(self, node: ast.Call):
'''Detect a MetaData call, and remove it, storing the
information.
Args:
node (ast.Call): The call node to process.
Returns:
ast.AST: The ast without the call node (if need be).
'''
if isinstance(node.func, ast.Name) and node.func.id == 'MetaData':
self._metadata.append(ast.literal_eval(node.args[1]))
return self.visit(node.args[0])
return super().visit_Call(node)
return super().visit_Call(node)


def extract_metadata(a: ast.AST) -> Tuple[ast.AST, List[Dict[str, str]]]:
'''Returns the expresion with extracted metadata and the metadata, in order
from the outter most to the inner most `MetaData` expressions.
Args:
a (ast.AST): The AST potentially containing metadata definitions
Returns:
Tuple[ast.AST, List[Dict[str, str]]]: a new AST without the metadata references
and a list of metadata found.
'''
e = _extract_metadata()
a_new = e.visit(a)
return a_new, e.metadata
11 changes: 10 additions & 1 deletion func_adl/object_stream.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# An Object stream represents a stream of objects, floats, integers, etc.
import ast
from typing import Any, Awaitable, Callable, List, Optional, Union, cast
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union, cast

from make_it_sync import make_sync

Expand Down Expand Up @@ -106,6 +106,15 @@ def Where(self, filter: Union[str, ast.Lambda, Callable]) -> 'ObjectStream':
return ObjectStream(function_call("Where",
[self._q_ast, cast(ast.AST, parse_as_ast(filter))]))

def MetaData(self, metadata: Dict[str, Any]) -> 'ObjectStream':
'''Add metadata to the current object stream. The metadata is an arbitrary set of string
key-value pairs. The backend must be able to properly interpret the metadata.
Returns:
ObjectStream: A new stream, of the same type and contents, but with metadata added.
'''
return ObjectStream(function_call("MetaData", [self._q_ast, as_ast(metadata)]))

def AsPandasDF(self, columns=[]) -> 'ObjectStream':
r"""
Return a pandas stream that contains one item, an pandas `DataFrame`.
Expand Down
16 changes: 1 addition & 15 deletions tests/ast/test_function_simplifier.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import ast
from typing import Tuple, cast

from astunparse import unparse
from func_adl.ast.function_simplifier import (FuncADLIndexError,
make_args_unique,
simplify_chained_calls)
from tests.util_debug_ast import normalize_ast

from .utils import reset_ast_counters # NOQA
from .utils import reset_ast_counters, util_run_parse # NOQA


def util_process(ast_in, ast_out):
Expand All @@ -31,18 +29,6 @@ def util_process(ast_in, ast_out):
return a_updated_raw


##############
# Test lambda copier
def util_run_parse(a_text: str) -> Tuple[ast.Lambda, ast.Lambda]:
module = ast.parse(a_text)
assert isinstance(module, ast.Module)
s = cast(ast.Expr, module.body[0])
a = s.value
assert isinstance(a, ast.Lambda)
new_a = make_args_unique(a)
return (a, new_a)


def test_lambda_copy_simple():
a, new_a = util_run_parse('lambda a: a')
assert unparse(new_a).strip() == "(lambda arg_0: arg_0)"
Expand Down
43 changes: 43 additions & 0 deletions tests/ast/test_meta_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@


import ast
from func_adl.ast.meta_data import extract_metadata
from typing import Dict, List


def compare_metadata(with_metadata: str, without_metadata: str) -> List[Dict[str, str]]:
'''
Compares two AST expressions after first removing all metadata references from the
first expression. Returns a list of dictionaries of the found metadata
'''
a_with = ast.parse(with_metadata)
a_without = ast.parse(without_metadata)

a_removed, metadata = extract_metadata(a_with)

assert ast.dump(a_removed) == ast.dump(a_without)
return metadata


def test_no_metadata():
'Make sure expression with no metadata is not changed'
meta = compare_metadata("Select(jets, lambda j: j*2)", "Select(jets, lambda j: j*2)")
assert len(meta) == 0


def test_simple_metadata():
'Make sure expression with metadata correctly cleaned up and removed'
meta = compare_metadata("MetaData(Select(jets, lambda j: j*2), {'hi': 'there'})",
"Select(jets, lambda j: j*2)")
assert len(meta) == 1
assert meta[0] == {'hi': 'there'}


def test_two_metadata():
'Make sure expression with no metadata is not changed'
meta = compare_metadata(
"MetaData(Select(MetaData(jets, {'fork': 'dude'}), lambda j: j*2), {'hi': 'there'})",
"Select(jets, lambda j: j*2)")
assert len(meta) == 2
assert meta[0] == {'hi': 'there'}
assert meta[1] == {'fork': 'dude'}
16 changes: 16 additions & 0 deletions tests/ast/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import ast
from typing import Tuple, cast

import pytest
from func_adl.ast.function_simplifier import make_args_unique


@pytest.fixture(autouse=True)
Expand All @@ -7,3 +11,15 @@ def reset_ast_counters():
fs.argument_var_counter = 0
yield
fs.argument_var_counter = 0


##############
# Test lambda copier
def util_run_parse(a_text: str) -> Tuple[ast.Lambda, ast.Lambda]:
module = ast.parse(a_text)
assert isinstance(module, ast.Module)
s = cast(ast.Expr, module.body[0])
a = s.value
assert isinstance(a, ast.Lambda)
new_a = make_args_unique(a)
return (a, new_a)
10 changes: 10 additions & 0 deletions tests/test_object_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,16 @@ def test_simple_query_awkward():
assert isinstance(r, ast.AST)


def test_metadata():
r = my_event() \
.MetaData({'one': 'two', 'two': 'three'}) \
.SelectMany("lambda e: e.jets()") \
.Select("lambda j: j.pT()") \
.AsROOTTTree("junk.root", "analysis", "jetPT") \
.value()
assert isinstance(r, ast.AST)


def test_nested_query_rendered_correctly():
r = my_event() \
.Where("lambda e: e.jets.Select(lambda j: j.pT()).Where(lambda j: j > 10).Count() > 0") \
Expand Down

0 comments on commit b9df64d

Please sign in to comment.