diff --git a/src/tests/test_polars.py b/src/tests/test_polars.py index 820ca8725..0d7911adf 100644 --- a/src/tests/test_polars.py +++ b/src/tests/test_polars.py @@ -7,7 +7,7 @@ from zoneinfo import ZoneInfo from hypothesis import given -from hypothesis.strategies import fixed_dictionaries, floats, none, sampled_from +from hypothesis.strategies import dates, fixed_dictionaries, floats, none, sampled_from from polars import ( Boolean, DataFrame, @@ -58,6 +58,7 @@ collect_series, columns_to_dict, convert_time_zone, + dataclass_to_row, drop_null_struct_series, ensure_expr_or_series, floor_datetime, @@ -548,6 +549,27 @@ def test_dataframe_nested_twice(self) -> None: assert_frame_equal(result, expected) +class TestDataClassToRow: + @given(data=fixed_dictionaries({"a": int64s() | none(), "b": floats() | none()})) + def test_basic_types(self, *, data: StrMapping) -> None: + @dataclass(kw_only=True) + class Row: + a: int | None = None + b: float | None = None + + df = dataclass_to_row(Row(**data)) + check_polars_dataframe(df, height=1, schema_list={"a": Int64, "b": Float64}) + + @given(data=fixed_dictionaries({"date": dates()})) + def test_date(self, *, data: StrMapping) -> None: + @dataclass(kw_only=True) + class Row: + date: dt.date | None = None + + df = dataclass_to_row(Row(**data)) + check_polars_dataframe(df, height=1, schema_list={"date": Date}) + + class TestDatetimeUTC: @mark.parametrize( ("dtype", "time_zone"), diff --git a/src/utilities/polars.py b/src/utilities/polars.py index 4e4115daf..c0c7c9a9f 100644 --- a/src/utilities/polars.py +++ b/src/utilities/polars.py @@ -506,6 +506,11 @@ def __str__(self) -> str: return f"DataFrame must be unique on {self.key!r}\n\n{self.df}" +def dataclass_to_row(obj: Dataclass, /) -> DataFrame: + """Convert a dataclass into a 1-row DataFrame.""" + return DataFrame([obj], orient="row") + + def drop_null_struct_series(series: Series, /) -> Series: """Drop nulls in a struct-dtype Series as per the <= 1.1 definition.""" try: