diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index 2cae743f83..1cb0a068cb 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -151,55 +151,64 @@ def __and__(self: Self, other: ArrowExpr | bool | Any) -> Self: return reuse_series_implementation(self, "__and__", other=other) def __rand__(self: Self, other: ArrowExpr | bool | Any) -> Self: - return reuse_series_implementation(self, "__rand__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__and__(self) # type: ignore[return-value] def __or__(self: Self, other: ArrowExpr | bool | Any) -> Self: return reuse_series_implementation(self, "__or__", other=other) def __ror__(self: Self, other: ArrowExpr | bool | Any) -> Self: - return reuse_series_implementation(self, "__ror__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__or__(self) # type: ignore[return-value] def __add__(self: Self, other: ArrowExpr | Any) -> Self: return reuse_series_implementation(self, "__add__", other) def __radd__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__radd__", other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__add__(self) # type: ignore[return-value] def __sub__(self: Self, other: ArrowExpr | Any) -> Self: return reuse_series_implementation(self, "__sub__", other) def __rsub__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__rsub__", other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__sub__(self) # type: ignore[return-value] def __mul__(self: Self, other: ArrowExpr | Any) -> Self: return reuse_series_implementation(self, "__mul__", other) def __rmul__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__rmul__", other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__mul__(self) # type: ignore[return-value] def __pow__(self: Self, other: ArrowExpr | Any) -> Self: return reuse_series_implementation(self, "__pow__", other) def __rpow__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__rpow__", other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__pow__(self) # type: ignore[return-value] def __floordiv__(self: Self, other: ArrowExpr | Any) -> Self: return reuse_series_implementation(self, "__floordiv__", other) def __rfloordiv__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__rfloordiv__", other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__floordiv__(self) # type: ignore[return-value] def __truediv__(self: Self, other: ArrowExpr | Any) -> Self: return reuse_series_implementation(self, "__truediv__", other) def __rtruediv__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__rtruediv__", other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__truediv__(self) # type: ignore[return-value] def __mod__(self: Self, other: ArrowExpr | Any) -> Self: return reuse_series_implementation(self, "__mod__", other) def __rmod__(self: Self, other: ArrowExpr | Any) -> Self: - return reuse_series_implementation(self, "__rmod__", other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__mod__(self) # type: ignore[return-value] def __invert__(self: Self) -> Self: return reuse_series_implementation(self, "__invert__") diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index a20df3c41c..07fa7deb24 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -170,7 +170,7 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: depth=0, function_name="lit", root_names=None, - output_names=[_lit_arrow_series.__name__], + output_names=["literal"], backend_version=self._backend_version, dtypes=self._dtypes, ) diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index a2caf77f41..ebfa5d6e58 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -96,78 +96,68 @@ def __len__(self: Self) -> int: def __eq__(self: Self, other: object) -> Self: # type: ignore[override] import pyarrow.compute as pc - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.equal(ser, other)) def __ne__(self: Self, other: object) -> Self: # type: ignore[override] import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.not_equal(ser, other)) def __ge__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.greater_equal(ser, other)) def __gt__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.greater(ser, other)) def __le__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.less_equal(ser, other)) def __lt__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.less(ser, other)) def __and__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.and_kleene(ser, other)) def __rand__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.and_kleene(other, ser)) def __or__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.or_kleene(ser, other)) def __ror__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.or_kleene(other, ser)) def __add__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - other = validate_column_comparand(other) - return self._from_native_series(pc.add(self._native_series, other)) + ser, other = validate_column_comparand(self, other, self._backend_version) + return self._from_native_series(pc.add(ser, other)) def __radd__(self: Self, other: Any) -> Self: return self + other # type: ignore[no-any-return] @@ -175,8 +165,8 @@ def __radd__(self: Self, other: Any) -> Self: def __sub__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - other = validate_column_comparand(other) - return self._from_native_series(pc.subtract(self._native_series, other)) + ser, other = validate_column_comparand(self, other, self._backend_version) + return self._from_native_series(pc.subtract(ser, other)) def __rsub__(self: Self, other: Any) -> Self: return (self - other) * (-1) # type: ignore[no-any-return] @@ -184,8 +174,8 @@ def __rsub__(self: Self, other: Any) -> Self: def __mul__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - other = validate_column_comparand(other) - return self._from_native_series(pc.multiply(self._native_series, other)) + ser, other = validate_column_comparand(self, other, self._backend_version) + return self._from_native_series(pc.multiply(ser, other)) def __rmul__(self: Self, other: Any) -> Self: return self * other # type: ignore[no-any-return] @@ -193,33 +183,28 @@ def __rmul__(self: Self, other: Any) -> Self: def __pow__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.power(ser, other)) def __rpow__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(pc.power(other, ser)) def __floordiv__(self: Self, other: Any) -> Self: - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(floordiv_compat(ser, other)) def __rfloordiv__(self: Self, other: Any) -> Self: - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) return self._from_native_series(floordiv_compat(other, ser)) def __truediv__(self: Self, other: Any) -> Self: import pyarrow as pa # ignore-banned-import() import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) if not isinstance(other, (pa.Array, pa.ChunkedArray)): # scalar other = pa.scalar(other) @@ -229,8 +214,7 @@ def __rtruediv__(self: Self, other: Any) -> Self: import pyarrow as pa # ignore-banned-import() import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) + ser, other = validate_column_comparand(self, other, self._backend_version) if not isinstance(other, (pa.Array, pa.ChunkedArray)): # scalar other = pa.scalar(other) @@ -239,18 +223,16 @@ def __rtruediv__(self: Self, other: Any) -> Self: def __mod__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) floor_div = (self // other)._native_series + ser, other = validate_column_comparand(self, other, self._backend_version) res = pc.subtract(ser, pc.multiply(floor_div, other)) return self._from_native_series(res) def __rmod__(self: Self, other: Any) -> Self: import pyarrow.compute as pc # ignore-banned-import() - ser = self._native_series - other = validate_column_comparand(other) floor_div = (other // self)._native_series + ser, other = validate_column_comparand(self, other, self._backend_version) res = pc.subtract(other, pc.multiply(floor_div, ser)) return self._from_native_series(res) @@ -264,8 +246,10 @@ def len(self: Self) -> int: def filter(self: Self, other: Any) -> Self: if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)): - other = validate_column_comparand(other) - return self._from_native_series(self._native_series.filter(other)) + ser, other = validate_column_comparand(self, other, self._backend_version) + else: + ser = self._native_series + return self._from_native_series(ser.filter(other)) def mean(self: Self) -> int: import pyarrow.compute as pc # ignore-banned-import() @@ -382,16 +366,17 @@ def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self: import pyarrow as pa # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import - ca = self._native_series - mask = np.zeros(len(ca), dtype=bool) + mask = np.zeros(self.len(), dtype=bool) mask[indices] = True if isinstance(values, self.__class__): - values = validate_column_comparand(values) + ser, values = validate_column_comparand(self, values, self._backend_version) + else: + ser = self._native_series if isinstance(values, pa.ChunkedArray): values = values.combine_chunks() if not isinstance(values, pa.Array): values = pa.array(values) - result = pc.replace_with_mask(ca, mask, values.take(indices)) + result = pc.replace_with_mask(ser, mask, values.take(indices)) return self._from_native_series(result) def to_list(self: Self) -> list[Any]: diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 4641bbb1f4..f8c25433c3 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -128,7 +128,9 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> pa.D raise AssertionError(msg) -def validate_column_comparand(other: Any) -> Any: +def validate_column_comparand( + lhs: ArrowSeries, rhs: Any, backend_version: tuple[int, ...] +) -> tuple[pa.ChunkedArray, Any]: """Validate RHS of binary operation. If the comparison isn't supported, return `NotImplemented` so that the @@ -140,27 +142,47 @@ def validate_column_comparand(other: Any) -> Any: from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.series import ArrowSeries - if isinstance(other, list): - if len(other) > 1: - if hasattr(other[0], "__narwhals_expr__") or hasattr( - other[0], "__narwhals_series__" + # If `rhs` is the output of an expression evaluation, then it is + # a list of Series. So, we verify that that list is of length-1, + # and take the first (and only) element. + if isinstance(rhs, list): + if len(rhs) > 1: + if hasattr(rhs[0], "__narwhals_expr__") or hasattr( + rhs[0], "__narwhals_series__" ): # e.g. `plx.all() + plx.all()` msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) are not supported in this context" raise ValueError(msg) - msg = ( - f"Expected scalar value, Series, or Expr, got list of : {type(other[0])}" - ) + msg = f"Expected scalar value, Series, or Expr, got list of : {type(rhs[0])}" raise ValueError(msg) - other = other[0] - if isinstance(other, ArrowDataFrame): + rhs = rhs[0] + + if isinstance(rhs, ArrowDataFrame): return NotImplemented - if isinstance(other, ArrowSeries): - if len(other) == 1: + + if isinstance(rhs, ArrowSeries): + if len(rhs) == 1: # broadcast - return other[0] - return other._native_series - return other + return lhs._native_series, rhs[0] + if len(lhs) == 1: + # broadcast + import numpy as np # ignore-banned-import + import pyarrow as pa # ignore-banned-import + + fill_value = lhs[0] + if backend_version < (13,) and hasattr(fill_value, "as_py"): + fill_value = fill_value.as_py() + left_result = pa.chunked_array( + [ + pa.array( + np.full(shape=rhs.len(), fill_value=fill_value), + type=lhs._native_series.type, + ) + ] + ) + return left_result, rhs._native_series + return lhs._native_series, rhs._native_series + return lhs._native_series, rhs def validate_dataframe_comparand( @@ -179,7 +201,7 @@ def validate_dataframe_comparand( import pyarrow as pa # ignore-banned-import value = other._native_series[0] - if backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover + if backend_version < (13,) and hasattr(value, "as_py"): value = value.as_py() return pa.array(np.full(shape=length, fill_value=value)) return other._native_series @@ -321,7 +343,7 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]: s_native = s._native_series if is_max_length_gt_1 and length == 1: value = s_native[0] - if s._backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover + if s._backend_version < (13,) and hasattr(value, "as_py"): value = value.as_py() reshaped.append(pa.array([value] * max_length, type=s_native.type)) else: diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 58e73792ac..d4e07c2c62 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -199,7 +199,7 @@ def __radd__(self, other: Any) -> Self: "__radd__", other, returns_scalar=False, - ) + ).alias("literal") def __sub__(self, other: Any) -> Self: return self._from_call( @@ -215,7 +215,7 @@ def __rsub__(self, other: Any) -> Self: "__rsub__", other, returns_scalar=False, - ) + ).alias("literal") def __mul__(self, other: Any) -> Self: return self._from_call( @@ -231,7 +231,7 @@ def __rmul__(self, other: Any) -> Self: "__rmul__", other, returns_scalar=False, - ) + ).alias("literal") def __truediv__(self, other: Any) -> Self: return self._from_call( @@ -247,7 +247,7 @@ def __rtruediv__(self, other: Any) -> Self: "__rtruediv__", other, returns_scalar=False, - ) + ).alias("literal") def __floordiv__(self, other: Any) -> Self: return self._from_call( @@ -263,7 +263,7 @@ def __rfloordiv__(self, other: Any) -> Self: "__rfloordiv__", other, returns_scalar=False, - ) + ).alias("literal") def __pow__(self, other: Any) -> Self: return self._from_call( @@ -279,7 +279,7 @@ def __rpow__(self, other: Any) -> Self: "__rpow__", other, returns_scalar=False, - ) + ).alias("literal") def __mod__(self, other: Any) -> Self: return self._from_call( @@ -295,7 +295,7 @@ def __rmod__(self, other: Any) -> Self: "__rmod__", other, returns_scalar=False, - ) + ).alias("literal") def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override] return self._from_call( @@ -353,13 +353,13 @@ def __and__(self, other: DaskExpr) -> Self: returns_scalar=False, ) - def __rand__(self, other: DaskExpr) -> Self: # pragma: no cover + def __rand__(self, other: DaskExpr) -> Self: return self._from_call( lambda _input, other: _input.__rand__(other), "__rand__", other, returns_scalar=False, - ) + ).alias("literal") def __or__(self, other: DaskExpr) -> Self: return self._from_call( @@ -369,13 +369,13 @@ def __or__(self, other: DaskExpr) -> Self: returns_scalar=False, ) - def __ror__(self, other: DaskExpr) -> Self: # pragma: no cover + def __ror__(self, other: DaskExpr) -> Self: return self._from_call( lambda _input, other: _input.__ror__(other), "__ror__", other, returns_scalar=False, - ) + ).alias("literal") def __invert__(self: Self) -> Self: return self._from_call( diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 0505637637..e8384165cf 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -235,7 +235,11 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]: if expr._output_names is not None and ( [s.name for s in out] != expr._output_names ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + msg = ( + f"Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues\n" + f"Expression output names: {expr._output_names}\n" + f"Series names: {[s.name for s in out]}" + ) raise AssertionError(msg) return out diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 182ea980fc..80648a90f8 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -159,55 +159,64 @@ def __and__(self, other: PandasLikeExpr | bool | Any) -> Self: return reuse_series_implementation(self, "__and__", other=other) def __rand__(self, other: Any) -> Self: - return reuse_series_implementation(self, "__rand__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__and__(self) # type: ignore[no-any-return] def __or__(self, other: PandasLikeExpr | bool | Any) -> Self: return reuse_series_implementation(self, "__or__", other=other) def __ror__(self, other: Any) -> Self: - return reuse_series_implementation(self, "__ror__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__or__(self) # type: ignore[no-any-return] def __add__(self, other: PandasLikeExpr | Any) -> Self: return reuse_series_implementation(self, "__add__", other=other) def __radd__(self, other: Any) -> Self: - return reuse_series_implementation(self, "__radd__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__add__(self) # type: ignore[no-any-return] def __sub__(self, other: PandasLikeExpr | Any) -> Self: return reuse_series_implementation(self, "__sub__", other=other) def __rsub__(self, other: Any) -> Self: - return reuse_series_implementation(self, "__rsub__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__sub__(self) # type: ignore[no-any-return] def __mul__(self, other: PandasLikeExpr | Any) -> Self: return reuse_series_implementation(self, "__mul__", other=other) def __rmul__(self, other: Any) -> Self: - return reuse_series_implementation(self, "__rmul__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__mul__(self) # type: ignore[no-any-return] def __truediv__(self, other: PandasLikeExpr | Any) -> Self: return reuse_series_implementation(self, "__truediv__", other=other) def __rtruediv__(self, other: Any) -> Self: - return reuse_series_implementation(self, "__rtruediv__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__truediv__(self) # type: ignore[no-any-return] def __floordiv__(self, other: PandasLikeExpr | Any) -> Self: return reuse_series_implementation(self, "__floordiv__", other=other) def __rfloordiv__(self, other: Any) -> Self: - return reuse_series_implementation(self, "__rfloordiv__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__floordiv__(self) # type: ignore[no-any-return] def __pow__(self, other: PandasLikeExpr | Any) -> Self: return reuse_series_implementation(self, "__pow__", other=other) def __rpow__(self, other: Any) -> Self: - return reuse_series_implementation(self, "__rpow__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__pow__(self) # type: ignore[no-any-return] def __mod__(self, other: PandasLikeExpr | Any) -> Self: return reuse_series_implementation(self, "__mod__", other=other) def __rmod__(self, other: Any) -> Self: - return reuse_series_implementation(self, "__rmod__", other=other) + other = self.__narwhals_namespace__().lit(other, dtype=None) + return other.__mod__(self) # type: ignore[no-any-return] # Unary diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index b001f19db3..60f5fa24f4 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -155,7 +155,7 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: depth=0, function_name="lit", root_names=None, - output_names=[_lit_pandas_series.__name__], + output_names=["literal"], implementation=self._implementation, backend_version=self._backend_version, dtypes=self._dtypes, @@ -501,8 +501,9 @@ def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) value_series = cast(PandasLikeSeries, value_series) - value_series_native = value_series._native_series - condition_native = validate_column_comparand(value_series_native.index, condition) + value_series_native, condition_native = validate_column_comparand( + value_series, condition + ) if self._otherwise_value is None: return [ diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 426e828b11..ef8e784cff 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -218,9 +218,9 @@ def scatter(self, indices: int | Sequence[int], values: Any) -> Self: if isinstance(values, self.__class__): # .copy() is necessary in some pre-2.2 versions of pandas to avoid # `values` also getting modified (!) - values = validate_column_comparand(self._native_series.index, values).copy() + _, values = validate_column_comparand(self, values) values = set_axis( - values, + values.copy(), self._native_series.index[indices], implementation=self._implementation, backend_version=self._backend_version, @@ -298,135 +298,135 @@ def arg_true(self) -> PandasLikeSeries: def filter(self, other: Any) -> PandasLikeSeries: ser = self._native_series if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)): - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.loc[other].rename(ser.name, copy=False)) def __eq__(self, other: object) -> PandasLikeSeries: # type: ignore[override] ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__eq__(other).rename(ser.name, copy=False)) def __ne__(self, other: object) -> PandasLikeSeries: # type: ignore[override] ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__ne__(other).rename(ser.name, copy=False)) def __ge__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__ge__(other).rename(ser.name, copy=False)) def __gt__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__gt__(other).rename(ser.name, copy=False)) def __le__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__le__(other).rename(ser.name, copy=False)) def __lt__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__lt__(other).rename(ser.name, copy=False)) def __and__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__and__(other).rename(ser.name, copy=False)) def __rand__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) - return self._from_native_series(ser.__rand__(other).rename(ser.name, copy=False)) + ser, other = validate_column_comparand(self, other) + return self._from_native_series(ser.__and__(other).rename(ser.name, copy=False)) def __or__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__or__(other).rename(ser.name, copy=False)) def __ror__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) - return self._from_native_series(ser.__ror__(other).rename(ser.name, copy=False)) + ser, other = validate_column_comparand(self, other) + return self._from_native_series(ser.__or__(other).rename(ser.name, copy=False)) def __add__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__add__(other).rename(ser.name, copy=False)) def __radd__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__radd__(other).rename(ser.name, copy=False)) def __sub__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__sub__(other).rename(ser.name, copy=False)) def __rsub__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__rsub__(other).rename(ser.name, copy=False)) def __mul__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__mul__(other).rename(ser.name, copy=False)) def __rmul__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__rmul__(other).rename(ser.name, copy=False)) def __truediv__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series( ser.__truediv__(other).rename(ser.name, copy=False) ) def __rtruediv__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series( ser.__rtruediv__(other).rename(ser.name, copy=False) ) def __floordiv__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series( ser.__floordiv__(other).rename(ser.name, copy=False) ) def __rfloordiv__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series( ser.__rfloordiv__(other).rename(ser.name, copy=False) ) def __pow__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__pow__(other).rename(ser.name, copy=False)) def __rpow__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__rpow__(other).rename(ser.name, copy=False)) def __mod__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__mod__(other).rename(ser.name, copy=False)) def __rmod__(self, other: Any) -> PandasLikeSeries: ser = self._native_series - other = validate_column_comparand(self._native_series.index, other) + ser, other = validate_column_comparand(self, other) return self._from_native_series(ser.__rmod__(other).rename(ser.name, copy=False)) # Unary @@ -738,9 +738,8 @@ def quantile( return self._native_series.quantile(q=quantile, interpolation=interpolation) def zip_with(self: Self, mask: Any, other: Any) -> PandasLikeSeries: - ser = self._native_series - mask = validate_column_comparand(ser.index, mask) - other = validate_column_comparand(ser.index, other) + ser, mask = validate_column_comparand(self, mask) + _, other = validate_column_comparand(self, other) res = ser.where(mask, other) return self._from_native_series(res) diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 44bfc34b7a..79be11aaa4 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -79,7 +79,7 @@ PATTERN_PA_DURATION = re.compile(PA_DURATION_RGX, re.VERBOSE) -def validate_column_comparand(index: Any, other: Any) -> Any: +def validate_column_comparand(lhs: PandasLikeSeries, rhs: Any) -> tuple[pd.Series, Any]: """Validate RHS of binary operation. If the comparison isn't supported, return `NotImplemented` so that the @@ -91,35 +91,56 @@ def validate_column_comparand(index: Any, other: Any) -> Any: from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.series import PandasLikeSeries - if isinstance(other, list): - if len(other) > 1: - if hasattr(other[0], "__narwhals_expr__") or hasattr( - other[0], "__narwhals_series__" + # If `rhs` is the output of an expression evaluation, then it is + # a list of Series. So, we verify that that list is of length-1, + # and take the first (and only) element. + if isinstance(rhs, list): + if len(rhs) > 1: + if hasattr(rhs[0], "__narwhals_expr__") or hasattr( + rhs[0], "__narwhals_series__" ): # e.g. `plx.all() + plx.all()` msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) are not supported in this context" raise ValueError(msg) - msg = ( - f"Expected scalar value, Series, or Expr, got list of : {type(other[0])}" - ) + msg = f"Expected scalar value, Series, or Expr, got list of : {type(rhs[0])}" raise ValueError(msg) - other = other[0] - if isinstance(other, PandasLikeDataFrame): + rhs = rhs[0] + + lhs_index = lhs._native_series.index + + if isinstance(rhs, PandasLikeDataFrame): return NotImplemented - if isinstance(other, PandasLikeSeries): - if other.len() == 1: + + if isinstance(rhs, PandasLikeSeries): + rhs_index = rhs._native_series.index + if rhs.len() == 1: # broadcast - s = other._native_series - return s.__class__(s.iloc[0], index=index, dtype=s.dtype) - if other._native_series.index is not index: - return set_axis( - other._native_series, - index, - implementation=other._implementation, - backend_version=other._backend_version, + s = rhs._native_series + return ( + lhs._native_series, + s.__class__(s.iloc[0], index=lhs_index, dtype=s.dtype), ) - return other._native_series - return other + if lhs.len() == 1: + # broadcast + s = lhs._native_series + return ( + s.__class__(s.iloc[0], index=rhs_index, dtype=s.dtype, name=s.name), + rhs._native_series, + ) + if rhs._native_series.index is not lhs_index: + return ( + lhs._native_series, + set_axis( + rhs._native_series, + lhs_index, + implementation=rhs._implementation, + backend_version=rhs._backend_version, + ), + ) + return (lhs._native_series, rhs._native_series) + + # `rhs` must be scalar, so just leave it as-is + return lhs._native_series, rhs def validate_dataframe_comparand(index: Any, other: Any) -> Any: diff --git a/narwhals/_polars/series.py b/narwhals/_polars/series.py index a0df340d2f..492913227e 100644 --- a/narwhals/_polars/series.py +++ b/narwhals/_polars/series.py @@ -192,9 +192,11 @@ def __pow__(self, other: PolarsSeries | Any) -> Self: ) def __rpow__(self, other: PolarsSeries | Any) -> Self: - return self._from_native_series( - self._native_series.__rpow__(extract_native(other)) - ) + result = self._native_series.__rpow__(extract_native(other)) + if self._backend_version < (16, 1): + # Explicitly set alias to work around https://github.com/pola-rs/polars/issues/20071 + result = result.alias(self.name) + return self._from_native_series(result) def __invert__(self) -> Self: return self._from_native_series(self._native_series.__invert__()) diff --git a/narwhals/series.py b/narwhals/series.py index 6d4c5e17ca..6c011610ed 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -2037,11 +2037,21 @@ def __and__(self, other: Any) -> Self: self._compliant_series.__and__(self._extract_native(other)) ) + def __rand__(self, other: Any) -> Self: + return self._from_compliant_series( + self._compliant_series.__rand__(self._extract_native(other)) + ) + def __or__(self, other: Any) -> Self: return self._from_compliant_series( self._compliant_series.__or__(self._extract_native(other)) ) + def __ror__(self, other: Any) -> Self: + return self._from_compliant_series( + self._compliant_series.__ror__(self._extract_native(other)) + ) + # unary def __invert__(self) -> Self: return self._from_compliant_series(self._compliant_series.__invert__()) diff --git a/pyproject.toml b/pyproject.toml index 3e43940c19..79ae858d3f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -148,8 +148,9 @@ exclude_also = [ "if (:?self._)?implementation is Implementation.MODIN", "if (:?self._)?implementation is Implementation.CUDF", 'request.applymarker\(pytest.mark.xfail\)', - 'if self._backend_version < ', - 'if "cudf" in str\(constructor', + 'if \w+._backend_version < ', + 'if backend_version <', + 'if "cudf" in str\(constructor' ] [tool.mypy] diff --git a/tests/expr_and_series/arithmetic_test.py b/tests/expr_and_series/arithmetic_test.py index 95172bd2cb..12f931baa6 100644 --- a/tests/expr_and_series/arithmetic_test.py +++ b/tests/expr_and_series/arithmetic_test.py @@ -74,8 +74,8 @@ def test_right_arithmetic_expr( data = {"a": [1, 2, 3]} df = nw.from_native(constructor(data)) - result = df.select(a=getattr(nw.col("a"), attr)(rhs)) - assert_equal_data(result, {"a": expected}) + result = df.select(getattr(nw.col("a"), attr)(rhs)) + assert_equal_data(result, {"literal": expected}) @pytest.mark.parametrize( @@ -135,8 +135,9 @@ def test_right_arithmetic_series( data = {"a": [1, 2, 3]} df = nw.from_native(constructor_eager(data), eager_only=True) - result = df.select(a=getattr(df["a"], attr)(rhs)) - assert_equal_data(result, {"a": expected}) + result_series = getattr(df["a"], attr)(rhs) + assert result_series.name == "a" + assert_equal_data({"a": result_series}, {"a": expected}) def test_truediv_same_dims( @@ -218,3 +219,65 @@ def test_mod(left: int, right: int) -> None: nw.col("a") % right ) assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("attr", "lhs", "expected"), + [ + ("__add__", nw.lit(1), [2, 3, 5]), + ("__sub__", nw.lit(1), [0, -1, -3]), + ("__mul__", nw.lit(2), [2, 4, 8]), + ("__truediv__", nw.lit(2.0), [2, 1, 0.5]), + ("__truediv__", nw.lit(1), [1, 0.5, 0.25]), + ("__floordiv__", nw.lit(2), [2, 1, 0]), + ("__mod__", nw.lit(3), [0, 1, 3]), + ("__pow__", nw.lit(2), [2, 4, 16]), + ], +) +def test_arithmetic_expr_left_literal( + attr: str, + lhs: Any, + expected: list[Any], + constructor: Constructor, + request: pytest.FixtureRequest, +) -> None: + if attr == "__mod__" and any( + x in str(constructor) for x in ["pandas_pyarrow", "modin"] + ): + request.applymarker(pytest.mark.xfail) + + data = {"a": [1.0, 2, 4]} + df = nw.from_native(constructor(data)) + result = df.select(getattr(lhs, attr)(nw.col("a"))) + assert_equal_data(result, {"literal": expected}) + + +@pytest.mark.parametrize( + ("attr", "lhs", "expected"), + [ + ("__add__", nw.lit(1), [2, 3, 5]), + ("__sub__", nw.lit(1), [0, -1, -3]), + ("__mul__", nw.lit(2), [2, 4, 8]), + ("__truediv__", nw.lit(2.0), [2, 1, 0.5]), + ("__truediv__", nw.lit(1), [1, 0.5, 0.25]), + ("__floordiv__", nw.lit(2), [2, 1, 0]), + ("__mod__", nw.lit(3), [0, 1, 3]), + ("__pow__", nw.lit(2), [2, 4, 16]), + ], +) +def test_arithmetic_series_left_literal( + attr: str, + lhs: Any, + expected: list[Any], + constructor_eager: ConstructorEager, + request: pytest.FixtureRequest, +) -> None: + if attr == "__mod__" and any( + x in str(constructor_eager) for x in ["pandas_pyarrow", "modin"] + ): + request.applymarker(pytest.mark.xfail) + + data = {"a": [1.0, 2, 4]} + df = nw.from_native(constructor_eager(data)) + result = df.select(getattr(lhs, attr)(nw.col("a"))) + assert_equal_data(result, {"literal": expected}) diff --git a/tests/expr_and_series/operators_test.py b/tests/expr_and_series/operators_test.py index 5506e6a8d7..ff01747a60 100644 --- a/tests/expr_and_series/operators_test.py +++ b/tests/expr_and_series/operators_test.py @@ -65,6 +65,25 @@ def test_logic_operators_expr( assert_equal_data(result, {"a": expected}) +@pytest.mark.parametrize( + ("operator", "expected"), + [ + ("__and__", [False, False, False, False]), + ("__rand__", [False, False, False, False]), + ("__or__", [True, True, False, False]), + ("__ror__", [True, True, False, False]), + ], +) +def test_logic_operators_expr_scalar( + constructor: Constructor, operator: str, expected: list[bool] +) -> None: + data = {"a": [True, True, False, False]} + df = nw.from_native(constructor(data)) + + result = df.select(a=getattr(nw.col("a"), operator)(False)) # noqa: FBT003 + assert_equal_data(result, {"a": expected}) + + @pytest.mark.parametrize( ("operator", "expected"), [ @@ -110,7 +129,9 @@ def test_comparand_operators_series( ("operator", "expected"), [ ("__and__", [True, False, False, False]), + ("__rand__", [True, False, False, False]), ("__or__", [True, True, True, False]), + ("__ror__", [True, True, True, False]), ], ) def test_logic_operators_series( diff --git a/tests/frame/lit_test.py b/tests/frame/lit_test.py index 2b0332a5e5..f51bd5c76a 100644 --- a/tests/frame/lit_test.py +++ b/tests/frame/lit_test.py @@ -62,3 +62,33 @@ def test_lit_out_name(constructor: Constructor) -> None: "literal": [2, 2, 2], } assert_equal_data(result, expected) + + +@pytest.mark.parametrize( + ("col_name", "expr", "expected_result"), + [ + ("left_lit", nw.lit(1) + nw.col("a"), [2, 4, 3]), + ("right_lit", nw.col("a") + nw.lit(1), [2, 4, 3]), + ("left_lit_with_agg", nw.lit(1) + nw.col("a").mean(), [3]), + ("right_lit_with_agg", nw.col("a").mean() - nw.lit(1), [1]), + ("left_scalar", 1 + nw.col("a"), [2, 4, 3]), + ("right_scalar", nw.col("a") + 1, [2, 4, 3]), + ("left_scalar_with_agg", 1 + nw.col("a").mean(), [3]), + ("right_scalar_with_agg", nw.col("a").mean() - 1, [1]), + ], +) +def test_lit_operation( + constructor: Constructor, + col_name: str, + expr: nw.Expr, + expected_result: list[int], + request: pytest.FixtureRequest, +) -> None: + if "dask_lazy_p2" in str(constructor) and "lit_with_agg" in col_name: + request.applymarker(pytest.mark.xfail) + data = {"a": [1, 3, 2]} + df_raw = constructor(data) + df = nw.from_native(df_raw).lazy() + result = df.select(expr.alias(col_name)) + expected = {col_name: expected_result} + assert_equal_data(result, expected) diff --git a/tests/utils.py b/tests/utils.py index 90143959dc..f9b493add2 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -62,7 +62,7 @@ def assert_equal_data(result: Any, expected: dict[str, Any]) -> None: result = result.collect() if hasattr(result, "columns"): for key in result.columns: - assert key in expected + assert key in expected, (key, expected) result = {key: _to_comparable_list(result[key]) for key in expected} for key in expected: result_key = result[key]