You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am one of the main developers of Brevitas, and I have been recently working to include this optimization as part of our library.
Here you can find the PR where I am working on it, including an extension to also optimize the scale, based on the suggestion given in this other issue, and some practical experimentation with it.
I've been doing some experiments on CNNs (mostly for the ability to quickly iterate between configurations), and I have few questions that I was hoping you could help me with:
With per channel quantization, HQQ (zero point) + bias correction gives me the same accuracy as just bias correction without HQQ. I assume this is more or less normal since in both cases the goal is to reduce quantization error through an additional term, but I was curious about your opinion.
While implementing HQQ for scale point optimization, I noticed that I had to considerably increase the value of beta, otherwise the final accuracy would drop. Apparently, even though the mean error would go down, the max error would increase way too much, hindering the quantization process. Would you have any intuition why this is the case?
If I try to optimize both scale and zero point at the same time (in this order), I notice a considerable drop in accuracy compared to the case where either one of them is optimized. This seems a bit strange since a similar setup for MSE seems to work just fine.
I'll be working and testing the implementation on a few more use cases, including Transformers and Stable Diffusion, and expanding to per group quantization. I'll let you know in case I have more questions.
Thanks,
Giuseppe
The text was updated successfully, but these errors were encountered:
I haven't tested it on CNNs, only dense layers in transformer models (both LLMs and vision models), but I can try answering the questions:
Bias correction: it depends on the group-size. Assuming that you use axis=0 and no reshaping (group-size=num rows), in this case, the zero-point is just a bias, so it's not surprising. However, if you reduce the group-size, you should see better results than the bias term.
I found that optimizing the scale is a bit tricky because of instabilities. If we follow the same math, the solution to the scale should be torch.mean((Wq - z)/(Wf - We) axis=axis, keepdim=True), the We should tend to 0 as we converge. I found that this is very unstable, especially with fp16. Instead, I tried a grid search logic which basically does an exhaustive search in the neighborhood but makes sure that the updated scale is not too far from the input. It only slightly improves the results.
The thing is, if you use autograd to optimize all the parameters (Wq, scale, zero) for thousands of iterations ( you can try it here ), you'll notice that the most important parameter is the zero-point, the scaling parameter only gets updated slightly. I also run some more experiments and it's clear that the zero-point is more important than the scaling especially at lower bits. So I just kept it untouched for the moment.
You can still make it a trainable parameter with HQQ+, but it doesn't really improve the results imo, maybe it depends on the use-case and the group-sizes/nbits.
Let me know if you have any other questions, happy to assist!
Hi everyone,
First of all, thanks for this amazing work!
I am one of the main developers of Brevitas, and I have been recently working to include this optimization as part of our library.
Here you can find the PR where I am working on it, including an extension to also optimize the scale, based on the suggestion given in this other issue, and some practical experimentation with it.
I've been doing some experiments on CNNs (mostly for the ability to quickly iterate between configurations), and I have few questions that I was hoping you could help me with:
beta
, otherwise the final accuracy would drop. Apparently, even though the mean error would go down, the max error would increase way too much, hindering the quantization process. Would you have any intuition why this is the case?I'll be working and testing the implementation on a few more use cases, including Transformers and Stable Diffusion, and expanding to per group quantization. I'll let you know in case I have more questions.
Thanks,
Giuseppe
The text was updated successfully, but these errors were encountered: