From 5080420b45ce079b4b782a050ca5e600530516a1 Mon Sep 17 00:00:00 2001 From: ritchie Date: Sat, 21 Dec 2024 14:31:13 +0100 Subject: [PATCH 1/2] fix: Fix decimal series dispatch --- py-polars/polars/series/series.py | 28 ++++++++++++++++--- .../tests/unit/datatypes/test_decimal.py | 23 +++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 787af72a92bc..13a4596d5852 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -1035,13 +1035,15 @@ def __add__(self, other: Expr) -> Expr: ... @overload def __add__(self, other: Any) -> Self: ... - def __add__(self, other: Any) -> Self | DataFrame | Expr: + def __add__(self, other: Any) -> Series | DataFrame | Expr: if isinstance(other, str): other = Series("", [other]) elif isinstance(other, pl.DataFrame): return other + self elif isinstance(other, pl.Expr): return F.lit(self) + other + if self.dtype.is_decimal() and isinstance(other, (float, int)): + return self.to_frame().select(F.col(self.name) + other).to_series() return self._arithmetic(other, "add", "add_<>") @overload @@ -1050,9 +1052,11 @@ def __sub__(self, other: Expr) -> Expr: ... @overload def __sub__(self, other: Any) -> Self: ... - def __sub__(self, other: Any) -> Self | Expr: + def __sub__(self, other: Any) -> Series | Expr: if isinstance(other, pl.Expr): return F.lit(self) - other + if self.dtype.is_decimal() and isinstance(other, (float, int)): + return self.to_frame().select(F.col(self.name) - other).to_series() return self._arithmetic(other, "sub", "sub_<>") def _recursive_cast_to_dtype(self, leaf_dtype: PolarsDataType) -> Series: @@ -1083,6 +1087,8 @@ def __truediv__(self, other: Any) -> Series | Expr: if self.dtype.is_temporal(): msg = "first cast to integer before dividing datelike dtypes" raise TypeError(msg) + if self.dtype.is_decimal() and isinstance(other, (float, int)): + return self.to_frame().select(F.col(self.name) / other).to_series() self = ( self @@ -1111,6 +1117,8 @@ def __floordiv__(self, other: Any) -> Series | Expr: if self.dtype.is_temporal(): msg = "first cast to integer before dividing datelike dtypes" raise TypeError(msg) + if self.dtype.is_decimal() and isinstance(other, (float, int)): + return self.to_frame().select(F.col(self.name) // other).to_series() if not isinstance(other, pl.Expr): other = F.lit(other) @@ -1134,6 +1142,8 @@ def __mul__(self, other: Any) -> Series | DataFrame | Expr: if self.dtype.is_temporal(): msg = "first cast to integer before multiplying datelike dtypes" raise TypeError(msg) + if self.dtype.is_decimal() and isinstance(other, (float, int)): + return self.to_frame().select(F.col(self.name) * other).to_series() elif isinstance(other, pl.DataFrame): return other * self else: @@ -1151,6 +1161,8 @@ def __mod__(self, other: Any) -> Series | Expr: if self.dtype.is_temporal(): msg = "first cast to integer before applying modulo on datelike dtypes" raise TypeError(msg) + if self.dtype.is_decimal() and isinstance(other, (float, int)): + return self.to_frame().select(F.col(self.name) % other).to_series() return self._arithmetic(other, "rem", "rem_<>") def __rmod__(self, other: Any) -> Series: @@ -1160,11 +1172,15 @@ def __rmod__(self, other: Any) -> Series: return self._arithmetic(other, "rem", "rem_<>_rhs") def __radd__(self, other: Any) -> Series: - if isinstance(other, str): - return (other + self.to_frame()).to_series() + if isinstance(other, str) or ( + isinstance(other, (int, float)) and self.dtype.is_decimal() + ): + return self.to_frame().select(other + F.col(self.name)).to_series() return self._arithmetic(other, "add", "add_<>_rhs") def __rsub__(self, other: Any) -> Series: + if isinstance(other, (int, float)) and self.dtype.is_decimal(): + return self.to_frame().select(other - F.col(self.name)).to_series() return self._arithmetic(other, "sub", "sub_<>_rhs") def __rtruediv__(self, other: Any) -> Series: @@ -1173,6 +1189,8 @@ def __rtruediv__(self, other: Any) -> Series: raise TypeError(msg) if self.dtype.is_float(): self.__rfloordiv__(other) + if isinstance(other, (int, float)) and self.dtype.is_decimal(): + return self.to_frame().select(other / F.col(self.name)).to_series() if isinstance(other, int): other = float(other) @@ -1188,6 +1206,8 @@ def __rmul__(self, other: Any) -> Series: if self.dtype.is_temporal(): msg = "first cast to integer before multiplying datelike dtypes" raise TypeError(msg) + if isinstance(other, (int, float)) and self.dtype.is_decimal(): + return self.to_frame().select(other * F.col(self.name)).to_series() return self._arithmetic(other, "mul", "mul_<>") def __pow__(self, exponent: int | float | Series) -> Series: diff --git a/py-polars/tests/unit/datatypes/test_decimal.py b/py-polars/tests/unit/datatypes/test_decimal.py index 7f259708d649..94576b08869b 100644 --- a/py-polars/tests/unit/datatypes/test_decimal.py +++ b/py-polars/tests/unit/datatypes/test_decimal.py @@ -553,3 +553,26 @@ def test_decimal_arithmetic_schema() -> None: assert q1.collect_schema() == q1.collect().schema q1 = q.select(pl.col.x + pl.col.x) assert q1.collect_schema() == q1.collect().schema + + +def test_decimal_arithmetic_schema_float_20369() -> None: + s = pl.Series("x", [1.0], dtype=pl.Decimal(15, 2)) + assert_series_equal((s - 1.0), pl.Series("x", [0.0], dtype=pl.Decimal(None, 2))) + assert_series_equal( + (3.0 - s), pl.Series("literal", [2.0], dtype=pl.Decimal(None, 2)) + ) + assert_series_equal( + (3.0 / s), pl.Series("literal", [3.0], dtype=pl.Decimal(None, 6)) + ) + assert_series_equal( + (s / 3.0), pl.Series("x", [0.333333], dtype=pl.Decimal(None, 6)) + ) + + assert_series_equal((s + 1.0), pl.Series("x", [2.0], dtype=pl.Decimal(None, 2))) + assert_series_equal( + (1.0 + s), pl.Series("literal", [2.0], dtype=pl.Decimal(None, 2)) + ) + assert_series_equal((s * 1.0), pl.Series("x", [1.0], dtype=pl.Decimal(None, 4))) + assert_series_equal( + (1.0 * s), pl.Series("literal", [1.0], dtype=pl.Decimal(None, 4)) + ) From af40df2d782e0af767e62e16ff0fd3fab3c19202 Mon Sep 17 00:00:00 2001 From: ritchie Date: Sat, 21 Dec 2024 14:51:58 +0100 Subject: [PATCH 2/2] fix test --- py-polars/tests/unit/dataframe/test_df.py | 1 + py-polars/tests/unit/series/test_series.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index 565fa8466e56..bfb58cdb1bd2 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -2038,6 +2038,7 @@ def test_add_string() -> None: expected = pl.DataFrame( {"a": ["hello hi", "hello there"], "b": ["hello hello", "hello world"]} ) + print(expected) assert_frame_equal(("hello " + df), expected) diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index 4c7302be9f15..5be458c1a2a0 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -371,10 +371,11 @@ def test_categorical_agg(s: pl.Series, min: str | None, max: str | None) -> None def test_add_string() -> None: s = pl.Series(["hello", "weird"]) result = s + " world" + print(result) assert_series_equal(result, pl.Series(["hello world", "weird world"])) result = "pfx:" + s - assert_series_equal(result, pl.Series(["pfx:hello", "pfx:weird"])) + assert_series_equal(result, pl.Series("literal", ["pfx:hello", "pfx:weird"])) @pytest.mark.parametrize(