Skip to content

Commit

Permalink
Improve unflatten | feat(atenlib) (#628)
Browse files Browse the repository at this point in the history
Remove If subgraph
  • Loading branch information
titaiwangms authored Apr 13, 2023
1 parent d49dbfc commit a9fc6ca
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions onnxscript/function_libs/torch_aten/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5877,10 +5877,10 @@ def aten_unflatten(self: TReal, dim: INT64, sizes: INT64):

self_size = op.Shape(self)

if dim < 0:
# PyTorch accepts negative dim as reversed counting
self_rank = op.Size(self_size)
dim = self_rank + dim
# PyTorch accepts negative dim as reversed counting
self_rank = op.Size(self_size)
dim = self_rank + dim
dim = dim % self_rank

head_start_idx = op.Constant(value_ints=[0])
head_end_idx = op.Reshape(dim, op.Constant(value_ints=[1]))
Expand Down

0 comments on commit a9fc6ca

Please sign in to comment.