diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index 2e9486b68..36e5c77a9 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -439,12 +439,16 @@ def split_to_sequence(node: ir.Node, op, state: OptimizerState) -> ReturnValue: else: return None + # If Split returns a single value, we need to wrap it into a list. + if isinstance(split_values, ir.Value): + split_values = [split_values] + keepdims = _get_int_attribute(node, "keepdims", 1) if keepdims is None: return None if keepdims == 0: # squeeze the split dimension if keepdims is 0 - axis_val = op.Constant(value_int=axis, _outputs=[f"{output.name}_axis"]) + axis_val = op.Constant(value_ints=[axis], _outputs=[f"{output.name}_axis"]) squeezed_values = [] for i in range(num_outputs): squeezed = op.Squeeze(