Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unwanted nodes during ONNX QCDQ export #1135

Closed
leokonto opened this issue Dec 19, 2024 · 4 comments
Closed

Unwanted nodes during ONNX QCDQ export #1135

leokonto opened this issue Dec 19, 2024 · 4 comments
Labels
bug Something isn't working

Comments

@leokonto
Copy link

Hello Brevitas team,

I am currently trying to dynamically generate ONNX models based on model descriptions. The models are relatively constrained as of now and only Conv2d, MaxPool, FC and Add layers are supported. If the model I pass is sequential, the export in the QCDQ format works perfectly fine. However, once I add an Add layer unwanted/unexpected nodes are exported as well (see attached image ). At the beginning I though it might be due to my usage of custom quantizers (I constrain the scaling to po2). Nevertheless, the problem still exists even if I use the quantizers provided by Brevitas (Int8{Weight,Act}PerTensorFloat).

Do you know what could potentially cause this issue?

I created a static description of a model that I want to export, which looks as follows:

class ResNet50(nn.Module):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
        self.input = qnn.QuantIdentity(act_quant=Int8ActPerTensorFloat, return_quant_tensor=True)
        
        self.conv1 = qnn.QuantConv2d(
            in_channels=3,
            out_channels=64,
            kernel_size=(7,7),
            padding=(3,3),
            stride=(2,2),
            weight_quant=Int8WeightPerTensorFloat,
            output_quant=Int8ActPerTensorFloat,
            bias=False,
            return_quant_tensor=True
        )
        self.relu = qnn.QuantReLU(act_quant=Int8ActPerTensorFloat, return_quant_tensor=True)

        self.resconv1 = qnn.QuantConv2d(
            in_channels=64,
            out_channels=256,
            kernel_size=(1,1),
            padding=(0,0),
            stride=(1,1),
            weight_quant=Int8WeightPerTensorFloat,
            output_quant=Int8ActPerTensorFloat,
            bias=False,
            return_quant_tensor=True
        )
        
        self.s1_conv1 = qnn.QuantConv2d(
            in_channels=64,
            out_channels=64,
            kernel_size=(1,1),
            padding=(0,0),
            stride=(1,1),
            bias=False,
            weight_quant=Int8WeightPerTensorFloat,
            output_quant=Int8ActPerTensorFloat,
            return_quant_tensor=True
        )
        
        self.s1_conv2 = qnn.QuantConv2d(
            in_channels=64,
            out_channels=64,
            kernel_size=(3,3),
            padding=(1,1),
            stride=(1,1),
            bias=False,
            weight_quant=Int8WeightPerTensorFloat,
            output_quant=Int8ActPerTensorFloat,
            return_quant_tensor=True
        )
        
        self.s1_conv3 = qnn.QuantConv2d(
            in_channels=64,
            out_channels=256,
            kernel_size=(1,1),
            padding=(0,0),
            stride=(1,1),
            bias=False,
            weight_quant=Int8WeightPerTensorFloat,
            output_quant=Int8ActPerTensorFloat,
            return_quant_tensor=True
        )
        
    def forward(self, x):
        x = self.input(x)
        x = self.relu(x)
        x = self.conv1(x)
        x = torch.nn.functional.max_pool2d(x, kernel_size=(3,3), stride=(2,2))
        y = self.resconv1(x)
        x = self.s1_conv1(x)
        x = self.s1_conv2(x)
        x = self.s1_conv3(x)
        return x + y

Thank you in advance.

Best regards,
Leo

P.S. I tested it with different Torch versions but still no change

  • Brevitas version: 0.11.0
  • Torch version: 2.3.0
Screenshot 2024-12-19 at 12 40 40
@leokonto leokonto added the bug Something isn't working label Dec 19, 2024
@Giuseppe5
Copy link
Collaborator

Thanks for this issue.
Really appreciate the reproducibility script and the explanations!

I have an intuition. Can you try modifying the last line of the forward function from:

return x + y

to:

return (x + y).value

The idea is that a QuantTensor has extra hyper-parameters (scale, zero_point, etc.) that might cause unwanted nodes in the final ONNX graph.

If that is the issue, it should NOT cause any difference in terms of final results.

Calling .value on a QuantTensor, it will discard all the extra hyper-parameters and maybe it will also remove the unwanted nodes. If that doesn't solve the issue, please let me know, I will make some time tomorrow to look into this.

@leokonto
Copy link
Author

leokonto commented Dec 19, 2024

Thanks a lot for the reply. It seems to have fixed the issue!

Just one last question. Would only returning .value have any possible implications for the remaining layers? You mentioned that is should NOT cause any difference in the final result, but assuming the addition is not the end of the forward pass and the following layers expect quantized inputs.

@Giuseppe5
Copy link
Collaborator

Giuseppe5 commented Dec 19, 2024

I'd recommend to call .value only on the last QuantTensor before returning the final results if possible, especially if you rely on the concept of QuantTensor for gradients propagation or other quantization aspects.
Similarly, if your last layer is a QuantModule, you can simply set return_quant_tensor=False.

If you really want to have a QuantTensor as output and avoid the extra nodes in ONNX, a kind of hacky solution is to wrap your model just before export with a simple nn.Module whose only job is to call .value on the output of the wrapped model.

@Giuseppe5
Copy link
Collaborator

I am going to close this issue but feel free to add more comments if you have further questions related to this, and feel free to reach out if you face other problems with Brevitas.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants