From 39c178786c04f50f396f542a77f45a0bdbe23d4f Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 20 Aug 2024 21:27:34 +0100 Subject: [PATCH] fix: improve validate columns (#829) * improve validate columns * skipif * skipif --- narwhals/_pandas_like/dataframe.py | 21 +++++++++++---------- tests/frame/schema_test.py | 28 ++++++++++++++++++++++++++++ tests/test_schema.py | 22 ---------------------- 3 files changed, 39 insertions(+), 32 deletions(-) delete mode 100644 tests/test_schema.py diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 27a725100..4faf61a08 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -1,6 +1,5 @@ from __future__ import annotations -import collections from typing import TYPE_CHECKING from typing import Any from typing import Iterable @@ -26,6 +25,7 @@ from narwhals.utils import parse_columns_to_drop if TYPE_CHECKING: + import pandas as pd from typing_extensions import Self from narwhals._pandas_like.group_by import PandasLikeGroupBy @@ -73,15 +73,16 @@ def __native_namespace__(self) -> Any: def __len__(self) -> int: return len(self._native_frame) - def _validate_columns(self, columns: Sequence[str]) -> None: - if len(columns) != len(set(columns)): - counter = collections.Counter(columns) - for col, count in counter.items(): - if count > 1: - msg = f"Expected unique column names, got {col!r} {count} time(s)" - raise ValueError(msg) - msg = "Please report a bug" # pragma: no cover - raise AssertionError(msg) + def _validate_columns(self, columns: pd.Index) -> None: + try: + len_unique_columns = len(columns.drop_duplicates()) + except Exception: # noqa: BLE001 # pragma: no cover + msg = f"Expected hashable (e.g. str or int) column names, got: {columns}" + raise ValueError(msg) from None + + if len(columns) != len_unique_columns: + msg = f"Expected unique column names, got: {columns}" + raise ValueError(msg) def _from_native_frame(self, df: Any) -> Self: return self.__class__( diff --git a/tests/frame/schema_test.py b/tests/frame/schema_test.py index e1ba5afda..6e6b33aa1 100644 --- a/tests/frame/schema_test.py +++ b/tests/frame/schema_test.py @@ -170,3 +170,31 @@ def test_unknown_dtype_polars() -> None: def test_hash() -> None: assert nw.Int64() in {nw.Int64, nw.Int32} + + +@pytest.mark.parametrize( + ("method", "expected"), + [ + ("names", ["a", "b", "c"]), + ("dtypes", [nw.Int64(), nw.Float32(), nw.String()]), + ("len", 3), + ], +) +def test_schema_object(method: str, expected: Any) -> None: + data = {"a": nw.Int64(), "b": nw.Float32(), "c": nw.String()} + schema = nw.Schema(data) + assert getattr(schema, method)() == expected + + +@pytest.mark.skipif( + parse_version(pd.__version__) < (2,), + reason="Before 2.0, pandas would raise on `drop_duplicates`", +) +def test_from_non_hashable_column_name() -> None: + # This is technically super-illegal + # BUT, it shows up in a scikit-learn test, so... + df = pd.DataFrame([[1, 2], [3, 4]], columns=["pizza", ["a", "b"]]) + + df = nw.from_native(df, eager_only=True) + assert df.columns == ["pizza", ["a", "b"]] + assert df["pizza"].dtype == nw.Int64 diff --git a/tests/test_schema.py b/tests/test_schema.py deleted file mode 100644 index f85fdd816..000000000 --- a/tests/test_schema.py +++ /dev/null @@ -1,22 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest - -import narwhals.stable.v1 as nw - -data = {"a": nw.Int64(), "b": nw.Float32(), "c": nw.String()} - - -@pytest.mark.parametrize( - ("method", "expected"), - [ - ("names", ["a", "b", "c"]), - ("dtypes", [nw.Int64(), nw.Float32(), nw.String()]), - ("len", 3), - ], -) -def test_schema_object(method: str, expected: Any) -> None: - schema = nw.Schema(data) - assert getattr(schema, method)() == expected