Skip to content

Commit

Permalink
feat: AWQ quantization for InternVL
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Dec 4, 2024
1 parent 9fc6473 commit a65820e
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 25 deletions.
6 changes: 4 additions & 2 deletions aphrodite/modeling/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,8 @@ def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter,
# for the packing.
if isinstance(param, PackedAphroditeParameter
) and param.packed_dim == param.output_dim:
param.adjust_shard_indexes_for_packing(
shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)

loaded_weight_shard = loaded_weight.narrow(param.output_dim,
Expand Down Expand Up @@ -753,7 +754,8 @@ def _load_fused_module_from_checkpoint(self, param: BaseAphroditeParameter,
# for the packing.
if isinstance(param, PackedAphroditeParameter
) and param.packed_dim == param.output_dim:
param.adjust_shard_indexes_for_packing(
shard_size, shard_offset = \
param.adjust_shard_indexes_for_packing(
shard_size=shard_size, shard_offset=shard_offset)

loaded_weight_shard = loaded_weight.narrow(param.output_dim,
Expand Down
4 changes: 4 additions & 0 deletions aphrodite/modeling/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@ def get_quant_config(model_config: ModelConfig,
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
# some vision model may keep quantization_config in their text_config
hf_text_config = getattr(model_config.hf_config, "text_config", None)
if hf_quant_config is None and hf_text_config is not None:
hf_quant_config = getattr(hf_text_config, "quantization_config", None)
if hf_quant_config is None:
# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config",
Expand Down
35 changes: 13 additions & 22 deletions aphrodite/modeling/models/internlm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.key_value_groups = int(self.num_heads / self.num_kv_heads)
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
Expand Down Expand Up @@ -119,6 +120,14 @@ def __init__(
cache_config=cache_config,
quant_config=quant_config)

def split_qkv(self, qkv: torch.Tensor):
qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128)
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2)
q = q.reshape(-1, self.q_size)
k = k.reshape(-1, self.kv_size)
v = v.reshape(-1, self.kv_size)
return q, k, v

def forward(
self,
positions: torch.Tensor,
Expand All @@ -127,7 +136,7 @@ def forward(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k, v = self.split_qkv(qkv)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.wo(attn_output)
Expand Down Expand Up @@ -321,24 +330,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if "wqkv" in name:
config = self.config
kv_groups = (config.num_attention_heads //
config.num_key_value_heads)
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
head_dim,
loaded_weight.shape[-1])
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
dim=1)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
weight_loader = param.weight_loader
weight_loader(param, wq, 'q')
weight_loader(param, wk, 'k')
weight_loader(param, wv, 'v')
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
97 changes: 96 additions & 1 deletion tests/models/test_internvl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import types
from typing import List, Optional, Type
from typing import List, Optional, Tuple, Type

import pytest
import torch
Expand Down Expand Up @@ -178,6 +178,68 @@ def run_test(
)


def run_awq_test(
aphrodite_runner: Type[AphroditeRunner],
image_assets: _ImageAssets,
models: Tuple[str, str],
*,
size_factors: List[float],
dtype: str,
max_tokens: int,
num_logprobs: int,
tensor_parallel_size: int,
distributed_executor_backend: Optional[str] = None,
):
source_model, quant_model = models
images = [asset.pil_image for asset in image_assets]
inputs_per_image = [(
[prompt for _ in size_factors],
[rescale_image_size(image, factor) for factor in size_factors],
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
# NOTE: take care of the order. run Aphrodite first, and then run HF.
# Aphrodite needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
# max_model_len should be greater than image_feature_size
with aphrodite_runner(source_model,
max_model_len=4096,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as aphrodite_model:
source_outputs_per_image = [
aphrodite_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
with aphrodite_runner(quant_model,
quantization="awq",
max_model_len=4096,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
enforce_eager=True) as aphrodite_model:
quant_outputs_per_image = [
aphrodite_model.generate_greedy_logprobs(prompts,
max_tokens,
num_logprobs=num_logprobs,
images=images)
for prompts, images in inputs_per_image
]
for source_outputs, quant_outputs in zip(source_outputs_per_image,
quant_outputs_per_image):
# TODO: Check whether using original CLIPVisionModel can improve
# consistency against HF
check_logprobs_close(
outputs_0_lst=source_outputs,
outputs_1_lst=quant_outputs,
name_0="source",
name_1="awq",
)


target_dtype = "half"
if is_cpu():
target_dtype = "bfloat16"
Expand Down Expand Up @@ -214,3 +276,36 @@ def test_models(hf_runner, aphrodite_runner, image_assets, model, size_factors,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)


@pytest.mark.parametrize(
"models", [("OpenGVLab/InternVL2-2B", "OpenGVLab/InternVL2-2B-AWQ")])
@pytest.mark.parametrize(
"size_factors",
[
# No image
[],
# Single-scale
[1.0],
# Single-scale, batched
[1.0, 1.0, 1.0],
# Multi-scale
[0.25, 0.5, 1.0],
],
)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@torch.inference_mode()
def test_awq_models(aphrodite_runner, image_assets, models, size_factors,
dtype: str, max_tokens: int, num_logprobs: int) -> None:
run_awq_test(
aphrodite_runner,
image_assets,
models,
size_factors=size_factors,
dtype=dtype,
max_tokens=max_tokens,
num_logprobs=num_logprobs,
tensor_parallel_size=1,
)

0 comments on commit a65820e

Please sign in to comment.