Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce dynamic int8 quantization API #19263

Merged
merged 36 commits into from
Mar 8, 2024

Conversation

james77777778
Copy link
Contributor

@james77777778 james77777778 commented Mar 7, 2024

This PR introduces dynamic int8 quantization API, following discussions with @fchollet

Highlights

  • Introduce specialized subclasses FloatDtypePolicy and QuantizedDtypePolicy
  • Move self.dtype_policy from Layer to Operation for dispatching the new quantized_call
  • Introduce keras.quantizers.*
  • Introduce quantize("int8") in keras.Model
  • Introduce quantize("int8") in keras.layers.Dense and keras.layers.EinsumDense

Notes

We want to reduce the memory footprint and enhance the inference speed for large model using int8 quantization.
Currently, we utilize post-training dynamic and symmetric quantization, a much easier technique compared to other quantization methods.

It's important that this simple quantization process works best for weights and activations that are centered at 0.

compute_dtype="float16" doesn't work well due to issues with underflow/overflow.

When encountering lora-enabled layers, the quantization process will firstly merge the lora weights before proceeding.

Note that compute_dtype in QuantizedDTypePolicy is inherited from FloatDTypePolicy. This is necessary because setting it to int8 would cause the autocasting to fail in the quantized layer.

Results on MNIST

  • model architecture
    • Dense or EinsumDense
    • BatchNormalization
    • ReLU
  • fine-tuning with enable_lora(rank=2)
  • inference time: batch size=1024
    • float: self.lora_enabled=True
    • int8: self.lora_enabled=False (merged)
backend dtype_policy layer float acc. int8 acc. float inference time int8 inference time inference time ratio
tensorflow float32 Dense 0.95990 0.96000 0.00395s 0.00198s 0.501
tensorflow mixed_bfloat16 Dense 0.96110 0.96110 0.00265s 0.00200s 0.755
tensorflow float32 EinsumDense 0.95950 0.95920 0.00384s 0.00188s 0.490
tensorflow mixed_bfloat16 EinsumDense 0.95980 0.95970 0.00258s 0.00200s 0.775
jax float32 Dense 0.96130 0.96160 0.00304s 0.00132s 0.434
jax mixed_bfloat16 Dense 0.95290 0.95300 0.00177s 0.00133s 0.751
jax float32 EinsumDense 0.96170 0.96160 0.00302s 0.00132s 0.437
jax mixed_bfloat16 EinsumDense 0.95720 0.95680 0.00176s 0.00125s 0.710
torch float32 Dense 0.96050 0.96070 0.00834s 0.01182s 1.417 (slower)
torch float32 EinsumDense 0.96010 0.95990 0.00895s 0.01317s 1.472 (slower)

The slow inference in torch may be caused by additional casting and the eager mode.

The standalone demo script using new API:

# export KERAS_BACKEND=...
python3 demo.py --dtype-policy float32
python3 demo.py --dtype-policy float32 --use-einsum
python3 demo.py --dtype-policy mixed_bfloat16
python3 demo.py --dtype-policy mixed_bfloat16 --use-einsum
import argparse
import os
import time

import numpy as np

import keras
from keras import backend
from keras import dtype_policies
from keras import layers
from keras import models
from keras import ops
from keras import saving


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--dtype-policy",
        default="float32",
        choices=["float32", "mixed_bfloat16"],
    )
    parser.add_argument("--use-einsum", action="store_true")
    return parser.parse_args()


