Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce LayerNorm IR ? #14447

Open
seanshpark opened this issue Dec 13, 2024 · 11 comments
Open

Introduce LayerNorm IR ? #14447

seanshpark opened this issue Dec 13, 2024 · 11 comments

Comments

@seanshpark
Copy link
Contributor

seanshpark commented Dec 13, 2024

Transformers(including ViT) has LayerNorm Op in the graph.
Circle model from ONNX has decomposed sub graph but would it be better to process as a single Node ?


test code to generate onnx

import onnx
import torch
import torch.nn as nn


class LayerNormNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = nn.LayerNorm((3, 16))

    def forward(self, x):
        out = self.ln(x)
        return out


net = LayerNormNet()
inp = torch.randn(1, 3, 16)

torch.onnx.export(net, inp, "ln11.onnx", opset_version=11)
onnx.shape_inference.infer_shapes_path('ln11.onnx', 'ln11-si.onnx')

torch.onnx.export(net, inp, "ln17.onnx", opset_version=17)
onnx.shape_inference.infer_shapes_path('ln17.onnx', 'ln17-si.onnx')

onnx graph

opset=11 opset=17
image image
@seanshpark
Copy link
Contributor Author

@seanshpark
Copy link
Contributor Author

@Samsung/one_compiler , @Samsung/one_onert , comments ?

@glistening
Copy link
Contributor

At least for this year, we haven't need LayerNorm. Instead, RMSNorm was used in our target model. Thus, it is low priority at this moment in runtime perspective. However, if it is necessary for front-end or some model I am not aware of. Please feel free to add.

@jinevening
Copy link
Contributor

@seanshpark Could you check if the parameter of layernorm in your example code matches with the real model? normalized_shape is usually the last dimension in langue models. I'm asking because characteristics of LN operation is very different according to normalized_shape.


If LN is given as a single Op, our backend device may convert it to InstanceNorm. For example,

Before
Input [N, L, D] -> LayerNorm [N, L, D]

After
Input [N, L, D] -> Transpose [N, D, L] -> Reshape [N, 1, D, L] -> InstanceNorm [N, 1, D, L] -> Reshape [N, D, L] -> Transpose [N, L, D]

I'm not against introducing LN (maybe for backends other than npu), but for now it would be also possible to just use InstanceNorm.

@seanshpark
Copy link
Contributor Author

Could you check if the parameter of layernorm in your example code matches with the real model?

Not sure what match means, but from our customer, input shape is [1, 16384, 128].
ReduceMean has different attributes.

image

@seanshpark
Copy link
Contributor Author

got same attribute with this

import onnx
import torch
import torch.nn as nn


class LayerNormNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln = nn.LayerNorm((128))

    def forward(self, x):
        out = self.ln(x)
        return out


net = LayerNormNet()
inp = torch.randn(1, 16384, 128)

torch.onnx.export(net, inp, "ln11.onnx", opset_version=11)
onnx.shape_inference.infer_shapes_path('ln11.onnx', 'ln11-si.onnx')

torch.onnx.export(net, inp, "ln17.onnx", opset_version=17)
onnx.shape_inference.infer_shapes_path('ln17.onnx', 'ln17-si.onnx')

@jinevening
Copy link
Contributor

got same attribute with this

I expected this :)

@jinevening
Copy link
Contributor

Introducing LN may lead to a lot of works to do (including npu compiler) but no visible benefit as of now. If onnx2circle generates the sequence in #14447 (comment), no additional work would be required.

@seanshpark
Copy link
Contributor Author

seanshpark commented Dec 13, 2024

If onnx2circle generates the sequence in (comment), no additional work would be required.

If I understand this correctly, in anyway around,
(1) ONNX model has LayerNorm
(2) onnx2circle converts to Transpose -> Reshape -> InstanceNorm -> Reshape -> Transpose

or

(1) ONNX model has LayerNorm
(2) onnx2circle converts to LayerNorm (of Circle IR)
(3) circle2circle converts to Transpose -> Reshape -> InstanceNorm -> Reshape -> Transpose

something like this?

Add;

  • if ONNX model has decomposed LayerNorm sub graph, we have to first fuse to LayerNorm(circle)
  • or support opset_version=17

@jinevening
Copy link
Contributor

something like this?

Yes. I imagined the first approach which does not need modification on circle schema.

if ONNX model has decomposed LayerNorm sub graph, we have to first fuse to LayerNorm(circle)

Even if onnx model has decomposed LN (due to low opset version), the fusion can be done inside onnx2circle, which does not affect circle schema.

Please note that it's just my opinion to minimize our workload considering the current generation of npu. For the next generation npu, we may need further discussion.

@seanshpark
Copy link
Contributor Author

As there is no particular benefit adding CircleLayerNorm IR as of now,
I'll close this issue after a day or two if there is no other opinions.
We may add new issue to provide as #14447 (comment) suggestion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants