diff --git a/src/qonnx/custom_op/registry.py b/src/qonnx/custom_op/registry.py index 3540bb5a..cefc0973 100644 --- a/src/qonnx/custom_op/registry.py +++ b/src/qonnx/custom_op/registry.py @@ -28,7 +28,24 @@ import importlib -from qonnx.util.basic import get_preferred_onnx_opset +_QONNX_DOMAINS = ["finn", "qonnx.custom_op", "onnx.brevitas"] + + +def register_custom_domain(domain: str): + _QONNX_DOMAINS.append(domain) + + +def is_finn_op(op_type): + "Return whether given op_type string is a QONNX or FINN custom op" + is_finn = False + for domain in _QONNX_DOMAINS: + is_finn = is_finn or op_type.startswith(domain) + return is_finn + + +def get_preferred_onnx_opset(): + "Return preferred ONNX opset version for QONNX" + return 11 def getCustomOp(node, onnx_opset_version=get_preferred_onnx_opset(), brevitas_exception=True): diff --git a/src/qonnx/util/basic.py b/src/qonnx/util/basic.py index 363aa501..d91844da 100644 --- a/src/qonnx/util/basic.py +++ b/src/qonnx/util/basic.py @@ -32,6 +32,9 @@ import string import warnings +import qonnx +import qonnx.custom_op +import qonnx.custom_op.registry from qonnx.core.datatype import DataType # TODO solve by moving onnx-dependent fxns to onnx.py @@ -48,7 +51,7 @@ def get_preferred_onnx_opset(): "Return preferred ONNX opset version for QONNX" - return 11 + return qonnx.custom_op.registry.get_preferred_onnx_opset() def qonnx_make_model(graph_proto, **kwargs): @@ -63,8 +66,7 @@ def qonnx_make_model(graph_proto, **kwargs): def is_finn_op(op_type): - "Return whether given op_type string is a QONNX or FINN custom op" - return op_type.startswith("finn") or op_type.startswith("qonnx.custom_op") or op_type.startswith("onnx.brevitas") + return qonnx.custom_op.registry.is_finn_op(op_type) def get_num_default_workers():