Skip to content

Commit 82fe541

Browse files
authored
Merge pull request #96 from lincc-frameworks/reduce-reimpl
Reimplementation of NestedFrame.reduce()
2 parents aadb12f + fc0b320 commit 82fe541

File tree

3 files changed

+44
-29
lines changed

3 files changed

+44
-29
lines changed

src/nested_pandas/nestedframe/core.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -413,35 +413,15 @@ def my_sum(col1, col2):
413413
if len(requested_columns) < len(args):
414414
extra_args = args[len(requested_columns) :]
415415

416-
# find targeted layers
417-
layers = np.unique([col[0] for col in requested_columns])
418-
419-
# build a flat dataframe with array columns to apply to the function
420-
apply_df = NestedFrame()
421-
for layer in layers:
416+
iterators = []
417+
for layer, col in requested_columns:
422418
if layer == "base":
423-
columns = [col[1] for col in requested_columns if col[0] == layer]
424-
apply_df = apply_df.join(self[columns], how="outer")
419+
iterators.append(self[col])
425420
else:
426-
# TODO: It should be faster to pass these columns to to_lists, but its 20x slower
427-
# columns = [col[1] for col in requested_columns if col[0] == layer]
428-
apply_df = apply_df.join(self[layer].nest.to_lists(), how="outer")
421+
iterators.append(self[layer].array.iter_field_lists(col))
429422

430-
# Translates the requested columns into the scalars or arrays we pass to func.
431-
def translate_cols(frame, layer, col):
432-
if layer == "base":
433-
# We pass the "base" column as a scalar
434-
return frame[col]
435-
return np.asarray(frame[col])
436-
437-
# send arrays along to the apply call
438-
result = apply_df.apply(
439-
lambda x: func(
440-
*[translate_cols(x, layer, col) for layer, col in requested_columns], *extra_args, **kwargs
441-
),
442-
axis=1, # to apply func on each row of our nested frame)
443-
)
444-
return result
423+
results = [func(*cols, *extra_args, **kwargs) for cols in zip(*iterators)]
424+
return NestedFrame(results, index=self.index)
445425

446426
def to_parquet(self, path, by_layer=False, **kwargs) -> None:
447427
"""Creates parquet file(s) with the data of a NestedFrame, either

src/nested_pandas/series/ext_array.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
# typing.Self and "|" union syntax don't exist in Python 3.9
3636
from __future__ import annotations
3737

38-
from collections.abc import Iterable, Iterator, Sequence
38+
from collections.abc import Generator, Iterable, Iterator, Sequence
3939
from typing import Any, Callable, cast
4040

4141
import numpy as np
@@ -648,8 +648,27 @@ def num_chunks(self) -> int:
648648
"""Number of chunks in underlying pyarrow.ChunkedArray"""
649649
return self._chunked_array.num_chunks
650650

651+
def iter_field_lists(self, field: str) -> Generator[np.ndarray, None, None]:
652+
"""Iterate over single field nested lists, as numpy arrays
653+
654+
Parameters
655+
----------
656+
field : str
657+
The name of the field to iterate over.
658+
659+
Yields
660+
------
661+
np.ndarray
662+
The numpy array view over a list scalar.
663+
"""
664+
for chunk in self._chunked_array.iterchunks():
665+
struct_array: pa.StructArray = cast(pa.StructArray, chunk)
666+
list_array: pa.ListArray = cast(pa.ListArray, struct_array.field(field))
667+
for list_scalar in list_array:
668+
yield np.asarray(list_scalar.values)
669+
651670
def view_fields(self, fields: str | list[str]) -> Self: # type: ignore[name-defined] # noqa: F821
652-
"""Get a view of the series with only the specified fields
671+
"""Get a view of the extension array with only the specified fields
653672
654673
Parameters
655674
----------
@@ -659,7 +678,7 @@ def view_fields(self, fields: str | list[str]) -> Self: # type: ignore[name-def
659678
Returns
660679
-------
661680
NestedExtensionArray
662-
The view of the series with only the specified fields.
681+
The view of the array with only the specified fields.
663682
"""
664683
if isinstance(fields, str):
665684
fields = [fields]

tests/nested_pandas/series/test_ext_array.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1255,6 +1255,22 @@ def test_num_chunks():
12551255
assert ext_array.num_chunks == 7
12561256

12571257

1258+
def test_iter_field_lists():
1259+
"""Test .iter_field_lists() yields the correct field lists"""
1260+
a = [[1, 2, 3], [1, 2, 3, 4]]
1261+
b = [np.array(["a", "b", "c"]), np.array(["x", "y", "z", "w"])]
1262+
struct_array = pa.StructArray.from_arrays(
1263+
arrays=[a, b],
1264+
names=["a", "b"],
1265+
)
1266+
ext_array = NestedExtensionArray(struct_array)
1267+
1268+
for actual, desired in zip(ext_array.iter_field_lists("a"), a):
1269+
assert_array_equal(actual, desired)
1270+
for actual, desired in zip(ext_array.iter_field_lists("b"), b):
1271+
assert_array_equal(actual, desired)
1272+
1273+
12581274
def test_view_fields_with_single_field():
12591275
"""Tests ext_array.view("field")"""
12601276
arrays = [

0 commit comments

Comments
 (0)