Skip to content

Commit

Permalink
append_dataclass now allowed null-valued fields outside the `DataFr…
Browse files Browse the repository at this point in the history
…ame` (#747)
  • Loading branch information
dycw authored Sep 18, 2024
1 parent 7ebbf98 commit b42ec9e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 16 deletions.
55 changes: 45 additions & 10 deletions src/tests/test_polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,30 @@


class TestAppendDataClass:
@given(
data=fixed_dictionaries({
"a": int64s() | none(),
"b": floats() | none(),
"c": text_ascii() | none(),
})
)
def test_columns_and_fields_equal(self, *, data: StrMapping) -> None:
df = DataFrame(schema={"a": Int64, "b": Float64, "c": Utf8})

@dataclass(kw_only=True)
class Row:
a: int | None = None
b: float | None = None
c: str | None = None

row = Row(**data)
result = append_dataclass(df, row)
height = 0 if (row.a is None) and (row.b is None) and (row.c is None) else 1
check_polars_dataframe(result, height=height, schema_list=df.schema)

@given(data=fixed_dictionaries({"a": int64s() | none(), "b": floats() | none()}))
def test_main(self, *, data: StrMapping) -> None:
schema = {"a": Int64, "b": Float64, "c": Utf8}
df = DataFrame([], schema=schema, orient="row")
def test_extra_column(self, *, data: StrMapping) -> None:
df = DataFrame(schema={"a": Int64, "b": Float64, "c": Utf8})

@dataclass(kw_only=True)
class Row:
Expand All @@ -105,36 +125,51 @@ class Row:

row = Row(**data)
result = append_dataclass(df, row)
check_polars_dataframe(result, height=1, schema_list=schema)
height = 0 if (row.a is None) and (row.b is None) else 1
check_polars_dataframe(result, height=height, schema_list=df.schema)

@given(data=fixed_dictionaries({"a": int64s() | none(), "b": floats() | none()}))
def test_extra_field_but_none(self, *, data: StrMapping) -> None:
df = DataFrame(schema={"a": Int64, "b": Float64})

@dataclass(kw_only=True)
class Row:
a: int | None = None
b: float | None = None
c: str | None = None

row = Row(**data)
result = append_dataclass(df, row)
height = 0 if (row.a is None) and (row.b is None) else 1
check_polars_dataframe(result, height=height, schema_list=df.schema)

@given(data=fixed_dictionaries({"datetime": zoned_datetimes()}))
def test_zoned_datetime(self, *, data: StrMapping) -> None:
schema = {"datetime": DatetimeUTC}
df = DataFrame([], schema=schema, orient="row")
df = DataFrame(schema={"datetime": DatetimeUTC})

@dataclass(kw_only=True)
class Row:
datetime: dt.datetime

row = Row(**data)
result = append_dataclass(df, row)
check_polars_dataframe(result, height=1, schema_list=schema)
check_polars_dataframe(result, height=1, schema_list=df.schema)

@given(
data=fixed_dictionaries({
"a": int64s() | none(),
"b": floats() | none(),
"c": text_ascii() | none(),
"c": text_ascii(),
})
)
def test_error(self, *, data: StrMapping) -> None:
df = DataFrame([], schema={"a": Int64, "b": Float64}, orient="row")
df = DataFrame(schema={"a": Int64, "b": Float64})

@dataclass(kw_only=True)
class Row:
a: int | None = None
b: float | None = None
c: str | None = None
c: str

row = Row(**data)
with raises(
Expand Down
2 changes: 1 addition & 1 deletion src/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from __future__ import annotations

__version__ = "0.57.3"
__version__ = "0.57.4"
11 changes: 6 additions & 5 deletions src/utilities/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
from collections.abc import Set as AbstractSet
from contextlib import suppress
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from datetime import timezone
from enum import Enum
from functools import partial, reduce
Expand Down Expand Up @@ -57,7 +57,7 @@
from polars.testing import assert_frame_equal
from typing_extensions import override

from utilities.dataclasses import Dataclass, is_dataclass_class, yield_field_names
from utilities.dataclasses import Dataclass, is_dataclass_class
from utilities.errors import redirect_error
from utilities.iterables import (
CheckIterablesEqualError,
Expand Down Expand Up @@ -100,14 +100,15 @@

def append_dataclass(df: DataFrame, obj: Dataclass, /) -> DataFrame:
"""Append a dataclass object to a DataFrame."""
fields = yield_field_names(obj)
non_null_fields = {k: v for k, v in asdict(obj).items() if v is not None}
try:
check_subset(fields, df.columns)
check_subset(non_null_fields, df.columns)
except CheckSubSetError as error:
raise AppendDataClassError(
left=error.left, right=error.right, extra=error.extra
) from None
row = dataclass_to_row(obj)
row_cols = set(df.columns) & set(non_null_fields)
row = dataclass_to_row(obj).select(*row_cols)
return concat([df, row], how="diagonal")


Expand Down

0 comments on commit b42ec9e

Please sign in to comment.