Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: address lit broadcasting and output name of right arithmetic ops #1424

Merged
merged 21 commits into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,55 +151,64 @@ def __and__(self: Self, other: ArrowExpr | bool | Any) -> Self:
return reuse_series_implementation(self, "__and__", other=other)

def __rand__(self: Self, other: ArrowExpr | bool | Any) -> Self:
return reuse_series_implementation(self, "__rand__", other=other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__and__(self) # type: ignore[return-value]

def __or__(self: Self, other: ArrowExpr | bool | Any) -> Self:
return reuse_series_implementation(self, "__or__", other=other)

def __ror__(self: Self, other: ArrowExpr | bool | Any) -> Self:
return reuse_series_implementation(self, "__ror__", other=other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__or__(self) # type: ignore[return-value]

def __add__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__add__", other)

def __radd__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__radd__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__add__(self) # type: ignore[return-value]

def __sub__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__sub__", other)

def __rsub__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rsub__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__sub__(self) # type: ignore[return-value]

def __mul__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__mul__", other)

def __rmul__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rmul__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__mul__(self) # type: ignore[return-value]

def __pow__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__pow__", other)

def __rpow__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rpow__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__pow__(self) # type: ignore[return-value]

def __floordiv__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__floordiv__", other)

def __rfloordiv__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rfloordiv__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__floordiv__(self) # type: ignore[return-value]

def __truediv__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__truediv__", other)

def __rtruediv__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rtruediv__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__truediv__(self) # type: ignore[return-value]

def __mod__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__mod__", other)

def __rmod__(self: Self, other: ArrowExpr | Any) -> Self:
return reuse_series_implementation(self, "__rmod__", other)
other = self.__narwhals_namespace__().lit(other, dtype=None)
return other.__mod__(self) # type: ignore[return-value]

def __invert__(self: Self) -> Self:
return reuse_series_implementation(self, "__invert__")
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
depth=0,
function_name="lit",
root_names=None,
output_names=[_lit_arrow_series.__name__],
output_names=["literal"],
backend_version=self._backend_version,
dtypes=self._dtypes,
)
Expand Down
81 changes: 33 additions & 48 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,130 +96,115 @@ def __len__(self: Self) -> int:
def __eq__(self: Self, other: object) -> Self: # type: ignore[override]
import pyarrow.compute as pc

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.equal(ser, other))

def __ne__(self: Self, other: object) -> Self: # type: ignore[override]
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.not_equal(ser, other))

def __ge__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.greater_equal(ser, other))

def __gt__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.greater(ser, other))

def __le__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.less_equal(ser, other))

def __lt__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.less(ser, other))

def __and__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.and_kleene(ser, other))

def __rand__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.and_kleene(other, ser))

def __or__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.or_kleene(ser, other))

def __ror__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.or_kleene(other, ser))

def __add__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

other = validate_column_comparand(other)
return self._from_native_series(pc.add(self._native_series, other))
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.add(ser, other))

def __radd__(self: Self, other: Any) -> Self:
return self + other # type: ignore[no-any-return]

def __sub__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

other = validate_column_comparand(other)
return self._from_native_series(pc.subtract(self._native_series, other))
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.subtract(ser, other))

def __rsub__(self: Self, other: Any) -> Self:
return (self - other) * (-1) # type: ignore[no-any-return]

def __mul__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

other = validate_column_comparand(other)
return self._from_native_series(pc.multiply(self._native_series, other))
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.multiply(ser, other))

def __rmul__(self: Self, other: Any) -> Self:
return self * other # type: ignore[no-any-return]

def __pow__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.power(ser, other))

def __rpow__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(pc.power(other, ser))

def __floordiv__(self: Self, other: Any) -> Self:
ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(floordiv_compat(ser, other))

def __rfloordiv__(self: Self, other: Any) -> Self:
ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
return self._from_native_series(floordiv_compat(other, ser))

def __truediv__(self: Self, other: Any) -> Self:
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
if not isinstance(other, (pa.Array, pa.ChunkedArray)):
# scalar
other = pa.scalar(other)
Expand All @@ -229,8 +214,7 @@ def __rtruediv__(self: Self, other: Any) -> Self:
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
ser, other = validate_column_comparand(self, other, self._backend_version)
if not isinstance(other, (pa.Array, pa.ChunkedArray)):
# scalar
other = pa.scalar(other)
Expand All @@ -239,18 +223,16 @@ def __rtruediv__(self: Self, other: Any) -> Self:
def __mod__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
floor_div = (self // other)._native_series
ser, other = validate_column_comparand(self, other, self._backend_version)
res = pc.subtract(ser, pc.multiply(floor_div, other))
return self._from_native_series(res)

def __rmod__(self: Self, other: Any) -> Self:
import pyarrow.compute as pc # ignore-banned-import()

ser = self._native_series
other = validate_column_comparand(other)
floor_div = (other // self)._native_series
ser, other = validate_column_comparand(self, other, self._backend_version)
res = pc.subtract(other, pc.multiply(floor_div, ser))
return self._from_native_series(res)

Expand All @@ -264,8 +246,10 @@ def len(self: Self) -> int:

def filter(self: Self, other: Any) -> Self:
if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)):
other = validate_column_comparand(other)
return self._from_native_series(self._native_series.filter(other))
ser, other = validate_column_comparand(self, other, self._backend_version)
else:
ser = self._native_series
return self._from_native_series(ser.filter(other))

