diff --git a/src/tests/test_polars.py b/src/tests/test_polars.py index 495defb7c..68d5bb9b7 100644 --- a/src/tests/test_polars.py +++ b/src/tests/test_polars.py @@ -93,10 +93,30 @@ class TestAppendDataClass: + @given( + data=fixed_dictionaries({ + "a": int64s() | none(), + "b": floats() | none(), + "c": text_ascii() | none(), + }) + ) + def test_columns_and_fields_equal(self, *, data: StrMapping) -> None: + df = DataFrame(schema={"a": Int64, "b": Float64, "c": Utf8}) + + @dataclass(kw_only=True) + class Row: + a: int | None = None + b: float | None = None + c: str | None = None + + row = Row(**data) + result = append_dataclass(df, row) + height = 0 if (row.a is None) and (row.b is None) and (row.c is None) else 1 + check_polars_dataframe(result, height=height, schema_list=df.schema) + @given(data=fixed_dictionaries({"a": int64s() | none(), "b": floats() | none()})) - def test_main(self, *, data: StrMapping) -> None: - schema = {"a": Int64, "b": Float64, "c": Utf8} - df = DataFrame([], schema=schema, orient="row") + def test_extra_column(self, *, data: StrMapping) -> None: + df = DataFrame(schema={"a": Int64, "b": Float64, "c": Utf8}) @dataclass(kw_only=True) class Row: @@ -105,12 +125,27 @@ class Row: row = Row(**data) result = append_dataclass(df, row) - check_polars_dataframe(result, height=1, schema_list=schema) + height = 0 if (row.a is None) and (row.b is None) else 1 + check_polars_dataframe(result, height=height, schema_list=df.schema) + + @given(data=fixed_dictionaries({"a": int64s() | none(), "b": floats() | none()})) + def test_extra_field_but_none(self, *, data: StrMapping) -> None: + df = DataFrame(schema={"a": Int64, "b": Float64}) + + @dataclass(kw_only=True) + class Row: + a: int | None = None + b: float | None = None + c: str | None = None + + row = Row(**data) + result = append_dataclass(df, row) + height = 0 if (row.a is None) and (row.b is None) else 1 + check_polars_dataframe(result, height=height, schema_list=df.schema) @given(data=fixed_dictionaries({"datetime": zoned_datetimes()})) def test_zoned_datetime(self, *, data: StrMapping) -> None: - schema = {"datetime": DatetimeUTC} - df = DataFrame([], schema=schema, orient="row") + df = DataFrame(schema={"datetime": DatetimeUTC}) @dataclass(kw_only=True) class Row: @@ -118,23 +153,23 @@ class Row: row = Row(**data) result = append_dataclass(df, row) - check_polars_dataframe(result, height=1, schema_list=schema) + check_polars_dataframe(result, height=1, schema_list=df.schema) @given( data=fixed_dictionaries({ "a": int64s() | none(), "b": floats() | none(), - "c": text_ascii() | none(), + "c": text_ascii(), }) ) def test_error(self, *, data: StrMapping) -> None: - df = DataFrame([], schema={"a": Int64, "b": Float64}, orient="row") + df = DataFrame(schema={"a": Int64, "b": Float64}) @dataclass(kw_only=True) class Row: a: int | None = None b: float | None = None - c: str | None = None + c: str row = Row(**data) with raises( diff --git a/src/utilities/__init__.py b/src/utilities/__init__.py index c28d54135..c3dc96e8e 100644 --- a/src/utilities/__init__.py +++ b/src/utilities/__init__.py @@ -1,3 +1,3 @@ from __future__ import annotations -__version__ = "0.57.3" +__version__ = "0.57.4" diff --git a/src/utilities/polars.py b/src/utilities/polars.py index 87b7a410c..09de91281 100644 --- a/src/utilities/polars.py +++ b/src/utilities/polars.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from collections.abc import Set as AbstractSet from contextlib import suppress -from dataclasses import dataclass +from dataclasses import asdict, dataclass from datetime import timezone from enum import Enum from functools import partial, reduce @@ -57,7 +57,7 @@ from polars.testing import assert_frame_equal from typing_extensions import override -from utilities.dataclasses import Dataclass, is_dataclass_class, yield_field_names +from utilities.dataclasses import Dataclass, is_dataclass_class from utilities.errors import redirect_error from utilities.iterables import ( CheckIterablesEqualError, @@ -100,14 +100,15 @@ def append_dataclass(df: DataFrame, obj: Dataclass, /) -> DataFrame: """Append a dataclass object to a DataFrame.""" - fields = yield_field_names(obj) + non_null_fields = {k: v for k, v in asdict(obj).items() if v is not None} try: - check_subset(fields, df.columns) + check_subset(non_null_fields, df.columns) except CheckSubSetError as error: raise AppendDataClassError( left=error.left, right=error.right, extra=error.extra ) from None - row = dataclass_to_row(obj) + row_cols = set(df.columns) & set(non_null_fields) + row = dataclass_to_row(obj).select(*row_cols) return concat([df, row], how="diagonal")