Skip to content

Commit 495b047

Browse files
authored
Merge pull request #205 from lincc-frameworks/sort_values_nested
Wrapper for DataFrame.sort_values
2 parents 3d3cab4 + 21096fa commit 495b047

File tree

2 files changed

+166
-0
lines changed

2 files changed

+166
-0
lines changed

src/nested_pandas/nestedframe/core.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,118 @@ def dropna(
916916
return None
917917
return new_df
918918

919+
def sort_values(
920+
self,
921+
by,
922+
*,
923+
axis=0,
924+
ascending=True,
925+
inplace=False,
926+
kind="quicksort",
927+
na_position="last",
928+
ignore_index=False,
929+
key=None,
930+
):
931+
"""
932+
Sort by the values along either axis.
933+
934+
Parameters:
935+
-----------
936+
by : str or list of str
937+
Name or list of names to sort by.
938+
939+
Access nested columns using `nested_df.nested_col` (where
940+
`nested_df` refers to a particular nested dataframe and
941+
`nested_col` is a column of that nested dataframe).
942+
axis : {0 or 'index', 1 or 'columns'}, default 0
943+
Axis to be sorted.
944+
ascending : bool or list of bool, default True
945+
Sort ascending vs. descending. Specify list for multiple sort
946+
orders. If this is a list of bools, must match the length of the
947+
by.
948+
inplace : bool, default False
949+
If True, perform operation in-place.
950+
kind : {'quicksort', 'mergesort', 'heapsort'}, default 'quicksort'
951+
Choice of sorting algorithm. See also ndarray.np.sort for more
952+
information. mergesort is the only stable algorithm. For DataFrames,
953+
this option is only applied when sorting on a single column or label.
954+
na_position : {'first', 'last'}, default 'last'
955+
Puts NaNs at the beginning if first; last puts NaNs at the end.
956+
ignore_index : bool, default False
957+
If True, the resulting axis will be labeled 0, 1, …, n - 1.
958+
Always False when applied to nested layers.
959+
key : callable, optional
960+
Apply the key function to the values before sorting.
961+
962+
Returns:
963+
--------
964+
DataFrame or None
965+
DataFrame with sorted values if inplace=False, None otherwise.
966+
"""
967+
968+
# Resolve target layer
969+
target = []
970+
if isinstance(by, str):
971+
by = [by]
972+
# Check "by" columns for hierarchical references
973+
for col in by:
974+
if self._is_known_hierarchical_column(col):
975+
target.append(col.split(".")[0])
976+
else:
977+
target.append("base")
978+
979+
# Ensure one target layer, preventing multi-layer operations
980+
target = np.unique(target)
981+
if len(target) > 1:
982+
raise ValueError("Queries cannot target multiple structs/layers, write a separate query for each")
983+
target = str(target[0])
984+
985+
# Apply pandas sort_values
986+
if target == "base":
987+
return super().sort_values(
988+
by=by,
989+
axis=axis,
990+
ascending=ascending,
991+
inplace=inplace,
992+
kind=kind,
993+
na_position=na_position,
994+
ignore_index=ignore_index,
995+
key=key,
996+
)
997+
else: # target is a nested column
998+
target_flat = self[target].nest.to_flat()
999+
target_flat = target_flat.set_index(self[target].array.get_list_index())
1000+
1001+
if target_flat.index.name is None: # set name if not present
1002+
target_flat.index.name = "index"
1003+
# Index must always be the first sort key for nested columns
1004+
nested_by = [target_flat.index.name] + [col.split(".")[-1] for col in by]
1005+
1006+
# Augment the ascending kwarg to include the index
1007+
if isinstance(ascending, bool):
1008+
ascending = [True] + [ascending] * len(by)
1009+
elif isinstance(ascending, list):
1010+
ascending = [True] + ascending
1011+
1012+
target_flat = target_flat.sort_values(
1013+
by=nested_by,
1014+
axis=axis,
1015+
ascending=ascending,
1016+
kind=kind,
1017+
na_position=na_position,
1018+
ignore_index=False,
1019+
key=key,
1020+
inplace=False,
1021+
)
1022+
1023+
# Could be optimized, as number of rows doesn't change
1024+
new_df = self._set_filtered_flat_df(nest_name=target, flat_df=target_flat)
1025+
1026+
if inplace:
1027+
self._update_inplace(new_df)
1028+
return None
1029+
return new_df
1030+
9191031
def reduce(self, func, *args, **kwargs) -> NestedFrame: # type: ignore[override]
9201032
"""
9211033
Takes a function and applies it to each top-level row of the NestedFrame.

tests/nested_pandas/nestedframe/test_nestedframe.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,60 @@ def test_dropna_errors():
807807
base.dropna(on_nested="nested", subset=["b"])
808808

809809

810+
def test_sort_values():
811+
"""Test that sort_values works on all layers"""
812+
813+
base = NestedFrame(data={"a": [1, 2, 3], "b": [2, 3, 6]}, index=[0, 1, 2])
814+
815+
nested = pd.DataFrame(
816+
data={"c": [0, 2, 4, 1, 4, 3, 1, 4, 1], "d": [5, 4, 7, 5, 3, 1, 9, 3, 4]},
817+
index=[0, 0, 0, 1, 1, 1, 2, 2, 2],
818+
)
819+
820+
base = base.add_nested(nested, "nested")
821+
822+
# Test basic functionality
823+
sv_base = base.sort_values("b")
824+
assert list(sv_base.index) == [0, 1, 2]
825+
826+
# Test on nested column
827+
sv_base = base.sort_values(["nested.d"])
828+
assert list(sv_base.iloc[0]["nested"]["d"]) == [4, 5, 7]
829+
830+
# Test multi-layer error trigger
831+
with pytest.raises(ValueError):
832+
base.sort_values(["a", "nested.c"])
833+
834+
# Test inplace=True
835+
base.sort_values("nested.d", inplace=True)
836+
assert list(base.iloc[0]["nested"]["d"]) == [4, 5, 7]
837+
838+
839+
def test_sort_values_ascension():
840+
"""Test that sort_values works with various ascending settings"""
841+
842+
base = NestedFrame(data={"a": [1, 2, 3], "b": [2, 3, 6]}, index=[0, 1, 2])
843+
844+
nested = pd.DataFrame(
845+
data={"c": [0, 2, 4, 1, 4, 3, 1, 4, 1], "d": [5, 4, 7, 5, 3, 1, 9, 3, 4]},
846+
index=[0, 0, 0, 1, 1, 1, 2, 2, 2],
847+
)
848+
849+
base = base.add_nested(nested, "nested")
850+
851+
# Test ascending=False
852+
sv_base = base.sort_values("nested.d", ascending=False)
853+
assert list(sv_base.iloc[0]["nested"]["d"]) == [7, 5, 4]
854+
855+
# Test list ascending
856+
sv_base = base.sort_values("nested.d", ascending=[False])
857+
assert list(sv_base.iloc[0]["nested"]["d"]) == [7, 5, 4]
858+
859+
# Test multi-by multi-ascending
860+
sv_base = base.sort_values(["nested.d", "nested.c"], ascending=[False, True])
861+
assert list(sv_base.iloc[0]["nested"]["d"]) == [7, 5, 4]
862+
863+
810864
def test_reduce():
811865
"""Tests that we can call reduce on a NestedFrame with a custom function."""
812866
nf = NestedFrame(

0 commit comments

Comments
 (0)