Skip to content

Commit a86a532

Browse files
committed
Use the Pandas expr tree for preflighting.
Requires `extract_nest_names` to be a method on `NestedFrame` so that the evaluation context is available at parsing time, since the Pandas Expr parsing does some eager evaluation. Resolves #174 .
1 parent 402ab66 commit a86a532

File tree

4 files changed

+151
-76
lines changed

4 files changed

+151
-76
lines changed

src/nested_pandas/nestedframe/core.py

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
from pandas._libs import lib
1111
from pandas._typing import Any, AnyAll, Axis, IndexLabel
1212
from pandas.api.extensions import no_default
13+
from pandas.core.computation import ops
14+
from pandas.core.computation.eval import Expr, ensure_scope
1315
from pandas.core.computation.expr import PARSERS, PandasExprVisitor
16+
from pandas.core.computation.parsing import clean_column_name
1417

15-
from nested_pandas.nestedframe.utils import extract_nest_names
1618
from nested_pandas.series.dtype import NestedDtype
1719
from nested_pandas.series.packer import pack, pack_lists, pack_sorted_df_into_struct
1820

@@ -79,6 +81,22 @@ class _NestResolver(dict):
7981
def __init__(self, outer: NestedFrame):
8082
self._outer = outer
8183
super().__init__()
84+
# Pre-load the field resolvers for all columns which are known at present.
85+
for column in outer.nested_columns:
86+
self._initialize_field_resolver(column, outer)
87+
88+
def _initialize_field_resolver(self, column: str, outer: NestedFrame):
89+
"""
90+
Initialize a resolver for the given nested column, and also an alias
91+
for it, in the case of column names that have spaces or are otherwise
92+
not identifier-like.
93+
"""
94+
super().__setitem__(column, _NestedFieldResolver(column, outer))
95+
clean_id = clean_column_name(column)
96+
# And once more for the cleaned name, if it's different.
97+
# This allows us to capture references to it from the Pandas evaluator.
98+
if clean_id != column:
99+
super().__setitem__(clean_id, _NestedFieldResolver(column, outer))
82100

83101
def __contains__(self, item):
84102
top_nest = item if "." not in item else item.split(".")[0].strip()
@@ -89,7 +107,7 @@ def __getitem__(self, item):
89107
if not super().__contains__(top_nest):
90108
if top_nest not in self._outer.nested_columns:
91109
raise KeyError(f"Unknown nest {top_nest}")
92-
super().__setitem__(top_nest, _NestedFieldResolver(top_nest, self._outer))
110+
self._initialize_field_resolver(top_nest, self._outer)
93111
return super().__getitem__(top_nest)
94112

95113
def __setitem__(self, item, _):
@@ -133,6 +151,48 @@ def __getattr__(self, item_name: str):
133151
raise AttributeError(f"No attribute {item_name}")
134152

135153

