Skip to content

Commit

Permalink
fix: improve validate columns (#829)
Browse files Browse the repository at this point in the history
* improve validate columns

* skipif

* skipif
  • Loading branch information
MarcoGorelli authored Aug 20, 2024
1 parent 61c9b9b commit 39c1787
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 32 deletions.
21 changes: 11 additions & 10 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import collections
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
Expand All @@ -26,6 +25,7 @@
from narwhals.utils import parse_columns_to_drop

if TYPE_CHECKING:
import pandas as pd
from typing_extensions import Self

from narwhals._pandas_like.group_by import PandasLikeGroupBy
Expand Down Expand Up @@ -73,15 +73,16 @@ def __native_namespace__(self) -> Any:
def __len__(self) -> int:
return len(self._native_frame)

def _validate_columns(self, columns: Sequence[str]) -> None:
if len(columns) != len(set(columns)):
counter = collections.Counter(columns)
for col, count in counter.items():
if count > 1:
msg = f"Expected unique column names, got {col!r} {count} time(s)"
raise ValueError(msg)
msg = "Please report a bug" # pragma: no cover
raise AssertionError(msg)
def _validate_columns(self, columns: pd.Index) -> None:
try:
len_unique_columns = len(columns.drop_duplicates())
except Exception: # noqa: BLE001 # pragma: no cover
msg = f"Expected hashable (e.g. str or int) column names, got: {columns}"
raise ValueError(msg) from None

if len(columns) != len_unique_columns:
msg = f"Expected unique column names, got: {columns}"
raise ValueError(msg)

def _from_native_frame(self, df: Any) -> Self:
return self.__class__(
Expand Down
28 changes: 28 additions & 0 deletions tests/frame/schema_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,31 @@ def test_unknown_dtype_polars() -> None:

def test_hash() -> None:
assert nw.Int64() in {nw.Int64, nw.Int32}


@pytest.mark.parametrize(
("method", "expected"),
[
("names", ["a", "b", "c"]),
("dtypes", [nw.Int64(), nw.Float32(), nw.String()]),
("len", 3),
],
)
def test_schema_object(method: str, expected: Any) -> None:
data = {"a": nw.Int64(), "b": nw.Float32(), "c": nw.String()}
schema = nw.Schema(data)
assert getattr(schema, method)() == expected


@pytest.mark.skipif(
parse_version(pd.__version__) < (2,),
reason="Before 2.0, pandas would raise on `drop_duplicates`",
)
def test_from_non_hashable_column_name() -> None:
# This is technically super-illegal
# BUT, it shows up in a scikit-learn test, so...
df = pd.DataFrame([[1, 2], [3, 4]], columns=["pizza", ["a", "b"]])

df = nw.from_native(df, eager_only=True)
assert df.columns == ["pizza", ["a", "b"]]
assert df["pizza"].dtype == nw.Int64
22 changes: 0 additions & 22 deletions tests/test_schema.py

This file was deleted.

0 comments on commit 39c1787

Please sign in to comment.