From 171d038461179ccffeadd32359a2f62511b0570d Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 13 Sep 2024 10:25:50 +0200 Subject: [PATCH 1/2] Allow for custom ops in registry. --- src/qonnx/custom_op/registry.py | 14 ++++++++++++++ src/qonnx/util/basic.py | 6 ++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 3540bb5a..65b8b2ba 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -30,6 +30,20 @@ from qonnx.util.basic import get_preferred_onnx_opset +_QONNX_DOMAINS = ["finn", "qonnx.custom_op", "onnx.brevitas"] + + +def register_custom_domain(domain: str): + _QONNX_DOMAINS.append(domain) + + +def is_finn_op(op_type): + "Return whether given op_type string is a QONNX or FINN custom op" + is_finn = False + for domain in _QONNX_DOMAINS: + is_finn = is_finn or op_type.startswith(domain) + return is_finn + def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): "Return a QONNX CustomOp instance for the given ONNX node, if it exists." diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 363aa501..7c1560ed 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -32,6 +32,9 @@ import string import warnings +import qonnx +import qonnx.custom_op +import qonnx.custom_op.registry from qonnx.core.datatype import DataType # TODO solve by moving onnx-dependent fxns to onnx.py @@ -63,8 +66,7 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(op_type): - "Return whether given op_type string is a QONNX or FINN custom op" - return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas") + return qonnx.custom_op.registry.is_finn_op(op_type) def get_num_default_workers(): From c84226ffada1088cc14ad022ccda262f70eaae05 Mon Sep 17 00:00:00 2001 From: jvreca Date: Fri, 13 Sep 2024 15:46:34 +0200 Subject: [PATCH 2/2] moved function get_preferred_onnx_opset to avoid circular commit --- src/qonnx/custom_op/registry.py | 7 +++++-- src/qonnx/util/basic.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 65b8b2ba..cefc0973 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -28,8 +28,6 @@ import importlib -from qonnx.util.basic import get_preferred_onnx_opset - _QONNX_DOMAINS = ["finn", "qonnx.custom_op", "onnx.brevitas"] @@ -45,6 +43,11 @@ def is_finn_op(op_type): return is_finn +def get_preferred_onnx_opset(): + "Return preferred ONNX opset version for QONNX" + return 11 + + def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): "Return a QONNX CustomOp instance for the given ONNX node, if it exists." op_type = node.op_type diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 7c1560ed..d91844da 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -51,7 +51,7 @@ def get_preferred_onnx_opset(): "Return preferred ONNX opset version for QONNX" - return 11 + return qonnx.custom_op.registry.get_preferred_onnx_opset() def qonnx_make_model(graph_proto, **kwargs):