Skip to content

Commit

Permalink
Add append_dataclass (#745)
Browse files Browse the repository at this point in the history
  • Loading branch information
dycw authored Sep 18, 2024
1 parent 50f3772 commit 1399499
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 3 deletions.
45 changes: 44 additions & 1 deletion src/tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from zoneinfo import ZoneInfo

from hypothesis import given
from hypothesis.strategies import sampled_from
from hypothesis.strategies import fixed_dictionaries, floats, none, sampled_from
from polars import (
Boolean,
DataFrame,
Expand All @@ -29,7 +29,9 @@
from polars.testing import assert_frame_equal, assert_series_equal
from pytest import mark, param, raises

from utilities.hypothesis import int64s, text_ascii
from utilities.polars import (
AppendDataClassError,
CheckPolarsDataFrameError,
CheckZonedDTypeOrSeriesError,
ColumnsToDictError,
Expand All @@ -49,6 +51,7 @@
_check_polars_dataframe_schema_set,
_check_polars_dataframe_schema_subset,
_yield_struct_series_element_remove_nulls,
append_dataclass,
ceil_datetime,
check_polars_dataframe,
check_zoned_dtype_or_series,
Expand All @@ -70,6 +73,7 @@
yield_struct_series_elements,
zoned_datetime,
)
from utilities.types import StrMapping
from utilities.zoneinfo import (
UTC,
HongKong,
Expand All @@ -80,6 +84,45 @@
)


class TestAppendDataClass:
@given(data=fixed_dictionaries({"a": int64s() | none(), "b": floats() | none()}))
def test_main(self, *, data: StrMapping) -> None:
schema = {"a": Int64, "b": Float64, "c": Utf8}
df = DataFrame([], schema=schema, orient="row")

@dataclass(kw_only=True)
class Row:
a: int | None = None
b: float | None = None

row = Row(**data)
result = append_dataclass(df, row)
check_polars_dataframe(result, height=1, schema_list=schema)

@given(
data=fixed_dictionaries({
"a": int64s() | none(),
"b": floats() | none(),
"c": text_ascii() | none(),
})
)
def test_error(self, *, data: StrMapping) -> None:
df = DataFrame([], schema={"a": Int64, "b": Float64}, orient="row")

@dataclass(kw_only=True)
class Row:
a: int | None = None
b: float | None = None
c: str | None = None

row = Row(**data)
with raises(
AppendDataClassError,
match="Dataclass fields .* must be a subset of DataFrame columns .*; dataclass had extra items .*",
):
_ = append_dataclass(df, row)


class TestCeilDatetime:
start: ClassVar[dt.datetime] = dt.datetime(2000, 1, 1, 0, 0, tzinfo=UTC)
end: ClassVar[dt.datetime] = dt.datetime(2000, 1, 1, 0, 3, tzinfo=UTC)
Expand Down
2 changes: 1 addition & 1 deletion src/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

__version__ = "0.57.1"
__version__ = "0.57.2"
33 changes: 32 additions & 1 deletion src/utilities/polars.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import datetime as dt
import reprlib
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from collections.abc import Set as AbstractSet
from contextlib import suppress
Expand All @@ -12,6 +13,7 @@
from typing import (
TYPE_CHECKING,
Any,
Generic,
Literal,
Never,
TypeGuard,
Expand All @@ -36,6 +38,7 @@
Utf8,
all_horizontal,
col,
concat,
lit,
struct,
when,
Expand All @@ -54,15 +57,17 @@
from polars.testing import assert_frame_equal
from typing_extensions import override

from utilities.dataclasses import Dataclass, is_dataclass_class
from utilities.dataclasses import Dataclass, is_dataclass_class, yield_field_names
from utilities.errors import redirect_error
from utilities.iterables import (
CheckIterablesEqualError,
CheckMappingsEqualError,
CheckSubSetError,
CheckSuperMappingError,
MaybeIterable,
check_iterables_equal,
check_mappings_equal,
check_subset,
check_supermapping,
is_iterable_not_str,
one,
Expand All @@ -84,13 +89,38 @@
from zoneinfo import ZoneInfo


_T = TypeVar("_T")
DatetimeHongKong = Datetime(time_zone="Asia/Hong_Kong")
DatetimeTokyo = Datetime(time_zone="Asia/Tokyo")
DatetimeUSCentral = Datetime(time_zone="US/Central")
DatetimeUSEastern = Datetime(time_zone="US/Eastern")
DatetimeUTC = Datetime(time_zone="UTC")


def append_dataclass(df: DataFrame, obj: Dataclass, /) -> DataFrame:
"""Append a dataclass object to a DataFrame."""
fields = yield_field_names(obj)
try:
check_subset(fields, df.columns)
except CheckSubSetError as error:
raise AppendDataClassError(
left=error.left, right=error.right, extra=error.extra
) from None
row = DataFrame([obj], orient="row")
return concat([df, row], how="diagonal")


@dataclass(kw_only=True)
class AppendDataClassError(Exception, Generic[_T]):
left: AbstractSet[_T]
right: AbstractSet[_T]
extra: AbstractSet[_T]

@override
def __str__(self) -> str:
return f"Dataclass fields {reprlib.repr(self.left)} must be a subset of DataFrame columns {reprlib.repr(self.right)}; dataclass had extra items {reprlib.repr(self.extra)}"


@overload
def ceil_datetime(column: Expr | str, every: Expr | str, /) -> Expr: ...
@overload
Expand Down Expand Up @@ -863,6 +893,7 @@ def zoned_datetime(
"IsNullStructSeriesError",
"SetFirstRowAsColumnsError",
"YieldStructSeriesElementsError",
"append_dataclass",
"ceil_datetime",
"check_polars_dataframe",
"check_zoned_dtype_or_series",
Expand Down

0 comments on commit 1399499

Please sign in to comment.