Replies: 7 comments 25 replies
-
this is my gist to convert whisper model into fully HQQ, https://gist.github.com/huseinzol05/70daae3a4557616f315e7744ba3fcc93, but seems the speed is not faster than flash attention 2 on 30 second examples, but simple matmul is faster, https://gist.github.com/huseinzol05/ff59996034604d17c1e53074e9adc03f |
Beta Was this translation helpful? Give feedback.
-
So I was able to run a benchmark and compare with vanilla fp16:
It is not as straightforward because the encoder and decoder require different logics. import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
model_id = "openai/whisper-medium"
compute_dtype = torch.bfloat16 # please don't change this
device = "cuda:0"
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=compute_dtype)
processor = AutoProcessor.from_pretrained(model_id)
##############################################################################
#No quantize
#model = model.to(device)
##############################################################################
from hqq.models.hf.base import AutoHQQHFModel
from hqq.core.quantize import *
# Please keep nbits=4 and axis=1
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1)
HQQLinear.set_backend(HQQBackend.PYTORCH)
AutoHQQHFModel.quantize_model(model.model.encoder, quant_config=quant_config, compute_dtype=compute_dtype, device=device)
AutoHQQHFModel.quantize_model(model.model.decoder, quant_config=quant_config, compute_dtype=compute_dtype, device=device)
#Replace HQQLinear layers matmuls to support int4 mm
import hqq.models.base as hqq_base
hqq_base._QUANT_LAYERS = [torch.nn.Linear, HQQLinear]
from hqq.utils.patching import prepare_for_inference
AutoHQQHFModel.set_auto_linear_tags(model.model.encoder)
prepare_for_inference(model.model.encoder)
AutoHQQHFModel.set_auto_linear_tags(model.model.decoder)
prepare_for_inference(model.model.decoder, backend="torchao_int4")
model.model.encoder.forward = torch.compile(model.model.encoder.forward, mode="reduce-overhead", fullgraph=True)
model.model.decoder.forward = torch.compile(model.model.decoder.forward, mode="reduce-overhead", fullgraph=True)
# ##############################################################################
import time
import numpy as np
encoder_input = torch.randn([1, 80, 3000], dtype=compute_dtype, device=device)
def run_encoder():
with torch.no_grad():
model.model.encoder(encoder_input)
torch.cuda.synchronize()
t = []
for _ in range(200):
t1 = time.time()
run_encoder()
t2 = time.time()
t.append(t2-t1)
print("Encoder", np.mean(t[-100:]), "sec / sample")
decoder_input = torch.randint(0, 1000, [1, 1], dtype=torch.int64, device=device)
def run_decoder():
with torch.no_grad():
out = model.model.decoder(decoder_input)
torch.cuda.synchronize()
t = []
for _ in range(200):
t1 = time.time()
run_decoder()
t2 = time.time()
t.append(t2-t1)
print("Decoder", np.mean(t[-100:]), "sec / sample") |
Beta Was this translation helpful? Give feedback.
-
I tested the Torch.compile code and it works. But it does not work on distil models(distil-whisper/distil-large-v3). How can I solve this? Error Message: TorchRuntimeError: Failed running call_module L__self___conv1(*(FakeTensor(..., device='cuda:0', size=(1, 80, 3000), dtype=torch.bfloat16),), **{}):
Invalid channel dimensions
from user code:
File "/usr/local/lib/python3.10/dist-packages/transformers/models/whisper/modeling_whisper.py", line 1172, in forward
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information |
Beta Was this translation helpful? Give feedback.
-
@huseinzol05 @kadirnar But token decoding is definitely significantly faster with hqq with torchao backend and fullgraph compilation when you measure that alone as I shared here. Depending on the size of the caching, that speed-up will decrease a bit. |
Beta Was this translation helpful? Give feedback.
-
Feature request to add static cache support to Whisper: huggingface/transformers#30707 |
Beta Was this translation helpful? Give feedback.
-
@mobicham Is the code for testing with long-form audios (>30s) available publicly? |
Beta Was this translation helpful? Give feedback.
-
@mobicham Can I use torch.compile with HQQ optimization? |
Beta Was this translation helpful? Give feedback.
-
Moving the HQQ-Whisper conversation here.
Beta Was this translation helpful? Give feedback.
All reactions