Skip to content

Commit

Permalink
Merge pull request #63 from lincc-frameworks/fix-nest-assign
Browse files Browse the repository at this point in the history
Fix .nest[...] = pd.Series
  • Loading branch information
hombit authored May 9, 2024
2 parents be2b7a6 + 4375c59 commit a196bd6
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 3 deletions.
13 changes: 10 additions & 3 deletions src/nested_pandas/series/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def to_flat(self, fields: list[str] | None = None) -> pd.DataFrame:
for field in fields:
list_array = cast(pa.ListArray, struct_array.field(field))
if index is None:
index = np.repeat(self._series.index.values, np.diff(list_array.offsets))
index = self.get_flat_index()
flat_series[field] = pd.Series(
list_array.flatten(),
index=pd.Series(index, name=self._series.index.name),
Expand Down Expand Up @@ -178,6 +178,13 @@ def query_flat(self, query: str) -> pd.Series:
return pd.Series([], dtype=self._series.dtype)
return pack_sorted_df_into_struct(flat)

def get_flat_index(self) -> pd.Index:
"""Index of the flat arrays"""
return pd.Index(
np.repeat(self._series.index.values, np.diff(self._series.array.list_offsets)),
name=self._series.index.name,
)

def get_flat_series(self, field: str) -> pd.Series:
"""Get the flat-array field as a Series
Expand All @@ -200,7 +207,7 @@ def get_flat_series(self, field: str) -> pd.Series:
return pd.Series(
flat_array,
dtype=pd.ArrowDtype(flat_array.type),
index=np.repeat(self._series.index.values, np.diff(self._series.array.list_offsets)),
index=self.get_flat_index(),
name=field,
copy=False,
)
Expand Down Expand Up @@ -252,7 +259,7 @@ def __setitem__(self, key: str, value: ArrayLike) -> None:
self.set_flat_field(key, value)
return

if isinstance(value, pd.Series) and not np.array_equal(self._series.index.values, value.index.values):
if isinstance(value, pd.Series) and not self.get_flat_index().equals(value.index):
raise ValueError("Cannot set field with a Series of different index")

pa_array = pa.array(value, from_pandas=True)
Expand Down
35 changes: 35 additions & 0 deletions tests/nested_pandas/series/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,41 @@ def test___setitem__():
)


def test___setitem___with_series_with_index():
"""Test that the .nest["field"] = pd.Series(...) works for a single field."""
struct_array = pa.StructArray.from_arrays(
arrays=[
pa.array([np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0])]),
pa.array([-np.array([4.0, 5.0, 6.0]), -np.array([3.0, 4.0, 5.0])]),
],
names=["a", "b"],
)
series = pd.Series(struct_array, dtype=NestedDtype(struct_array.type), index=[0, 1])

flat_series = pd.Series(
data=["a", "b", "c", "d", "e", "f"],
index=[0, 0, 0, 1, 1, 1],
name="a",
dtype=pd.ArrowDtype(pa.string()),
)

series.nest["a"] = flat_series

assert_series_equal(
series.nest["a"],
flat_series,
)
assert_series_equal(
series.nest.get_list_series("a"),
pd.Series(
data=[np.array(["a", "b", "c"]), np.array(["d", "e", "f"])],
dtype=pd.ArrowDtype(pa.list_(pa.string())),
index=[0, 1],
name="a",
),
)


def test___setitem___raises_for_wrong_length():
"""Test that the .nest["field"] = ... raises for a wrong length."""
struct_array = pa.StructArray.from_arrays(
Expand Down

0 comments on commit a196bd6

Please sign in to comment.