diff --git a/src/tests/test_polars.py b/src/tests/test_polars.py index 357cb91e3..820ca8725 100644 --- a/src/tests/test_polars.py +++ b/src/tests/test_polars.py @@ -7,7 +7,7 @@ from zoneinfo import ZoneInfo from hypothesis import given -from hypothesis.strategies import sampled_from +from hypothesis.strategies import fixed_dictionaries, floats, none, sampled_from from polars import ( Boolean, DataFrame, @@ -29,7 +29,9 @@ from polars.testing import assert_frame_equal, assert_series_equal from pytest import mark, param, raises +from utilities.hypothesis import int64s, text_ascii from utilities.polars import ( + AppendDataClassError, CheckPolarsDataFrameError, CheckZonedDTypeOrSeriesError, ColumnsToDictError, @@ -49,6 +51,7 @@ _check_polars_dataframe_schema_set, _check_polars_dataframe_schema_subset, _yield_struct_series_element_remove_nulls, + append_dataclass, ceil_datetime, check_polars_dataframe, check_zoned_dtype_or_series, @@ -70,6 +73,7 @@ yield_struct_series_elements, zoned_datetime, ) +from utilities.types import StrMapping from utilities.zoneinfo import ( UTC, HongKong, @@ -80,6 +84,45 @@ ) +class TestAppendDataClass: + @given(data=fixed_dictionaries({"a": int64s() | none(), "b": floats() | none()})) + def test_main(self, *, data: StrMapping) -> None: + schema = {"a": Int64, "b": Float64, "c": Utf8} + df = DataFrame([], schema=schema, orient="row") + + @dataclass(kw_only=True) + class Row: + a: int | None = None + b: float | None = None + + row = Row(**data) + result = append_dataclass(df, row) + check_polars_dataframe(result, height=1, schema_list=schema) + + @given( + data=fixed_dictionaries({ + "a": int64s() | none(), + "b": floats() | none(), + "c": text_ascii() | none(), + }) + ) + def test_error(self, *, data: StrMapping) -> None: + df = DataFrame([], schema={"a": Int64, "b": Float64}, orient="row") + + @dataclass(kw_only=True) + class Row: + a: int | None = None + b: float | None = None + c: str | None = None + + row = Row(**data) + with raises( + AppendDataClassError, + match="Dataclass fields .* must be a subset of DataFrame columns .*; dataclass had extra items .*", + ): + _ = append_dataclass(df, row) + + class TestCeilDatetime: start: ClassVar[dt.datetime] = dt.datetime(2000, 1, 1, 0, 0, tzinfo=UTC) end: ClassVar[dt.datetime] = dt.datetime(2000, 1, 1, 0, 3, tzinfo=UTC) diff --git a/src/utilities/__init__.py b/src/utilities/__init__.py index d8f91625e..b5c2e7e31 100644 --- a/src/utilities/__init__.py +++ b/src/utilities/__init__.py @@ -1,3 +1,3 @@ from __future__ import annotations -__version__ = "0.57.1" +__version__ = "0.57.2" diff --git a/src/utilities/polars.py b/src/utilities/polars.py index f7cc0e4ab..4e4115daf 100644 --- a/src/utilities/polars.py +++ b/src/utilities/polars.py @@ -1,6 +1,7 @@ from __future__ import annotations import datetime as dt +import reprlib from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from collections.abc import Set as AbstractSet from contextlib import suppress @@ -12,6 +13,7 @@ from typing import ( TYPE_CHECKING, Any, + Generic, Literal, Never, TypeGuard, @@ -36,6 +38,7 @@ Utf8, all_horizontal, col, + concat, lit, struct, when, @@ -54,15 +57,17 @@ from polars.testing import assert_frame_equal from typing_extensions import override -from utilities.dataclasses import Dataclass, is_dataclass_class +from utilities.dataclasses import Dataclass, is_dataclass_class, yield_field_names from utilities.errors import redirect_error from utilities.iterables import ( CheckIterablesEqualError, CheckMappingsEqualError, + CheckSubSetError, CheckSuperMappingError, MaybeIterable, check_iterables_equal, check_mappings_equal, + check_subset, check_supermapping, is_iterable_not_str, one, @@ -84,6 +89,7 @@ from zoneinfo import ZoneInfo +_T = TypeVar("_T") DatetimeHongKong = Datetime(time_zone="Asia/Hong_Kong") DatetimeTokyo = Datetime(time_zone="Asia/Tokyo") DatetimeUSCentral = Datetime(time_zone="US/Central") @@ -91,6 +97,30 @@ DatetimeUTC = Datetime(time_zone="UTC") +def append_dataclass(df: DataFrame, obj: Dataclass, /) -> DataFrame: + """Append a dataclass object to a DataFrame.""" + fields = yield_field_names(obj) + try: + check_subset(fields, df.columns) + except CheckSubSetError as error: + raise AppendDataClassError( + left=error.left, right=error.right, extra=error.extra + ) from None + row = DataFrame([obj], orient="row") + return concat([df, row], how="diagonal") + + +@dataclass(kw_only=True) +class AppendDataClassError(Exception, Generic[_T]): + left: AbstractSet[_T] + right: AbstractSet[_T] + extra: AbstractSet[_T] + + @override + def __str__(self) -> str: + return f"Dataclass fields {reprlib.repr(self.left)} must be a subset of DataFrame columns {reprlib.repr(self.right)}; dataclass had extra items {reprlib.repr(self.extra)}" + + @overload def ceil_datetime(column: Expr | str, every: Expr | str, /) -> Expr: ... @overload @@ -863,6 +893,7 @@ def zoned_datetime( "IsNullStructSeriesError", "SetFirstRowAsColumnsError", "YieldStructSeriesElementsError", + "append_dataclass", "ceil_datetime", "check_polars_dataframe", "check_zoned_dtype_or_series",