def mean(self: Self) -> int:
import pyarrow.compute as pc # ignore-banned-import()
Expand Down Expand Up @@ -382,16 +366,17 @@ def scatter(self: Self, indices: int | Sequence[int], values: Any) -> Self:
import pyarrow as pa # ignore-banned-import
import pyarrow.compute as pc # ignore-banned-import

ca = self._native_series
mask = np.zeros(len(ca), dtype=bool)
mask = np.zeros(self.len(), dtype=bool)
mask[indices] = True
if isinstance(values, self.__class__):
values = validate_column_comparand(values)
ser, values = validate_column_comparand(self, values, self._backend_version)
else:
ser = self._native_series
if isinstance(values, pa.ChunkedArray):
values = values.combine_chunks()
if not isinstance(values, pa.Array):
values = pa.array(values)
result = pc.replace_with_mask(ca, mask, values.take(indices))
result = pc.replace_with_mask(ser, mask, values.take(indices))
return self._from_native_series(result)

def to_list(self: Self) -> list[Any]:
Expand Down
56 changes: 39 additions & 17 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> pa.D
raise AssertionError(msg)


def validate_column_comparand(other: Any) -> Any:
def validate_column_comparand(
lhs: ArrowSeries, rhs: Any, backend_version: tuple[int, ...]
) -> tuple[pa.ChunkedArray, Any]:
"""Validate RHS of binary operation.

If the comparison isn't supported, return `NotImplemented` so that the
Expand All @@ -140,27 +142,47 @@ def validate_column_comparand(other: Any) -> Any:
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.series import ArrowSeries

if isinstance(other, list):
if len(other) > 1:
if hasattr(other[0], "__narwhals_expr__") or hasattr(
other[0], "__narwhals_series__"
# If `rhs` is the output of an expression evaluation, then it is
# a list of Series. So, we verify that that list is of length-1,
# and take the first (and only) element.
if isinstance(rhs, list):
if len(rhs) > 1:
if hasattr(rhs[0], "__narwhals_expr__") or hasattr(
rhs[0], "__narwhals_series__"
):
# e.g. `plx.all() + plx.all()`
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) are not supported in this context"
raise ValueError(msg)
msg = (
f"Expected scalar value, Series, or Expr, got list of : {type(other[0])}"
)
msg = f"Expected scalar value, Series, or Expr, got list of : {type(rhs[0])}"
raise ValueError(msg)
other = other[0]
if isinstance(other, ArrowDataFrame):
rhs = rhs[0]

if isinstance(rhs, ArrowDataFrame):
return NotImplemented
if isinstance(other, ArrowSeries):
if len(other) == 1:

if isinstance(rhs, ArrowSeries):
if len(rhs) == 1:
# broadcast
return other[0]
return other._native_series
return other
return lhs._native_series, rhs[0]
if len(lhs) == 1:
# broadcast
import numpy as np # ignore-banned-import
import pyarrow as pa # ignore-banned-import

fill_value = lhs[0]
if backend_version < (13,) and hasattr(fill_value, "as_py"):
fill_value = fill_value.as_py()
left_result = pa.chunked_array(
[
pa.array(
np.full(shape=rhs.len(), fill_value=fill_value),
type=lhs._native_series.type,
)
]
)
return left_result, rhs._native_series
return lhs._native_series, rhs._native_series
return lhs._native_series, rhs


def validate_dataframe_comparand(
Expand All @@ -179,7 +201,7 @@ def validate_dataframe_comparand(
import pyarrow as pa # ignore-banned-import

value = other._native_series[0]
if backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover
if backend_version < (13,) and hasattr(value, "as_py"):
value = value.as_py()
return pa.array(np.full(shape=length, fill_value=value))
return other._native_series
Expand Down Expand Up @@ -321,7 +343,7 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]:
s_native = s._native_series
if is_max_length_gt_1 and length == 1:
value = s_native[0]
if s._backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover
if s._backend_version < (13,) and hasattr(value, "as_py"):
value = value.as_py()
reshaped.append(pa.array([value] * max_length, type=s_native.type))
else:
Expand Down
Loading
Loading