def build_model(num_layers=32, units=1024, use_einsum=False):
    inputs = layers.Input([28, 28])
    x = layers.Flatten()(inputs)
    for _ in range(num_layers):
        if use_einsum:
            x = layers.EinsumDense("ab,bc->ac", output_shape=[units])(x)
        else:
            x = layers.Dense(units)(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
    outputs = layers.Dense(10, use_bias=True, activation="softmax")(x)
    model = models.Model(inputs, outputs)
    return model


def enable_lora(model):
    for layer in model.layers:
        if hasattr(layer, "enable_lora"):
            layer.enable_lora(2)


def benchmark(model, batch_size=1024, input_shape=(28, 28), iterations=200):
    def fn(x):
        return model(x, training=False)

    if backend.backend() == "tensorflow":
        import tensorflow as tf

        jit_fn = tf.function(fn, jit_compile=True)
    elif backend.backend() == "jax":
        import jax

        jit_fn = jax.jit(fn)
    elif backend.backend() == "torch":
        jit_fn = fn
    else:
        jit_fn = fn

    # warmup
    x = ops.ones([batch_size, *input_shape])
    for _ in range(10):
        _ = ops.convert_to_numpy(jit_fn(x))

    times = []
    for _ in range(iterations):
        t0 = time.time()
        _ = ops.convert_to_numpy(jit_fn(x))
        t1 = time.time()
        times.append(t1 - t0)
    avg_time = sum(times) / len(times)
    return avg_time


def main():
    args = get_args()

    # Set dtype policy
    dtype = args.dtype_policy
    dtype_policies.dtype_policy.set_dtype_policy(dtype)
    print(f"Global dtype policy: {dtype_policies.dtype_policy.dtype_policy()}")

    # Model / data parameters
    use_einsum = args.use_einsum
    num_classes = 10
    input_shape = (28, 28, 1)
    epochs = 1

    # Load the data and split it between train and test sets
    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
    x_train = x_train.astype("float32") / 255
    x_test = x_test.astype("float32") / 255
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    y_train = keras.utils.to_categorical(y_train, num_classes)
    y_test = keras.utils.to_categorical(y_test, num_classes)

    model = build_model(num_layers=32, units=1024, use_einsum=use_einsum)
    model.compile(
        loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
    )

    """Train float model"""
    print("=====Start training float model=====")
    model.fit(
        x_train, y_train, batch_size=128, epochs=epochs, validation_split=0.1
    )
    print(f"Performance of {dtype}:")
    score = model.evaluate(x_test, y_test, verbose=0)
    print(f"  Test accuracy: {score[1]:.5f}")
    avg_time = benchmark(model, input_shape=input_shape)
    print(f"  Avg. inference time (batch_size=1024): {avg_time:.5f}s")
    model.save("model_fp32.keras")

    """Enable lora"""
    print("=====Enable lora weights=====")
    enable_lora(model)

    """Fine-tuning lora weights"""
    model.compile(
        loss="categorical_crossentropy",
        optimizer="adam",
        metrics=["accuracy"],
    )
    model.fit(
        x_train,
        y_train,
        batch_size=128,
        epochs=epochs,
        validation_split=0.1,
    )
    print("Performance of fine-tuned lora weights:")
    score = model.evaluate(x_test, y_test, verbose=0)
    print(f"  Test accuracy: {score[1]:.5f}")
    avg_time = benchmark(model, input_shape=input_shape)
    print(f"  Avg. inference time (batch_size=1024): {avg_time:.5f}s")

    """Quantize to int8 weights"""
    model.quantize(mode="int8")
    int8_model = model
    int8_model.compile(loss="categorical_crossentropy", metrics=["accuracy"])
    print("Performance of quantized model:")
    score = int8_model.evaluate(x_test, y_test, verbose=0)
    print(f"  Test accuracy: {score[1]:.5f}")
    avg_time = benchmark(int8_model, input_shape=input_shape)
    print(f"  Avg. inference time (batch_size=1024): {avg_time:.5f}s")

    """Saving & loading"""
    int8_model.save("model_int8.keras")
    reloaded_int8_model = saving.load_model("model_int8.keras")
    reloaded_score = reloaded_int8_model.evaluate(x_test, y_test, verbose=0)
    print(f"Reloaded int8 model test accuracy: {reloaded_score[1]:.5f}")
    print("Size of saved model:")
    print(f"  fp32: {os.path.getsize('model_fp32.keras') >> 20}MB")
    print(f"  int8: {os.path.getsize('model_int8.keras') >> 20}MB")

    """Cleanup"""
    os.remove("model_fp32.keras")
    os.remove("model_int8.keras")


if __name__ == "__main__":
    main()

Acknowledgments

@codecov-commenter
Copy link

codecov-commenter commented Mar 7, 2024

Codecov Report

Attention: Patch coverage is 76.81941% with 86 lines in your changes are missing coverage. Please review.

Project coverage is 75.66%. Comparing base (c8700f4) to head (50c6e74).
Report is 79 commits behind head on master.

Files Patch % Lines
keras/layers/core/einsum_dense.py 65.27% 38 Missing and 12 partials ⚠️
keras/layers/core/dense.py 76.47% 7 Missing and 5 partials ⚠️
keras/quantizers/__init__.py 76.66% 4 Missing and 3 partials ⚠️
keras/layers/layer.py 53.84% 3 Missing and 3 partials ⚠️
keras/dtype_policies/dtype_policy.py 92.98% 3 Missing and 1 partial ⚠️
keras/backend/torch/numpy.py 50.00% 1 Missing and 1 partial ⚠️
keras/ops/operation.py 80.00% 1 Missing and 1 partial ⚠️
keras/dtype_policies/__init__.py 87.50% 1 Missing ⚠️
keras/models/model.py 94.44% 0 Missing and 1 partial ⚠️
keras/quantizers/quantizers.py 97.05% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #19263      +/-   ##
==========================================
- Coverage   80.14%   75.66%   -4.49%     
==========================================
  Files         341      365      +24     
  Lines       36163    39763    +3600     
  Branches     7116     7709     +593     
==========================================
+ Hits        28982    30085    +1103     
- Misses       5578     8012    +2434     
- Partials     1603     1666      +63     
Flag Coverage Δ
keras 75.51% <76.81%> (-4.48%) ⬇️
keras-jax 59.70% <75.47%> (-3.36%) ⬇️
keras-numpy 54.24% <71.15%> (-2.84%) ⬇️
keras-tensorflow 61.21% <76.28%> (-3.44%) ⬇️
keras-torch 60.32% <76.01%> (-3.55%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR -- awesome work!

keras/dtype_policies/dtype_policy.py Show resolved Hide resolved
keras/dtype_policies/dtype_policy.py Outdated Show resolved Hide resolved
keras/layers/core/dense.py Outdated Show resolved Hide resolved
keras/layers/layer.py Outdated Show resolved Hide resolved
keras/models/model.py Outdated Show resolved Hide resolved
keras/models/model.py Outdated Show resolved Hide resolved
keras/quantizers/quantizers.py Show resolved Hide resolved
@james77777778
Copy link
Contributor Author

@fchollet

I have fixed the issue in Model.quantize where double quantization might occur with KerasNLP gpt2

Colab:
https://colab.research.google.com/drive/1DDChRgkEzgUmr1k1N_NbLha1_rjx3asw?usp=sharing

# prompt
output = gpt2_lm.generate("My trip to Yosemite was", max_length=200)
# outputs

# mixed_bfloat16
My trip to Yosemite was pretty much a blur, but I managed to catch a few glimpses of the area and I was able to make out a few small, isolated, rocky ridges. The trail is a little steep, but the scenery is very scenic. The view is beautiful and I really enjoyed the views. The hike is a bit longer than the typical route but I enjoyed it a lot. It is a short, steep, and very quiet hike that takes me about 5-10 minutes to walk.
I was able to make out a couple of interesting points that you will see in the pictures below.
I was able to make out some interesting rocks that I had never seen before in Yosemite.
I did notice a few other interesting features that I didn't expect, but I was not sure what they were. I was able to make out a few of these features in the photos and then I was able to make out a couple other interesting rocks. I was able to

# quantized int8
My trip to Yosemite was pretty crazy. The sun is always shining, so we went to Yosemite to see it. I was a little disappointed that it didn't have the best view, because the views were so good. The views of the Yosemite Valley were really good, but there were some things that were really not great.

I was a huge fan of The Big Apple, but I didn't really like it. It was a bit like the big old house, with all the windows blown out. It was really cool to see all the different views of Yosemite. The view was really good, but there were some things that were just just really, really, terrible.

I was really happy that the views were good. It really was a good experience. I was really glad I had a good time, because they were really nice.

There's a lot of stuff in this book that is really interesting to me. I really enjoyed it, and I think that the

I didn't see any obvious performance degradation in gpt2 after quantization.
However, I'm unsure how to benchmark the inference speed of GPT2. I've tried generate(...) but the speed didn't improve.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! 👍

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 8, 2024
@fchollet
Copy link
Collaborator

fchollet commented Mar 8, 2024

However, I'm unsure how to benchmark the inference speed of GPT2. I've tried generate(...) but the speed didn't improve.

I've thought about it for a bit, but I can't make sense of it. Most of the generate() time is the forward pass of the model, and most of that is EinsumDense calls. So if the inference speed of EinsumDense is improved, then the speed of generate() should be improved as well.

Can you try with Gemma and Mistral to see if generate() gets faster?

@fchollet fchollet merged commit ce06c65 into keras-team:master Mar 8, 2024
9 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase labels Mar 8, 2024
@james77777778
Copy link
Contributor Author

Can you try with Gemma and Mistral to see if generate() gets faster?

@fchollet

I think I have identified the root cause...
It appears that colab's T4 may not support int8xint8->int32 acceleration, as I got a speed-up on my rtx4070 with Gemma

  • keras.config.set_dtype_policy("bfloat16") (only this fits in 11gb vram)
  • prompt="What is Keras?"
backend bfloat16 int8 ratio
tensorflow 0.781s 0.712s 0.911
jax 0.694s 0.525s 0.756

Run with tensorflow backend:

Gemma output (before quantization):
What is Keras?

Keras is a high-level neural network API that provides a flexible and efficient interface for building and training neural networks. It is designed to be easy to use for both advanced researchers and novice users, and to provide a high-level API that hides the complexity of the underlying algorithms.

Gemma output (after quantization):
What is Keras?

Keras is a high-level neural network library that provides a flexible and efficient interface for building and running neural networks. It is designed to be easy to use and to provide a high-level API that is easy to learn and use.

Keras is built on top of Theano

Run with jax backend:

Gemma output (before quantization):
What is Keras?

Keras is a high-level neural network API that provides a flexible and efficient interface for building and training neural networks. It is designed to be easy to use for both advanced researchers and novice users, and to provide a high-level API that hides the complexity of the underlying algorithms.

Gemma output (after quantization):
What is Keras?

Keras is an open source deep learning library that provides a high-level API for developing neural networks. It is designed to be easy to use and extend, and to provide a consistent interface for working with different types of neural networks.

Keras is built on top of Theano,

The benchmark script:

import time

import kagglehub
import keras_nlp

import keras

kagglehub.login()

# Setup Keras
keras.config.set_dtype_policy("bfloat16")

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")

# Warmup
_ = gemma_lm.generate("What is Keras?", max_length=64)

# Predict
t0 = time.time()
output = gemma_lm.generate("What is Keras?", max_length=64)
cost_time = time.time() - t0
print("Gemma output (before quantization):")
print(output)
print(f"Cost time: {cost_time}")

# Quantize
gemma_lm.quantize("int8")
gemma_lm.generate_function = None  # force `make_generate_function`
print("Finish quantization")

# Warmup
_ = gemma_lm.generate("What is Keras?", max_length=64)

# Run in int8
t0 = time.time()
output = gemma_lm.generate("What is Keras?", max_length=64)
cost_time = time.time() - t0
print("Gemma output (after quantization):")
print(output)
print(f"Cost time: {cost_time}")

weight_in_int8 = []
for weight in gemma_lm.weights:
    if weight.dtype == "int8":
        weight_in_int8.append(weight.name)
print("Number of weights in int8:", len(weight_in_int8))

@james77777778 james77777778 deleted the add-quantize-v2 branch March 8, 2024 14:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants