From 69582e97a8400c5121debbfc00a0ca61076f2ff4 Mon Sep 17 00:00:00 2001 From: AllenTiTaiWang Date: Thu, 13 Apr 2023 18:05:07 +0000 Subject: [PATCH] Improve unflatten | feat(atenlib) --- onnxscript/function_libs/torch_aten/ops/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxscript/function_libs/torch_aten/ops/core.py b/onnxscript/function_libs/torch_aten/ops/core.py index 1ba3f66f6..e9802c39c 100644 --- a/onnxscript/function_libs/torch_aten/ops/core.py +++ b/onnxscript/function_libs/torch_aten/ops/core.py @@ -5877,10 +5877,10 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64): self_size = op.Shape(self) - if dim < 0: - # PyTorch accepts negative dim as reversed counting - self_rank = op.Size(self_size) - dim = self_rank + dim + # PyTorch accepts negative dim as reversed counting + self_rank = op.Size(self_size) + dim = self_rank + dim + dim = dim % self_rank head_start_idx = op.Constant(value_ints=[0]) head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1]))