Skip to content

Commit

Permalink
BUG: Copy attrs on pd.merge()
Browse files Browse the repository at this point in the history
This uses the same logic as `pd.concat()`: Copy `attrs` only if all
input `attrs` are identical.

I've refactored the handling in __finalize__ from special-casing based on th the method name (previously only "concat") to handling "other" parameters
that have an `input_objs` attribute. This is a more scalable architecture compared to hard-coding method names in __finalize__.

Tests added for `concat()` and `merge()`.

Closes #60351.
  • Loading branch information
timhoffm committed Nov 19, 2024
1 parent 6a7685f commit 7815e73
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 14 deletions.
4 changes: 2 additions & 2 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6053,8 +6053,8 @@ def __finalize__(self, other, method: str | None = None, **kwargs) -> Self:
assert isinstance(name, str)
object.__setattr__(self, name, getattr(other, name, None))

if method == "concat":
objs = other.objs
elif hasattr(other, "input_objs"):
objs = other.input_objs
# propagate attrs only if all concat arguments have the same attrs
if all(bool(obj.attrs) for obj in objs):
# all concatenate arguments have non-empty attrs
Expand Down
8 changes: 5 additions & 3 deletions pandas/core/reshape/concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def _get_result(
result = sample._constructor_from_mgr(mgr, axes=mgr.axes)
result._name = name
return result.__finalize__(
types.SimpleNamespace(objs=objs), method="concat"
types.SimpleNamespace(input_objs=objs), method="concat"
)

# combine as columns in a frame
Expand All @@ -566,7 +566,9 @@ def _get_result(
)
df = cons(data, index=index, copy=False)
df.columns = columns
return df.__finalize__(types.SimpleNamespace(objs=objs), method="concat")
return df.__finalize__(
types.SimpleNamespace(input_objs=objs), method="concat"
)

# combine block managers
else:
Expand Down Expand Up @@ -605,7 +607,7 @@ def _get_result(
)

out = sample._constructor_from_mgr(new_data, axes=new_data.axes)
return out.__finalize__(types.SimpleNamespace(objs=objs), method="concat")
return out.__finalize__(types.SimpleNamespace(input_objs=objs), method="concat")


def new_axes(
Expand Down
10 changes: 8 additions & 2 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
import datetime
from functools import partial
import types
from typing import (
TYPE_CHECKING,
Literal,
Expand Down Expand Up @@ -1106,7 +1107,10 @@ def get_result(self) -> DataFrame:
join_index, left_indexer, right_indexer = self._get_join_info()

result = self._reindex_and_concat(join_index, left_indexer, right_indexer)
result = result.__finalize__(self, method=self._merge_type)
result = result.__finalize__(
types.SimpleNamespace(input_objs=[self.left, self.right]),
method=self._merge_type,
)

if self.indicator:
result = self._indicator_post_merge(result)
Expand All @@ -1115,7 +1119,9 @@ def get_result(self) -> DataFrame:

self._maybe_restore_index_levels(result)

return result.__finalize__(self, method="merge")
return result.__finalize__(
types.SimpleNamespace(input_objs=[self.left, self.right]), method="merge"
)

@final
@cache_readonly
Expand Down
26 changes: 25 additions & 1 deletion pandas/tests/frame/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def test_attrs(self):
result = df.rename(columns=str)
assert result.attrs == {"version": 1}

def test_attrs_deepcopy(self):
def test_attrs_is_deepcopy(self):
df = DataFrame({"A": [2, 3]})
assert df.attrs == {}
df.attrs["tags"] = {"spam", "ham"}
Expand All @@ -324,6 +324,30 @@ def test_attrs_deepcopy(self):
assert result.attrs == df.attrs
assert result.attrs["tags"] is not df.attrs["tags"]

def test_attrs_concat(self):
# concat propagates attrs if all input attrs are equal
df1 = DataFrame({"A": [2, 3]})
df1.attrs = {"a": 1, "b": 2}
df2 = DataFrame({"A": [4, 5]})
df2.attrs = df1.attrs.copy()
df3 = DataFrame({"A": [6, 7]})
df3.attrs = df1.attrs.copy()
assert pd.concat([df1, df2, df3]).attrs == df1.attrs
# concat does not propagate attrs if input attrs are different
df2.attrs = {"c": 3}
assert pd.concat([df1, df2, df3]).attrs == {}

def test_attrs_merge(self):
# merge propagates attrs if all input attrs are equal
df1 = DataFrame({"key": ["a", "b"], "val1": [1, 2]})
df1.attrs = {"a": 1, "b": 2}
df2 = DataFrame({"key": ["a", "b"], "val2": [3, 4]})
df2.attrs = df1.attrs.copy()
assert pd.merge(df1, df2).attrs == df1.attrs
# merge does not propagate attrs if input attrs are different
df2.attrs = {"c": 3}
assert pd.merge(df1, df2).attrs == {}

@pytest.mark.parametrize("allows_duplicate_labels", [True, False, None])
def test_set_flags(
self,
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/generic/test_duplicate_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def test_concat(self, objs, kwargs):
allows_duplicate_labels=False
),
False,
marks=not_implemented,
),
# false true false
pytest.param(
Expand All @@ -173,7 +172,6 @@ def test_concat(self, objs, kwargs):
),
pd.DataFrame({"B": [0, 1]}, index=["a", "d"]),
False,
marks=not_implemented,
),
# true true true
(
Expand Down Expand Up @@ -296,7 +294,6 @@ def test_concat_raises(self):
with pytest.raises(pd.errors.DuplicateLabelError, match=msg):
pd.concat(objs, axis=1)

@not_implemented
def test_merge_raises(self):
a = pd.DataFrame({"A": [0, 1, 2]}, index=["a", "b", "c"]).set_flags(
allows_duplicate_labels=False
Expand Down
8 changes: 6 additions & 2 deletions pandas/tests/generic/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,16 @@ def test_metadata_propagation_indiv(self, monkeypatch):
def finalize(self, other, method=None, **kwargs):
for name in self._metadata:
if method == "merge":
left, right = other.left, other.right
left, right = other.input_objs
value = getattr(left, name, "") + "|" + getattr(right, name, "")
object.__setattr__(self, name, value)
elif method == "concat":
value = "+".join(
[getattr(o, name) for o in other.objs if getattr(o, name, None)]
[
getattr(o, name)
for o in other.input_objs
if getattr(o, name, None)
]
)
object.__setattr__(self, name, value)
else:
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/generic/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def finalize(self, other, method=None, **kwargs):
value = "+".join(
[
getattr(obj, name)
for obj in other.objs
for obj in other.input_objs
if getattr(obj, name, None)
]
)
Expand Down

0 comments on commit 7815e73

Please sign in to comment.