Skip to content

Commit

Permalink
Add the functionality for removing empty fields from serialized datac…
Browse files Browse the repository at this point in the history
…lasses
  • Loading branch information
mszalkowski-ant committed Sep 11, 2024
1 parent 62be4b4 commit 59aeddc
Show file tree
Hide file tree
Showing 10 changed files with 424 additions and 165 deletions.
2 changes: 1 addition & 1 deletion tests/tests_build/test_hier_design.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def hier_design_path(hier_design_yaml: Path) -> Path:

@pytest.fixture
def hier_design(hier_design_yaml: Path) -> DesignDescription:
return DesignDescription.from_file(hier_design_yaml)
return DesignDescription.load(hier_design_yaml)


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_kpm/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def read_json_file(json_file_path: str) -> dict:


def read_yaml_file(yaml_file_path: str) -> DesignDescription:
return DesignDescription.from_file(yaml_file_path)
return DesignDescription.load(yaml_file_path)
139 changes: 137 additions & 2 deletions tests/tests_parse/test_common_serdes.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import Any, List
from dataclasses import fields
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Union

import marshmallow_dataclass
import pytest

from topwrap.common_serdes import (
AnnotatedFlatTree,
FlatTree,
MarshmallowDataclassExtensions,
NestedDict,
annotate_flat_tree,
ext_field,
flatten_tree,
unflatten_annotated_tree,
)
Expand Down Expand Up @@ -247,10 +253,139 @@ def test_unflatten_flattened_identity(
== tree_samelength
)

def test_unflatten_annotated_unsorted_tree(self, annot_tree_unsorted_order):
def test_unflatten_annotated_unsorted_tree(
self, annot_tree_unsorted_order: AnnotatedFlatTree[str, Any]
):
assert unflatten_annotated_tree(
annot_tree_unsorted_order, ["direction", "type", "name", "width"], sort=True
) == {
"out": {"required": {"data_out": 16, "valid": 1}},
"in": {"optional": {"data_in": 32}},
}


@marshmallow_dataclass.dataclass
class DummyDataclass(MarshmallowDataclassExtensions):
str_f: str
int_f: int
deep_f: Dict[Any, Any] = ext_field(dict)


class TestDataclassExtensions:
def test_no_cleanup_regular(self):
@marshmallow_dataclass.dataclass
class TestDataclass(MarshmallowDataclassExtensions):
opt_field1: List[int] = ext_field(list, self_cleanup=False)
opt_field2: List[str] = ext_field(list, self_cleanup=False)
opt_field3: Dict[str, List[int]] = ext_field(dict, self_cleanup=False)

serial = TestDataclass(opt_field1=[1, 2, 3, 4]).to_dict()

assert "opt_field1" in serial and "opt_field2" in serial and "opt_field3" in serial

def test_cleanup_regular(self):
@marshmallow_dataclass.dataclass
class TestDataclass(MarshmallowDataclassExtensions):
req_field: List[int]
opt_field: List[str] = ext_field(list)
opt_field2: List[str] = ext_field(list)
opt_field3: List[str] = ext_field(list)
opt_field4: Dict[int, List[str]] = ext_field(dict)

serial = TestDataclass(req_field=[], opt_field2=[], opt_field3=["bar"]).to_dict()

assert "req_field" in serial and "opt_field3" in serial
assert (
"opt_field" not in serial and "opt_field2" not in serial and "opt_field4" not in serial
)

def test_cleanup_deep_cleanup(self):
@marshmallow_dataclass.dataclass
class TestDataclass(MarshmallowDataclassExtensions):
req_field_regular: Dict[str, List[int]]
req_field_list: List[Any] = ext_field(deep_cleanup=True)
req_field_deepclean: Dict[str, Dict[str, List[Any]]] = ext_field(deep_cleanup=True)
opt_field_deepclean: Dict[str, Dict[str, List[Any]]] = ext_field(
dict, deep_cleanup=True
)
opt_field_deep_seq: List[Any] = ext_field(list, deep_cleanup=True)

deep_dict = {
"shallow_empty": {},
"deep_full": {"empty": [], "not_empty": ["foo", {}]},
"deep_empty": {"empty": [{}], "emptier": [{"a": []}]},
}
serial = TestDataclass(
req_field_regular={"empty": []},
req_field_list=["asdf", {}, None],
req_field_deepclean=deep_dict,
opt_field_deepclean=deep_dict,
opt_field_deep_seq=["item", ["item", {}, ["item", [], {"item", tuple()}, set()]]],
).to_dict()

for fld in fields(TestDataclass):
assert fld.name in serial

