|
13 | 13 |
|
14 | 14 |
|
15 | 15 | def _check_columns(df: DataFrameType, columns: ColumnsDef, strict: bool) -> None:
|
| 16 | + missing_columns = [] |
| 17 | + dtype_mismatches = [] |
| 18 | + |
16 | 19 | if isinstance(columns, list):
|
17 | 20 | for column in columns:
|
18 |
| - assert column in df.columns, f"Column {column} missing from DataFrame. Got {_describe_pd(df)}" |
| 21 | + if column not in df.columns: |
| 22 | + missing_columns.append(column) |
19 | 23 | if isinstance(columns, dict):
|
20 | 24 | for column, dtype in columns.items():
|
21 |
| - assert column in df.columns, f"Column {column} missing from DataFrame. Got {_describe_pd(df)}" |
22 |
| - assert df[column].dtype == dtype, ( |
23 |
| - f"Column {column} has wrong dtype. Was {df[column].dtype}, expected {dtype}" |
24 |
| - ) |
25 |
| - if strict: |
26 |
| - assert len(df.columns) == len(columns), ( |
27 |
| - f"DataFrame contained unexpected column(s): {', '.join(set(df.columns) - set(columns))}" |
| 25 | + if column not in df.columns: |
| 26 | + missing_columns.append(column) |
| 27 | + elif df[column].dtype != dtype: |
| 28 | + dtype_mismatches.append((column, df[column].dtype, dtype)) |
| 29 | + |
| 30 | + if missing_columns: |
| 31 | + raise AssertionError(f"Missing columns: {missing_columns}. Got {_describe_pd(df)}") |
| 32 | + |
| 33 | + if dtype_mismatches: |
| 34 | + mismatches = ", ".join( |
| 35 | + [f"Column {col} has wrong dtype. Was {was}, expected {expected}" for col, was, expected in dtype_mismatches] |
28 | 36 | )
|
| 37 | + raise AssertionError(mismatches) |
| 38 | + |
| 39 | + if strict: |
| 40 | + extra_columns = set(df.columns) - set(columns) |
| 41 | + if extra_columns: |
| 42 | + raise AssertionError(f"DataFrame contained unexpected column(s): {', '.join(extra_columns)}") |
29 | 43 |
|
30 | 44 |
|
31 | 45 | def df_out(columns: Optional[ColumnsDef] = None, strict: bool = False) -> Callable:
|
|
0 commit comments