Skip to content

Commit

Permalink
hi
Browse files Browse the repository at this point in the history
  • Loading branch information
dycw committed Sep 18, 2024
1 parent 8f61921 commit 3087f01
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
24 changes: 23 additions & 1 deletion src/tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from zoneinfo import ZoneInfo

from hypothesis import given
from hypothesis.strategies import fixed_dictionaries, floats, none, sampled_from
from hypothesis.strategies import dates, fixed_dictionaries, floats, none, sampled_from
from polars import (
Boolean,
DataFrame,
Expand Down Expand Up @@ -58,6 +58,7 @@
collect_series,
columns_to_dict,
convert_time_zone,
dataclass_to_row,
drop_null_struct_series,
ensure_expr_or_series,
floor_datetime,
Expand Down Expand Up @@ -548,6 +549,27 @@ def test_dataframe_nested_twice(self) -> None:
assert_frame_equal(result, expected)


class TestDataClassToRow:
@given(data=fixed_dictionaries({"a": int64s() | none(), "b": floats() | none()}))
def test_basic_types(self, *, data: StrMapping) -> None:
@dataclass(kw_only=True)
class Row:
a: int | None = None
b: float | None = None

df = dataclass_to_row(Row(**data))
check_polars_dataframe(df, height=1, schema_list={"a": Int64, "b": Float64})

@given(data=fixed_dictionaries({"date": dates()}))
def test_date(self, *, data: StrMapping) -> None:
@dataclass(kw_only=True)
class Row:
date: dt.date | None = None

df = dataclass_to_row(Row(**data))
check_polars_dataframe(df, height=1, schema_list={"date": Date})


class TestDatetimeUTC:
@mark.parametrize(
("dtype", "time_zone"),
Expand Down
5 changes: 5 additions & 0 deletions src/utilities/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,11 @@ def __str__(self) -> str:
return f"DataFrame must be unique on {self.key!r}\n\n{self.df}"


def dataclass_to_row(obj: Dataclass, /) -> DataFrame:
"""Convert a dataclass into a 1-row DataFrame."""
return DataFrame([obj], orient="row")


def drop_null_struct_series(series: Series, /) -> Series:
"""Drop nulls in a struct-dtype Series as per the <= 1.1 definition."""
try:
Expand Down

0 comments on commit 3087f01

Please sign in to comment.