154+
def _subexprs_by_nest(parents: list, node) -> dict[str, list]:
155+
"""
156+
Given an expression which contains references to both base and nested
157+
columns, return a dictionary of the sub-expressions that should be
158+
evaluated independently, keyed by nesting context.
159+
160+
The key of the dictionary is the name of the nested column, and will
161+
be a blank string in the case of base columns. The value is a list
162+
of the parent nodes that lead to sub-expressions that can be evaluated
163+
successfully.
164+
165+
While this is not in use today for automatically splitting expressions,
166+
it can be used to detect whether an expression is suitably structured
167+
for evaluation: the returned dictionary should have a single key.
168+
"""
169+
if isinstance(node, ops.Term) and not isinstance(node, ops.Constant):
170+
if isinstance(node.value, _SeriesFromNest):
171+
return {node.value.nest_name: parents}
172+
return {getattr(node, "upper_name", ""): parents}
173+
if not isinstance(node, ops.Op):
174+
return {}
175+
sources = [getattr(node, "lhs", None), getattr(node, "rhs", None)]
176+
result: dict[str, list] = {}
177+
for source in sources:
178+
child = _subexprs_by_nest(parents, source)
179+
for k, v in child.items():
180+
result.setdefault(k, []).append(v)
181+
# After a complete traversal across sources, check for any necessary splits.
182+
# If it's homogenous, move the split-node up the tree.
183+
if len(result) == 1:
184+
# Let the record of each parent node drift up the tree,
185+
# and merge the subtrees into a single node, since by definition,
186+
# this node is homogeneous over all of its children, and can
187+
# be evaluated in a single step.
188+
result = {k: [node] for k in result}
189+
# If the result is either empty or has more than one key, leave the result
190+
# alone. Each key represents a different nest (with a blank string for the base),
191+
# and the value is the highest point in the expression tree where the expression
192+
# was still within a single nest.
193+
return result
194+
195+
136196
class NestedFrame(pd.DataFrame):
137197
"""A Pandas Dataframe extension with support for nested structure.
138198
@@ -457,6 +517,39 @@ def eval(self, expr: str, *, inplace: bool = False, **kwargs) -> Any | None:
457517
kwargs["parser"] = "nested-pandas"
458518
return super().eval(expr, **kwargs)
459519

520+
def extract_nest_names(
521+
self,
522+
expr: str,
523+
local_dict=None,
524+
global_dict=None,
525+
resolvers=(),
526+
level: int = 0,
527+
target=None,
528+
**kwargs,
529+
) -> set[str]:
530+
"""
531+
Given a string expression, parse it and visit the resulting expression tree,
532+
surfacing the nesting types. The purpose is to identify expressions that attempt
533+
to mix base and nested columns, or columns from two different nests.
534+
"""
535+
index_resolvers = self._get_index_resolvers()
536+
column_resolvers = self._get_cleaned_column_resolvers()
537+
resolvers = resolvers + (_NestResolver(self), column_resolvers, index_resolvers)
538+
# Parser needs to be the "nested-pandas" parser.
539+
# We also need the same variable context that eval() will have, so that
540+
# backtick-quoted names are substituted as expected.
541+
env = ensure_scope(
542+
level + 1,
543+
global_dict=global_dict,
544+
local_dict=local_dict,
545+
resolvers=resolvers,
546+
target=target,
547+
)
548+
parsed_expr = Expr(expr, parser="nested-pandas", env=env)
549+
expr_tree = parsed_expr.terms
550+
separable = _subexprs_by_nest([], expr_tree)
551+
return set(separable.keys())
552+
460553
def query(self, expr: str, *, inplace: bool = False, **kwargs) -> NestedFrame | None:
461554
"""
462555
Query the columns of a NestedFrame with a boolean expression. Specified
@@ -514,7 +607,7 @@ def query(self, expr: str, *, inplace: bool = False, **kwargs) -> NestedFrame |
514607
# At present, the query expression must be either entirely within a
515608
# single nest, or have nothing but base columns. Mixed structures are not
516609
# supported, so preflight the expression.
517-
nest_names = extract_nest_names(expr)
610+
nest_names = self.extract_nest_names(expr, **kwargs)
518611
if len(nest_names) > 1:
519612
raise ValueError("Queries cannot target multiple structs/layers, write a separate query for each")
520613
result = self.eval(expr, **kwargs)

src/nested_pandas/nestedframe/utils.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

tests/nested_pandas/nestedframe/test_nestedframe.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,25 @@ def test_query():
594594
assert base["nested.d"].shape == (2,)
595595

596596

597+
def test_query_on_non_identifier_columns():
598+
"""
599+
Column names very often follow the same rules as Python identifiers, but
600+
they are not required to. Test that query() can handle such names.
601+
"""
602+
# Taken from GH#174
603+
nf = NestedFrame(data={"dog": [1, 2, 3], "good dog": [2, 4, 6]}, index=[0, 1, 2])
604+
nested = pd.DataFrame(
605+
data={"a": [0, 2, 4, 1, 4, 3, 1, 4, 1], "b": [5, 4, 7, 5, 3, 1, 9, 3, 4]},
606+
index=[0, 0, 0, 1, 1, 1, 2, 2, 2],
607+
)
608+
nf = nf.add_nested(nested, "bad dog")
609+
nf2 = nf.query("`good dog` > 3")
610+
assert nf.shape == (3, 3)
611+
assert nf2.shape == (2, 3)
612+
nf3 = nf.query("`bad dog`.a > 2")
613+
assert nf3["bad dog"].nest["a"].size == 4
614+
615+
597616
def test_dropna():
598617
"""Test that dropna works on all layers"""
599618

