Skip to content

Commit 7a91bc5

Browse files
committed
should report all missing columns rightaway
1 parent 6f95a56 commit 7a91bc5

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

daffy/decorators.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,33 @@
1313

1414

1515
def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None:
16+
missing_columns = []
17+
dtype_mismatches = []
18+
1619
if isinstance(columns, list):
1720
for column in columns:
18-
assert column in df.columns, f"Column {column} missing from DataFrame. Got {_describe_pd(df)}"
21+
if column not in df.columns:
22+
missing_columns.append(column)
1923
if isinstance(columns, dict):
2024
for column, dtype in columns.items():
21-
assert column in df.columns, f"Column {column} missing from DataFrame. Got {_describe_pd(df)}"
22-
assert df[column].dtype == dtype, (
23-
f"Column {column} has wrong dtype. Was {df[column].dtype}, expected {dtype}"
24-
)
25-
if strict:
26-
assert len(df.columns) == len(columns), (
27-
f"DataFrame contained unexpected column(s): {', '.join(set(df.columns) - set(columns))}"
25+
if column not in df.columns:
26+
missing_columns.append(column)
27+
elif df[column].dtype != dtype:
28+
dtype_mismatches.append((column, df[column].dtype, dtype))
29+
30+
if missing_columns:
31+
raise AssertionError(f"Missing columns: {missing_columns}. Got {_describe_pd(df)}")
32+
33+
if dtype_mismatches:
34+
mismatches = ", ".join(
35+
[f"Column {col} has wrong dtype. Was {was}, expected {expected}" for col, was, expected in dtype_mismatches]
2836
)
37+
raise AssertionError(mismatches)
38+
39+
if strict:
40+
extra_columns = set(df.columns) - set(columns)
41+
if extra_columns:
42+
raise AssertionError(f"DataFrame contained unexpected column(s): {', '.join(extra_columns)}")
2943

3044

3145
def df_out(columns: Optional[ColumnsDef] = None, strict: bool = False) -> Callable:

tests/test_decorators.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def test_fn() -> DataFrameType:
107107
with pytest.raises(AssertionError) as excinfo:
108108
test_fn()
109109

110-
assert "Column FooColumn missing" in str(excinfo.value)
110+
assert "Missing columns: ['FooColumn']. Got columns: ['Brand', 'Price']" in str(excinfo.value)
111111

112112

113113
def test_wrong_input_type_unnamed() -> None:
@@ -255,7 +255,17 @@ def test_fn(my_input: Any) -> Any:
255255

256256
with pytest.raises(AssertionError) as excinfo:
257257
test_fn(df[["Brand"]])
258-
assert "Column Price missing" in str(excinfo.value)
258+
assert "Missing columns: ['Price']. Got columns: ['Brand']" in str(excinfo.value)
259+
260+
@pytest.mark.parametrize(("df"), [pd.DataFrame(cars), pl.DataFrame(cars)])
261+
def test_df_in_missing_multiple_columns(df: DataFrameType) -> None:
262+
@df_in(columns=["Brand", "Price", "Extra"])
263+
def test_fn(my_input: Any) -> Any:
264+
return my_input
265+
266+
with pytest.raises(AssertionError) as excinfo:
267+
test_fn(df[["Brand"]])
268+
assert "Missing columns: ['Price', 'Extra']. Got columns: ['Brand']" in str(excinfo.value)
259269

260270

261271
def test_df_out_with_df_modification(basic_pandas_df: pd.DataFrame, extended_pandas_df: pd.DataFrame) -> None:

0 commit comments

Comments
 (0)