From bfbc34d315eacbf44238268bc35507ae23526c36 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sat, 30 Nov 2024 15:57:20 +0000 Subject: [PATCH] feat: consistently return Python scalars from Series reductions for PyArrow (#1471) --------- Co-authored-by: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> --- narwhals/__init__.py | 6 +- narwhals/_arrow/dataframe.py | 10 +++- narwhals/_arrow/expr.py | 2 + narwhals/_arrow/series.py | 87 +++++++++++++++++----------- narwhals/_dask/expr.py | 2 + narwhals/_expression_parsing.py | 11 +++- narwhals/dependencies.py | 34 +++++------ narwhals/selectors.py | 6 +- narwhals/series.py | 2 +- narwhals/stable/v1/__init__.py | 4 +- narwhals/stable/v1/_dtypes.py | 4 +- narwhals/stable/v1/dependencies.py | 34 +++++------ narwhals/stable/v1/dtypes.py | 4 +- narwhals/stable/v1/selectors.py | 6 +- narwhals/stable/v1/typing.py | 8 +-- narwhals/translate.py | 2 +- narwhals/typing.py | 8 +-- tests/translate/to_py_scalar_test.py | 2 + tests/utils.py | 3 +- 19 files changed, 137 insertions(+), 98 deletions(-) diff --git a/narwhals/__init__.py b/narwhals/__init__.py index 449b416fd..072ffcb63 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -86,10 +86,10 @@ "Field", "Float32", "Float64", + "Int8", "Int16", "Int32", "Int64", - "Int8", "LazyFrame", "List", "Object", @@ -97,10 +97,10 @@ "Series", "String", "Struct", + "UInt8", "UInt16", "UInt32", "UInt64", - "UInt8", "Unknown", "all", "all_horizontal", @@ -113,8 +113,8 @@ "exceptions", "from_arrow", "from_dict", - "from_numpy", "from_native", + "from_numpy", "generate_temporary_column_name", "get_level", "get_native_namespace", diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index ea80f752a..f9ca79893 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -541,6 +541,8 @@ def is_empty(self: Self) -> bool: return self.shape[0] == 0 def item(self: Self, row: int | None, column: int | str | None) -> Any: + from narwhals._arrow.series import maybe_extract_py_scalar + if row is None and column is None: if self.shape != (1, 1): msg = ( @@ -549,14 +551,18 @@ def item(self: Self, row: int | None, column: int | str | None) -> Any: f" frame has shape {self.shape!r}" ) raise ValueError(msg) - return self._native_frame[0][0] + return maybe_extract_py_scalar( + self._native_frame[0][0], return_py_scalar=True + ) elif row is None or column is None: msg = "cannot call `.item()` with only one of `row` or `column`" raise ValueError(msg) _col = self.columns.index(column) if isinstance(column, str) else column - return self._native_frame[_col][row] + return maybe_extract_py_scalar( + self._native_frame[_col][row], return_py_scalar=True + ) def rename(self: Self, mapping: dict[str, str]) -> Self: df = self._native_frame diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 1cb0a068c..9d2b8ec61 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -25,6 +25,8 @@ class ArrowExpr: + _implementation: Implementation = Implementation.PYARROW + def __init__( self: Self, call: Callable[[ArrowDataFrame], list[ArrowSeries]], diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index ebfa5d6e5..5b2f8d0f5 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -14,7 +14,6 @@ from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._arrow.utils import parse_datetime_format from narwhals._arrow.utils import validate_column_comparand -from narwhals.translate import to_py_scalar from narwhals.utils import Implementation from narwhals.utils import generate_temporary_column_name @@ -32,6 +31,12 @@ from narwhals.typing import DTypes +def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa: FBT001 + if return_py_scalar: + return getattr(value, "as_py", lambda: value)() + return value + + class ArrowSeries: def __init__( self: Self, @@ -241,8 +246,8 @@ def __invert__(self: Self) -> Self: return self._from_native_series(pc.invert(self._native_series)) - def len(self: Self) -> int: - return len(self._native_series) + def len(self: Self, *, _return_py_scalar: bool = True) -> int: + return maybe_extract_py_scalar(len(self._native_series), _return_py_scalar) # type: ignore[no-any-return] def filter(self: Self, other: Any) -> Self: if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)): @@ -251,12 +256,12 @@ def filter(self: Self, other: Any) -> Self: ser = self._native_series return self._from_native_series(ser.filter(other)) - def mean(self: Self) -> int: + def mean(self: Self, *, _return_py_scalar: bool = True) -> int: import pyarrow.compute as pc # ignore-banned-import() - return pc.mean(self._native_series) # type: ignore[no-any-return] + return maybe_extract_py_scalar(pc.mean(self._native_series), _return_py_scalar) # type: ignore[no-any-return] - def median(self: Self) -> int: + def median(self: Self, *, _return_py_scalar: bool = True) -> int: import pyarrow.compute as pc # ignore-banned-import() from narwhals.exceptions import InvalidOperationError @@ -265,22 +270,24 @@ def median(self: Self) -> int: msg = "`median` operation not supported for non-numeric input type." raise InvalidOperationError(msg) - return pc.approximate_median(self._native_series) # type: ignore[no-any-return] + return maybe_extract_py_scalar( # type: ignore[no-any-return] + pc.approximate_median(self._native_series), _return_py_scalar + ) - def min(self: Self) -> int: + def min(self: Self, *, _return_py_scalar: bool = True) -> int: import pyarrow.compute as pc # ignore-banned-import() - return pc.min(self._native_series) # type: ignore[no-any-return] + return maybe_extract_py_scalar(pc.min(self._native_series), _return_py_scalar) # type: ignore[no-any-return] - def max(self: Self) -> int: + def max(self: Self, *, _return_py_scalar: bool = True) -> int: import pyarrow.compute as pc # ignore-banned-import() - return pc.max(self._native_series) # type: ignore[no-any-return] + return maybe_extract_py_scalar(pc.max(self._native_series), _return_py_scalar) # type: ignore[no-any-return] - def sum(self: Self) -> int: + def sum(self: Self, *, _return_py_scalar: bool = True) -> int: import pyarrow.compute as pc # ignore-banned-import() - return pc.sum(self._native_series) # type: ignore[no-any-return] + return maybe_extract_py_scalar(pc.sum(self._native_series), _return_py_scalar) # type: ignore[no-any-return] def drop_nulls(self: Self) -> ArrowSeries: import pyarrow.compute as pc # ignore-banned-import() @@ -300,12 +307,14 @@ def shift(self: Self, n: int) -> Self: result = ca return self._from_native_series(result) - def std(self: Self, ddof: int) -> float: + def std(self: Self, ddof: int, *, _return_py_scalar: bool = True) -> float: import pyarrow.compute as pc # ignore-banned-import() - return pc.stddev(self._native_series, ddof=ddof) # type: ignore[no-any-return] + return maybe_extract_py_scalar( # type: ignore[no-any-return] + pc.stddev(self._native_series, ddof=ddof), _return_py_scalar + ) - def skew(self: Self) -> float | None: + def skew(self: Self, *, _return_py_scalar: bool = True) -> float | None: import pyarrow.compute as pc # ignore-banned-import() ser = self._native_series @@ -321,18 +330,22 @@ def skew(self: Self) -> float | None: m2 = pc.mean(pc.power(m, 2)) m3 = pc.mean(pc.power(m, 3)) # Biased population skewness - return pc.divide(m3, pc.power(m2, 1.5)) # type: ignore[no-any-return] + return maybe_extract_py_scalar( # type: ignore[no-any-return] + pc.divide(m3, pc.power(m2, 1.5)), _return_py_scalar + ) - def count(self: Self) -> int: + def count(self: Self, *, _return_py_scalar: bool = True) -> int: import pyarrow.compute as pc # ignore-banned-import() - return pc.count(self._native_series) # type: ignore[no-any-return] + return maybe_extract_py_scalar(pc.count(self._native_series), _return_py_scalar) # type: ignore[no-any-return] - def n_unique(self: Self) -> int: + def n_unique(self: Self, *, _return_py_scalar: bool = True) -> int: import pyarrow.compute as pc # ignore-banned-import() unique_values = pc.unique(self._native_series) - return pc.count(unique_values, mode="all") # type: ignore[no-any-return] + return maybe_extract_py_scalar( # type: ignore[no-any-return] + pc.count(unique_values, mode="all"), _return_py_scalar + ) def __native_namespace__(self: Self) -> ModuleType: if self._implementation is Implementation.PYARROW: @@ -430,15 +443,15 @@ def diff(self: Self) -> Self: pc.pairwise_diff(self._native_series.combine_chunks()) ) - def any(self: Self) -> bool: + def any(self: Self, *, _return_py_scalar: bool = True) -> bool: import pyarrow.compute as pc # ignore-banned-import() - return to_py_scalar(pc.any(self._native_series)) # type: ignore[no-any-return] + return maybe_extract_py_scalar(pc.any(self._native_series), _return_py_scalar) # type: ignore[no-any-return] - def all(self: Self) -> bool: + def all(self: Self, *, _return_py_scalar: bool = True) -> bool: import pyarrow.compute as pc # ignore-banned-import() - return to_py_scalar(pc.all(self._native_series)) # type: ignore[no-any-return] + return maybe_extract_py_scalar(pc.all(self._native_series), _return_py_scalar) # type: ignore[no-any-return] def is_between( self, lower_bound: Any, upper_bound: Any, closed: str = "both" @@ -480,8 +493,8 @@ def cast(self: Self, dtype: DType) -> Self: dtype = narwhals_to_native_dtype(dtype, self._dtypes) return self._from_native_series(pc.cast(ser, dtype)) - def null_count(self: Self) -> int: - return self._native_series.null_count # type: ignore[no-any-return] + def null_count(self: Self, *, _return_py_scalar: bool = True) -> int: + return maybe_extract_py_scalar(self._native_series.null_count, _return_py_scalar) # type: ignore[no-any-return] def head(self: Self, n: int) -> Self: ser = self._native_series @@ -527,8 +540,8 @@ def item(self: Self, index: int | None = None) -> Any: f" or an explicit index is provided (Series is of length {len(self)})" ) raise ValueError(msg) - return self._native_series[0] - return self._native_series[index] + return maybe_extract_py_scalar(self._native_series[0], return_py_scalar=True) + return maybe_extract_py_scalar(self._native_series[index], return_py_scalar=True) def value_counts( self: Self, @@ -718,7 +731,7 @@ def is_sorted(self: Self, *, descending: bool) -> bool: result = pc.all(pc.greater_equal(ser[:-1], ser[1:])) else: result = pc.all(pc.less_equal(ser[:-1], ser[1:])) - return to_py_scalar(result) # type: ignore[no-any-return] + return maybe_extract_py_scalar(result, return_py_scalar=True) # type: ignore[no-any-return] def unique(self: Self, *, maintain_order: bool) -> ArrowSeries: # The param `maintain_order` is only here for compatibility with the Polars API @@ -798,12 +811,15 @@ def quantile( self: Self, quantile: float, interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], + *, + _return_py_scalar: bool = True, ) -> Any: import pyarrow.compute as pc # ignore-banned-import() - return pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[ - 0 - ] + return maybe_extract_py_scalar( + pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[0], + _return_py_scalar, + ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: return self._from_native_series(self._native_series[offset::n]) @@ -994,7 +1010,10 @@ def rolling_mean( return result def __iter__(self: Self) -> Iterator[Any]: - yield from self._native_series.__iter__() + yield from ( + maybe_extract_py_scalar(x, return_py_scalar=True) + for x in self._native_series.__iter__() + ) @property def shape(self: Self) -> tuple[int]: diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index d4e07c2c6..cc814920a 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -29,6 +29,8 @@ class DaskExpr: + _implementation: Implementation = Implementation.DASK + def __init__( self, call: Callable[[DaskLazyFrame], list[dask_expr.Series]], diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index e8384165c..87a91dfa9 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -14,6 +14,7 @@ from narwhals.dependencies import is_numpy_array from narwhals.exceptions import InvalidIntoExprError +from narwhals.utils import Implementation if TYPE_CHECKING: from narwhals._arrow.dataframe import ArrowDataFrame @@ -223,9 +224,17 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]: for arg_name, arg_value in kwargs.items() } + # For PyArrow.Series, we return Python Scalars (like Polars does) instead of PyArrow Scalars. + # However, when working with expressions, we keep everything PyArrow-native. + extra_kwargs = ( + {"_return_py_scalar": False} + if returns_scalar and expr._implementation is Implementation.PYARROW + else {} + ) + out: list[CompliantSeries] = [ plx._create_series_from_scalar( - getattr(series, attr)(*_args, **_kwargs), + getattr(series, attr)(*_args, **extra_kwargs, **_kwargs), reference_series=series, # type: ignore[arg-type] ) if returns_scalar diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 009469536..d5f0a6c6f 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -302,30 +302,30 @@ def is_into_dataframe(native_dataframe: Any) -> bool: __all__ = [ - "get_polars", - "get_pandas", - "get_modin", "get_cudf", - "get_pyarrow", - "get_numpy", "get_ibis", + "get_modin", + "get_numpy", + "get_pandas", + "get_polars", + "get_pyarrow", + "is_cudf_dataframe", + "is_cudf_series", + "is_dask_dataframe", "is_ibis_table", + "is_into_dataframe", + "is_into_series", + "is_modin_dataframe", + "is_modin_series", + "is_numpy_array", "is_pandas_dataframe", - "is_pandas_series", "is_pandas_index", + "is_pandas_like_dataframe", + "is_pandas_like_series", + "is_pandas_series", "is_polars_dataframe", "is_polars_lazyframe", "is_polars_series", - "is_modin_dataframe", - "is_modin_series", - "is_cudf_dataframe", - "is_cudf_series", - "is_pyarrow_table", "is_pyarrow_chunked_array", - "is_numpy_array", - "is_dask_dataframe", - "is_pandas_like_dataframe", - "is_pandas_like_series", - "is_into_dataframe", - "is_into_series", + "is_pyarrow_table", ] diff --git a/narwhals/selectors.py b/narwhals/selectors.py index dbb99f82c..31a5f80e8 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -271,10 +271,10 @@ def all() -> Expr: __all__ = [ + "all", + "boolean", "by_dtype", + "categorical", "numeric", - "boolean", "string", - "categorical", - "all", ] diff --git a/narwhals/series.py b/narwhals/series.py index 6c011610e..8529ab9c2 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -656,7 +656,7 @@ def median(self) -> Any: >>> my_library_agnostic_function(s_pl) 5.0 >>> my_library_agnostic_function(s_pa) - + 5.0 """ return self._compliant_series.median() diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 69fee279c..110192334 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -3403,10 +3403,10 @@ def from_numpy( "Field", "Float32", "Float64", + "Int8", "Int16", "Int32", "Int64", - "Int8", "LazyFrame", "List", "Object", @@ -3414,10 +3414,10 @@ def from_numpy( "Series", "String", "Struct", + "UInt8", "UInt16", "UInt32", "UInt64", - "UInt8", "Unknown", "all", "all_horizontal", diff --git a/narwhals/stable/v1/_dtypes.py b/narwhals/stable/v1/_dtypes.py index 19748dae4..f6608fb5c 100644 --- a/narwhals/stable/v1/_dtypes.py +++ b/narwhals/stable/v1/_dtypes.py @@ -63,11 +63,12 @@ def __hash__(self) -> int: "Array", "Boolean", "Categorical", + "DType", "Date", "Datetime", "Duration", - "DType", "Enum", + "Field", "Float32", "Float64", "Int8", @@ -78,7 +79,6 @@ def __hash__(self) -> int: "NumericType", "Object", "String", - "Field", "Struct", "UInt8", "UInt16", diff --git a/narwhals/stable/v1/dependencies.py b/narwhals/stable/v1/dependencies.py index 3c3fae32a..6a020622e 100644 --- a/narwhals/stable/v1/dependencies.py +++ b/narwhals/stable/v1/dependencies.py @@ -28,30 +28,30 @@ from narwhals.dependencies import is_pyarrow_table __all__ = [ - "get_polars", - "get_pandas", - "get_modin", "get_cudf", - "get_pyarrow", - "get_numpy", "get_ibis", + "get_modin", + "get_numpy", + "get_pandas", + "get_polars", + "get_pyarrow", + "is_cudf_dataframe", + "is_cudf_series", + "is_dask_dataframe", "is_ibis_table", + "is_into_dataframe", + "is_into_series", + "is_modin_dataframe", + "is_modin_series", + "is_numpy_array", "is_pandas_dataframe", - "is_pandas_series", "is_pandas_index", + "is_pandas_like_dataframe", + "is_pandas_like_series", + "is_pandas_series", "is_polars_dataframe", "is_polars_lazyframe", "is_polars_series", - "is_modin_dataframe", - "is_modin_series", - "is_cudf_dataframe", - "is_cudf_series", - "is_pyarrow_table", "is_pyarrow_chunked_array", - "is_numpy_array", - "is_dask_dataframe", - "is_pandas_like_dataframe", - "is_pandas_like_series", - "is_into_dataframe", - "is_into_series", + "is_pyarrow_table", ] diff --git a/narwhals/stable/v1/dtypes.py b/narwhals/stable/v1/dtypes.py index 37c3af0e8..930e8d7dd 100644 --- a/narwhals/stable/v1/dtypes.py +++ b/narwhals/stable/v1/dtypes.py @@ -30,14 +30,14 @@ "Array", "Boolean", "Categorical", + "DType", "Date", "Datetime", "Duration", - "DType", "Enum", + "Field", "Float32", "Float64", - "Field", "Int8", "Int16", "Int32", diff --git a/narwhals/stable/v1/selectors.py b/narwhals/stable/v1/selectors.py index 1b630d88a..0d82484e9 100644 --- a/narwhals/stable/v1/selectors.py +++ b/narwhals/stable/v1/selectors.py @@ -8,10 +8,10 @@ from narwhals.selectors import string __all__ = [ + "all", + "boolean", "by_dtype", + "categorical", "numeric", - "boolean", "string", - "categorical", - "all", ] diff --git a/narwhals/stable/v1/typing.py b/narwhals/stable/v1/typing.py index 2ad1835ed..93eae2264 100644 --- a/narwhals/stable/v1/typing.py +++ b/narwhals/stable/v1/typing.py @@ -202,14 +202,14 @@ class DTypes: __all__ = [ - "IntoExpr", + "DataFrameT", + "Frame", + "FrameT", "IntoDataFrame", "IntoDataFrameT", + "IntoExpr", "IntoFrame", "IntoFrameT", - "Frame", - "FrameT", - "DataFrameT", "IntoSeries", "IntoSeriesT", ] diff --git a/narwhals/translate.py b/narwhals/translate.py index 0ffe336dc..ef30cef44 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -964,7 +964,7 @@ def to_py_scalar(scalar_like: Any) -> Any: __all__ = [ "get_native_namespace", - "to_native", "narwhalify", + "to_native", "to_py_scalar", ] diff --git a/narwhals/typing.py b/narwhals/typing.py index d51fa3bb3..8121c33d2 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -201,14 +201,14 @@ class DTypes: __all__ = [ - "IntoExpr", + "DataFrameT", + "Frame", + "FrameT", "IntoDataFrame", "IntoDataFrameT", + "IntoExpr", "IntoFrame", "IntoFrameT", - "Frame", - "FrameT", - "DataFrameT", "IntoSeries", "IntoSeriesT", ] diff --git a/tests/translate/to_py_scalar_test.py b/tests/translate/to_py_scalar_test.py index e1c4bbdf0..ace5db7a6 100644 --- a/tests/translate/to_py_scalar_test.py +++ b/tests/translate/to_py_scalar_test.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd +import pyarrow as pa import pytest import narwhals.stable.v1 as nw @@ -17,6 +18,7 @@ ("input_value", "expected"), [ (1, 1), + (pa.scalar(1), 1), (np.int64(1), 1), (1.0, 1.0), (None, None), diff --git a/tests/utils.py b/tests/utils.py index f9b493add..95ccda816 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -10,7 +10,6 @@ import pandas as pd -import narwhals as nw from narwhals.typing import IntoDataFrame from narwhals.typing import IntoFrame from narwhals.utils import Implementation @@ -54,7 +53,7 @@ def _to_comparable_list(column_values: Any) -> Any: column_values = column_values.to_pandas() if hasattr(column_values, "to_list"): return column_values.to_list() - return [nw.to_py_scalar(v) for v in column_values] + return list(column_values) def assert_equal_data(result: Any, expected: dict[str, Any]) -> None: