diff --git a/src/brevitas/export/onnx/debug.py b/src/brevitas/export/onnx/debug.py index 85acba31b..3d5c236ac 100644 --- a/src/brevitas/export/onnx/debug.py +++ b/src/brevitas/export/onnx/debug.py @@ -14,6 +14,7 @@ class DebugMarkerFunction(Function): @staticmethod def symbolic(g, input, export_debug_name): ret = g.op('brevitas.onnx::DebugMarker', input, export_debug_name_s=export_debug_name) + ret.setType(input.type()) return ret @staticmethod