tests/nested_pandas/utils/test_utils.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import pandas as pd
33
import pytest
44
from nested_pandas import NestedFrame
5-
from nested_pandas.nestedframe.utils import extract_nest_names
65
from nested_pandas.utils import count_nested
76

87

@@ -52,16 +51,41 @@ def test_check_expr_nesting():
5251
used to ensure that an expression-based query does not try to combine base and nested
5352
sub-expressions.
5453
"""
55-
assert extract_nest_names("a > 2 & nested.c > 1") == {"", "nested"}
56-
assert extract_nest_names("(nested.c > 1) and (nested.d>2)") == {"nested"}
57-
assert extract_nest_names("-1.52e-5 < abc < 35.2e2") == {""}
58-
assert extract_nest_names("(n.a > 1) and ((b + c) > (d - 1e-8)) or n.q > c") == {"n", ""}
54+
base = NestedFrame(data={"a": [1, 2, 3], "b": [2, np.nan, 6]}, index=[0, 1, 2])
55+
nested = pd.DataFrame(
56+
data={
57+
"c": [0, 2, 4, 1, np.nan, 3, 1, 4, 1],
58+
"d": [5, 4, 7, 5, 3, 1, 9, 3, 4],
59+
"label": ["b", "a", "b", "b", "a", "a", "b", "a", "b"],
60+
},
61+
index=[0, 0, 0, 1, 1, 1, 2, 2, 2],
62+
)
63+
b1 = base.add_nested(nested, "nested")
64+
assert b1.extract_nest_names("a > 2 & nested.c > 1") == {"", "nested"}
65+
assert b1.extract_nest_names("(nested.c > 1) and (nested.d>2)") == {"nested"}
66+
assert b1.extract_nest_names("-1.52e-5 < b < 35.2e2") == {""}
67+
68+
b2 = base.add_nested(nested.copy(), "n")
69+
assert b2.extract_nest_names("(n.c > 1) and ((b + a) > (b - 1e-8)) or n.d > a") == {"n", ""}
70+
71+
abc = pd.DataFrame(
72+
data={
73+
"c": [3, 1, 4, 1, 5, 9, 2, 6, 5],
74+
"d": [1, 4, 1, 2, 1, 3, 5, 6, 2],
75+
"g": ["a", "b", "c", "d", "e", "f", "g", "h", "i"],
76+
},
77+
index=[0, 0, 0, 1, 1, 1, 2, 2, 2],
78+
)
79+
b3 = base.add_nested(abc, "abc").add_nested(abc, "c")
80+
assert b3.extract_nest_names("abc.c > 2 & c.d < 5") == {"abc", "c"}
81+
82+
assert b3.extract_nest_names("(abc.d > 3) & (abc.c == [2, 5])") == {"abc"}
83+
assert b3.extract_nest_names("(abc.d > 3)&(abc.g == 'f')") == {"abc"}
84+
assert b3.extract_nest_names("(abc.d > 3) & (abc.g == 'f')") == {"abc"}
5985

60-
assert extract_nest_names("a.b > 2 & c.d < 5") == {"a", "c"}
86+
assert b1.extract_nest_names("a>3") == {""}
87+
assert b1.extract_nest_names("a > 3") == {""}
6188

62-
assert extract_nest_names("a>3") == {""}
63-
assert extract_nest_names("a > 3") == {""}
64-
assert extract_nest_names("test.a>5&b==2") == {"test", ""}
65-
assert extract_nest_names("test.a > 5 & b == 2") == {"test", ""}
66-
assert extract_nest_names("(a.b > 3)&(a.c == 'f')") == {"a"}
67-
assert extract_nest_names("(a.b > 3) & (a.c == 'f')") == {"a"}
89+
b4 = base.add_nested(nested, "test")
90+
assert b4.extract_nest_names("test.c>5&b==2") == {"test", ""}
91+
assert b4.extract_nest_names("test.c > 5 & b == 2") == {"test", ""}

0 commit comments

Comments
 (0)