From ba5c41f527bb7fb045bed8842786af2d17ec75e9 Mon Sep 17 00:00:00 2001 From: mdaniowi Date: Thu, 1 Aug 2024 09:30:14 +0100 Subject: [PATCH] strings attr test added to test_attr.py --- tests/custom_op/test_attr.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tests/custom_op/test_attr.py b/tests/custom_op/test_attr.py index 9db644d7..6e2527ac 100644 --- a/tests/custom_op/test_attr.py +++ b/tests/custom_op/test_attr.py @@ -37,7 +37,11 @@ class AttrTestOp(CustomOp): def get_nodeattr_types(self): - return {"tensor_attr": ("t", True, np.asarray([]))} + my_attrs = { + "tensor_attr": ("t", True, np.asarray([])), + "strings_attr": ("strings", True, [""]) + } + return my_attrs def make_shape_compatible_op(self, model): param_tensor = self.get_nodeattr("tensor_attr") @@ -70,6 +74,7 @@ def test_attr(): strarr = np.array2string(w, separator=", ") w_str = strarr.replace("[", "{").replace("]", "}").replace(" ", "") tensor_attr_str = f"int8{wshp_str} {w_str}" + strings_attr = ["a", "bc", "def"] input = f""" < @@ -86,9 +91,18 @@ def test_attr(): model = oprs.parse_model(input) model = ModelWrapper(model) inst = getCustomOp(model.graph.node[0]) + w_prod = inst.get_nodeattr("tensor_attr") assert (w_prod == w).all() w = w - 1 inst.set_nodeattr("tensor_attr", w) w_prod = inst.get_nodeattr("tensor_attr") assert (w_prod == w).all() + + inst.set_nodeattr("strings_attr", strings_attr) + strings_attr_prod = inst.get_nodeattr("strings_attr") + assert strings_attr_prod == strings_attr + strings_attr_prod[0] = "test" + inst.set_nodeattr("strings_attr", strings_attr_prod) + assert inst.get_nodeattr("strings_attr") == ["test"] + strings_attr[1:] +