diff --git a/dask_expr/_expr.py b/dask_expr/_expr.py index 5d6e79ae..b95534ad 100644 --- a/dask_expr/_expr.py +++ b/dask_expr/_expr.py @@ -1183,6 +1183,9 @@ class Elemwise(Blockwise): def _simplify_up(self, parent, dependents): if self._filter_passthrough and isinstance(parent, Filter): + if self._name != parent.frame._name: + # We can't push the filter through the filter condition + return parents = [x() for x in dependents[self._name] if x() is not None] if not all(isinstance(p, Filter) for p in parents): return diff --git a/dask_expr/tests/test_collection.py b/dask_expr/tests/test_collection.py index 85500046..341939d8 100644 --- a/dask_expr/tests/test_collection.py +++ b/dask_expr/tests/test_collection.py @@ -2293,3 +2293,11 @@ def test_axes(df, pdf): [assert_eq(d, p) for d, p in zip(df.axes, pdf.axes)] assert len(df.x.axes) == len(pdf.x.axes) assert_eq(df.x.axes[0], pdf.x.axes[0]) + + +def test_filter_optimize_condition(): + pdf = pd.DataFrame({"a": [1, 2, 3, 4], "b": [True, False, True, False]}) + df = from_pandas(pdf, npartitions=2) + result = df[df.b.fillna(True)] + expected = pdf[pdf.b.fillna(True)] + assert_eq(result, expected)