From a53c3e300c7994e61179e7e7cde0ee9ec38941c1 Mon Sep 17 00:00:00 2001 From: Bernhard Vogginger Date: Thu, 15 Aug 2024 16:59:17 +0200 Subject: [PATCH] Add metadata to missing nodes Also fix metadata read write tests --- nir/ir/conv.py | 1 + nir/ir/graph.py | 2 ++ nir/ir/linear.py | 2 ++ nir/ir/pooling.py | 1 + pyproject.toml | 2 +- tests/test_readwrite.py | 20 +++++++++++--------- 6 files changed, 18 insertions(+), 10 deletions(-) diff --git a/nir/ir/conv.py b/nir/ir/conv.py index ce7c452..3ba3726 100644 --- a/nir/ir/conv.py +++ b/nir/ir/conv.py @@ -106,6 +106,7 @@ class Conv2d(NIRNode): dilation: Union[int, Tuple[int, int]] # Dilation groups: int # Groups bias: np.ndarray # Bias C_out + metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): if isinstance(self.padding, str) and self.padding not in ["same", "valid"]: diff --git a/nir/ir/graph.py b/nir/ir/graph.py index f4315c8..47eba09 100644 --- a/nir/ir/graph.py +++ b/nir/ir/graph.py @@ -452,6 +452,7 @@ class Input(NIRNode): # Shape of incoming data (overrrides input_type from # NIRNode to allow for non-keyword (positional) initialization) input_type: Types + metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): self.input_type = parse_shape_argument(self.input_type, "input") @@ -479,6 +480,7 @@ class Output(NIRNode): # Type of incoming data (overrrides input_type from # NIRNode to allow for non-keyword (positional) initialization) output_type: Types + metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): self.output_type = parse_shape_argument(self.output_type, "output") diff --git a/nir/ir/linear.py b/nir/ir/linear.py index ccd45bd..1353f62 100644 --- a/nir/ir/linear.py +++ b/nir/ir/linear.py @@ -46,6 +46,7 @@ class Linear(NIRNode): """ weight: np.ndarray # Weight term + metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): assert len(self.weight.shape) >= 2, "Weight must be at least 2D" @@ -69,6 +70,7 @@ class Scale(NIRNode): """ scale: np.ndarray # Scaling factor + metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): self.input_type = {"input": np.array(self.scale.shape)} diff --git a/nir/ir/pooling.py b/nir/ir/pooling.py index 2f2ae47..4797320 100644 --- a/nir/ir/pooling.py +++ b/nir/ir/pooling.py @@ -29,6 +29,7 @@ class AvgPool2d(NIRNode): kernel_size: np.ndarray # (Height, Width) stride: np.ndarray # (Height, width) padding: np.ndarray # (Height, width) + metadata: Dict[str, Any] = field(default_factory=dict) def __post_init__(self): self.input_type = {"input": None} diff --git a/pyproject.toml b/pyproject.toml index e1865a5..90be725 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,5 +47,5 @@ find={include = ["nir*"]} [tool.ruff] line-length = 100 -lint.per-file-ignores = {"docs/conf.py" = ["E402"]} +per-file-ignores = {"docs/conf.py" = ["E402"]} exclude = ["paper/"] diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index 9410ee9..ca68e28 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -1,11 +1,11 @@ -import tempfile import inspect import sys +import tempfile import numpy as np import nir -from tests import mock_affine, mock_conv +from tests import mock_affine, mock_conv, mock_linear ALL_NODES = [] for name, obj in inspect.getmembers(sys.modules["nir.ir"]): @@ -47,7 +47,7 @@ def factory_test_graph(ir: nir.NIRGraph): assert_equivalence(ir, ir2) -def factory_test_metadata(node): +def factory_test_metadata(ir: nir.NIRGraph): def compare_dicts(d1, d2): for k, v in d1.items(): if isinstance(v, np.ndarray): @@ -58,12 +58,14 @@ def compare_dicts(d1, d2): assert v == d2[k] metadata = {"some": "metadata", "with": 2, "data": np.array([1, 2, 3])} - node.metadata = metadata - compare_dicts(node.metadata, metadata) + for node in ir.nodes.values(): + node.metadata = metadata + compare_dicts(node.metadata, metadata) tmp = tempfile.mktemp() - nir.write(tmp, node) - node2 = nir.read(tmp) - compare_dicts(node2.metadata, metadata) + nir.write(tmp, ir) + ir2 = nir.read(tmp) + for node in ir2.nodes.values(): + compare_dicts(node.metadata, metadata) def test_simple(): @@ -146,7 +148,7 @@ def test_linear(): tau = np.array([1, 1, 1]) r = np.array([1, 1, 1]) v_leak = np.array([1, 1, 1]) - ir = nir.NIRGraph.from_list(mock_affine(2, 2), nir.LI(tau, r, v_leak)) + ir = nir.NIRGraph.from_list(mock_linear(2, 2), nir.LI(tau, r, v_leak)) factory_test_graph(ir) factory_test_metadata(ir)