-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce dynamic int8 quantization API #19263
Conversation
Add `quantize` to `Layer` Add `quantizers`
…or backwards compatibility
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #19263 +/- ##
==========================================
- Coverage 80.14% 75.66% -4.49%
==========================================
Files 341 365 +24
Lines 36163 39763 +3600
Branches 7116 7709 +593
==========================================
+ Hits 28982 30085 +1103
- Misses 5578 8012 +2434
- Partials 1603 1666 +63
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR -- awesome work!
I have fixed the issue in Colab: # prompt
output = gpt2_lm.generate("My trip to Yosemite was", max_length=200) # outputs
# mixed_bfloat16
My trip to Yosemite was pretty much a blur, but I managed to catch a few glimpses of the area and I was able to make out a few small, isolated, rocky ridges. The trail is a little steep, but the scenery is very scenic. The view is beautiful and I really enjoyed the views. The hike is a bit longer than the typical route but I enjoyed it a lot. It is a short, steep, and very quiet hike that takes me about 5-10 minutes to walk.
I was able to make out a couple of interesting points that you will see in the pictures below.
I was able to make out some interesting rocks that I had never seen before in Yosemite.
I did notice a few other interesting features that I didn't expect, but I was not sure what they were. I was able to make out a few of these features in the photos and then I was able to make out a couple other interesting rocks. I was able to
# quantized int8
My trip to Yosemite was pretty crazy. The sun is always shining, so we went to Yosemite to see it. I was a little disappointed that it didn't have the best view, because the views were so good. The views of the Yosemite Valley were really good, but there were some things that were really not great.
I was a huge fan of The Big Apple, but I didn't really like it. It was a bit like the big old house, with all the windows blown out. It was really cool to see all the different views of Yosemite. The view was really good, but there were some things that were just just really, really, terrible.
I was really happy that the views were good. It really was a good experience. I was really glad I had a good time, because they were really nice.
There's a lot of stuff in this book that is really interesting to me. I really enjoyed it, and I think that the I didn't see any obvious performance degradation in gpt2 after quantization. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! 👍
I've thought about it for a bit, but I can't make sense of it. Most of the Can you try with Gemma and Mistral to see if |
I think I have identified the root cause...
Run with Gemma output (before quantization):
What is Keras?
Keras is a high-level neural network API that provides a flexible and efficient interface for building and training neural networks. It is designed to be easy to use for both advanced researchers and novice users, and to provide a high-level API that hides the complexity of the underlying algorithms.
Gemma output (after quantization):
What is Keras?
Keras is a high-level neural network library that provides a flexible and efficient interface for building and running neural networks. It is designed to be easy to use and to provide a high-level API that is easy to learn and use.
Keras is built on top of Theano Run with Gemma output (before quantization):
What is Keras?
Keras is a high-level neural network API that provides a flexible and efficient interface for building and training neural networks. It is designed to be easy to use for both advanced researchers and novice users, and to provide a high-level API that hides the complexity of the underlying algorithms.
Gemma output (after quantization):
What is Keras?
Keras is an open source deep learning library that provides a high-level API for developing neural networks. It is designed to be easy to use and extend, and to provide a consistent interface for working with different types of neural networks.
Keras is built on top of Theano, The benchmark script: import time
import kagglehub
import keras_nlp
import keras
kagglehub.login()
# Setup Keras
keras.config.set_dtype_policy("bfloat16")
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
# Warmup
_ = gemma_lm.generate("What is Keras?", max_length=64)
# Predict
t0 = time.time()
output = gemma_lm.generate("What is Keras?", max_length=64)
cost_time = time.time() - t0
print("Gemma output (before quantization):")
print(output)
print(f"Cost time: {cost_time}")
# Quantize
gemma_lm.quantize("int8")
gemma_lm.generate_function = None # force `make_generate_function`
print("Finish quantization")
# Warmup
_ = gemma_lm.generate("What is Keras?", max_length=64)
# Run in int8
t0 = time.time()
output = gemma_lm.generate("What is Keras?", max_length=64)
cost_time = time.time() - t0
print("Gemma output (after quantization):")
print(output)
print(f"Cost time: {cost_time}")
weight_in_int8 = []
for weight in gemma_lm.weights:
if weight.dtype == "int8":
weight_in_int8.append(weight.name)
print("Number of weights in int8:", len(weight_in_int8)) |
This PR introduces dynamic int8 quantization API, following discussions with @fchollet
Highlights
FloatDtypePolicy
andQuantizedDtypePolicy
self.dtype_policy
fromLayer
toOperation
for dispatching the newquantized_call
keras.quantizers.*
quantize("int8")
inkeras.Model
quantize("int8")
inkeras.layers.Dense
andkeras.layers.EinsumDense
Notes
We want to reduce the memory footprint and enhance the inference speed for large model using int8 quantization.
Currently, we utilize post-training dynamic and symmetric quantization, a much easier technique compared to other quantization methods.
It's important that this simple quantization process works best for weights and activations that are centered at
0
.compute_dtype="float16"
doesn't work well due to issues with underflow/overflow.When encountering lora-enabled layers, the quantization process will firstly merge the lora weights before proceeding.
Note that
compute_dtype
inQuantizedDTypePolicy
is inherited fromFloatDTypePolicy
. This is necessary because setting it to int8 would cause the autocasting to fail in the quantized layer.Results on MNIST
Dense
orEinsumDense
BatchNormalization
ReLU
enable_lora(rank=2)
self.lora_enabled=True
self.lora_enabled=False
(merged)Dense
Dense
EinsumDense
EinsumDense
Dense
Dense
EinsumDense
EinsumDense
Dense
EinsumDense
The standalone demo script using new API:
# export KERAS_BACKEND=... python3 demo.py --dtype-policy float32 python3 demo.py --dtype-policy float32 --use-einsum python3 demo.py --dtype-policy mixed_bfloat16 python3 demo.py --dtype-policy mixed_bfloat16 --use-einsum
Acknowledgments