assert serial["req_field_regular"]["empty"] == []
assert serial["req_field_list"] == ["asdf"]
assert serial["opt_field_deep_seq"] == ["item", ["item", ["item", {"item", tuple()}]]]
for fld in ("req_field_deepclean", "opt_field_deepclean"):
assert "shallow_empty" not in serial[fld]
assert "deep_empty" not in serial[fld]
assert serial[fld]["deep_full"] == {"not_empty": ["foo"]}

def test_noops(self):
@marshmallow_dataclass.dataclass
class Inner:
fld: List[Any] = ext_field(deep_cleanup=True)

@marshmallow_dataclass.dataclass
class TestDataclass(MarshmallowDataclassExtensions):
req_self: Dict[Any, Any] = ext_field(self_cleanup=True)
req_nested: Inner = ext_field(deep_cleanup=True)

serial = TestDataclass(req_self={}, req_nested=Inner(fld=[])).to_dict()

assert serial["req_nested"] == {"fld": []}
assert serial["req_self"] == {}

def test_cleanup_no_simple_type_erasure(self):
@marshmallow_dataclass.dataclass
class TestDataclass(MarshmallowDataclassExtensions):
opt_int_field: int = ext_field(0)
opt_str_field: str = ext_field("")
opt_float_field: float = ext_field(0.0)
opt_bool_field: bool = ext_field(False)
nested: Dict[str, Union[int, str, float, bool]] = ext_field(dict, deep_cleanup=True)

nested = {"int": 0, "str": "", "float": 0.0, "bool": False}
serial = TestDataclass(nested=nested).to_dict()

for field in fields(TestDataclass):
assert field.name in serial

for field in nested:
assert field in serial["nested"]

@pytest.fixture
def dummy_instance(self):
return DummyDataclass(str_f="a", int_f=3, deep_f={"bar": "a"})

def test_dict_methods(self, dummy_instance: DummyDataclass):
data = {"str_f": "a", "int_f": 3, "deep_f": {"bar": "a"}}

assert DummyDataclass.from_dict(data) == dummy_instance
assert dummy_instance.to_dict() == data

def test_yaml_methods(self, dummy_instance: DummyDataclass):
data = "{deep_f: {bar: a}, int_f: 3, str_f: a}\n"

assert DummyDataclass.from_yaml(data) == dummy_instance
assert dummy_instance.to_yaml(default_flow_style=True) == data

def test_file_methods(self, dummy_instance: DummyDataclass):
with TemporaryDirectory() as tmpdir:
path = Path(tmpdir) / "out.yaml"
dummy_instance.save(path)

assert DummyDataclass.load(path) == dummy_instance
12 changes: 6 additions & 6 deletions tests/tests_parse/test_ip_desc.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_invalid_interface_type(self, invalid_interface_type_core):
with pytest.raises(
ValidationError, match="'Invalid interface type: IDONTEXIST'"
) and pytest.raises(ValidationError, match="'Must be one of: master,"):
IPCoreDescription.Schema().load(invalid_interface_type_core)
IPCoreDescription.from_dict(invalid_interface_type_core)

def test_invalid_interface_compliance(
self, invalid_interface_compliance_core, force_compliance
Expand All @@ -116,15 +116,15 @@ def test_invalid_interface_compliance(
) and pytest.raises(
ValidationError, match='Unknown out port "TBUBU", not present in interface "AXI4Stream"'
):
IPCoreDescription.Schema().load(invalid_interface_compliance_core)
IPCoreDescription.from_dict(invalid_interface_compliance_core)

def test_optional_signal_missing_compliance(
self, optional_missing_interface_compliance_core, force_compliance
):
IPCoreDescription.Schema().load(optional_missing_interface_compliance_core)
IPCoreDescription.from_dict(optional_missing_interface_compliance_core)

def test_interface_compliance_off(self, invalid_interface_compliance_core):
IPCoreDescription.Schema().load(invalid_interface_compliance_core)
IPCoreDescription.from_dict(invalid_interface_compliance_core)

def test_builtins_presence(self):
for ip in (
Expand Down Expand Up @@ -251,12 +251,12 @@ def deep_normalize(err):
return norm

try:
IPCoreDescription.Schema().load(completely_invalid_core)
IPCoreDescription.from_dict(completely_invalid_core)
except ValidationError as e:
assert DeepDiff(deep_normalize(e), deep_normalize(EXPECTED)) == {}

def test_valid_syntax(self, completely_valid_core, force_compliance):
ip: IPCoreDescription = IPCoreDescription.Schema().load(completely_valid_core)
ip: IPCoreDescription = IPCoreDescription.from_dict(completely_valid_core)

assert ip == IPCoreDescription(
name="correct_core",
Expand Down
Loading

0 comments on commit 59aeddc

Please sign in to comment.