Skip to content

Commit 8fca944

Browse files
committed
Fix .nest.get_*_series
1 parent 0d8c2f3 commit 8fca944

File tree

2 files changed

+72
-10
lines changed

2 files changed

+72
-10
lines changed

src/nested_pandas/series/accessor.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,18 @@ def get_flat_series(self, field: str) -> pd.Series:
251251
The flat-array field.
252252
"""
253253

254-
# TODO: we should make proper missed values handling here
254+
flat_chunks = []
255+
for nested_chunk in self._series.array._chunked_array.iterchunks():
256+
struct_array = cast(pa.StructArray, nested_chunk)
257+
list_array = cast(pa.ListArray, struct_array.field(field))
258+
flat_array = list_array.flatten()
259+
flat_chunks.append(flat_array)
260+
261+
flat_chunked_array = pa.chunked_array(flat_chunks)
255262

256-
struct_array = cast(pa.StructArray, pa.array(self._series))
257-
list_array = cast(pa.ListArray, struct_array.field(field))
258-
flat_array = list_array.flatten()
259263
return pd.Series(
260-
flat_array,
261-
dtype=pd.ArrowDtype(flat_array.type),
264+
flat_chunked_array,
265+
dtype=pd.ArrowDtype(flat_chunked_array.type),
262266
index=self.get_flat_index(),
263267
name=field,
264268
copy=False,
@@ -277,11 +281,16 @@ def get_list_series(self, field: str) -> pd.Series:
277281
pd.Series
278282
The list-array field.
279283
"""
280-
struct_array = cast(pa.StructArray, pa.array(self._series))
281-
list_array = struct_array.field(field)
284+
285+
list_chunks = []
286+
for nested_chunk in self._series.array._chunked_array.iterchunks():
287+
struct_array = cast(pa.StructArray, nested_chunk)
288+
list_array = struct_array.field(field)
289+
list_chunks.append(list_array)
290+
list_chunked_array = pa.chunked_array(list_chunks)
282291
return pd.Series(
283-
list_array,
284-
dtype=pd.ArrowDtype(list_array.type),
292+
list_chunked_array,
293+
dtype=pd.ArrowDtype(list_chunked_array.type),
285294
index=self._series.index,
286295
name=field,
287296
copy=False,

tests/nested_pandas/series/test_accessor.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,32 @@ def test_get_list_series():
543543
)
544544

545545

546+
def test_get_list_series_multiple_chunks():
547+
"""Test that .nest.get_list_series() works when underlying array is chunked"""
548+
struct_array = pa.StructArray.from_arrays(
549+
arrays=[
550+
[np.array([1, 2, 3]), np.array([4, 5, 6])],
551+
[np.array([6, 4, 2]), np.array([1, 2, 3])],
552+
],
553+
names=["a", "b"],
554+
)
555+
chunked_array = pa.chunked_array([struct_array] * 3)
556+
series = pd.Series(chunked_array, dtype=NestedDtype(chunked_array.type), index=[5, 7, 9, 11, 13, 15])
557+
assert series.array.num_chunks == 3
558+
559+
lists = series.nest.get_list_series("a")
560+
561+
assert_series_equal(
562+
lists,
563+
pd.Series(
564+
data=[np.array([1, 2, 3]), np.array([4, 5, 6])] * 3,
565+
dtype=pd.ArrowDtype(pa.list_(pa.int64())),
566+
index=[5, 7, 9, 11, 13, 15],
567+
name="a",
568+
),
569+
)
570+
571+
546572
def test_get():
547573
"""Test .nest.get() which is implemented by the base class"""
548574
series = pack_seq(
@@ -588,6 +614,33 @@ def test___getitem___single_field():
588614
)
589615

590616

617+
def test___getitem___single_field_multiple_chunks():
618+
"""Reproduces issue 142
619+
620+
https://github.com/lincc-frameworks/nested-pandas/issues/142
621+
"""
622+
struct_array = pa.StructArray.from_arrays(
623+
arrays=[
624+
[np.array([1.0, 2.0, 3.0]), np.array([1.0, 2.0, 1.0])],
625+
[np.array([4.0, 5.0, 6.0]), np.array([3.0, 4.0, 5.0])],
626+
],
627+
names=["a", "b"],
628+
)
629+
chunked_array = pa.chunked_array([struct_array] * 3)
630+
series = pd.Series(chunked_array, dtype=NestedDtype(chunked_array.type), index=[0, 1, 2, 3, 4, 5])
631+
assert series.array.num_chunks == 3
632+
633+
assert_series_equal(
634+
series.nest["a"],
635+
pd.Series(
636+
np.array([1.0, 2.0, 3.0, 1.0, 2.0, 1.0] * 3),
637+
dtype=pd.ArrowDtype(pa.float64()),
638+
index=[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5],
639+
name="a",
640+
),
641+
)
642+
643+
591644
def test___getitem___multiple_fields():
592645
"""Test that the .nest[["b", "a"]] works for multiple fields."""
593646
arrays = [

0 commit comments

Comments
 (0)