Skip to content

Commit

Permalink
Supporting quantized weights from Quark by default. (#47)
Browse files Browse the repository at this point in the history
* support quark

* using torch/all.h

* loading weight from quark output

* support both ammo and quark

* Update doc

* fix load ammo

* fix linter

* fix isort
  • Loading branch information
charlifu authored Jun 13, 2024
1 parent ff24102 commit dc60612
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 48 deletions.
18 changes: 16 additions & 2 deletions ROCm_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,23 @@ The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head

## Fp8 Quantization

To use fp8 quantization, first step is to quantize your model to fp8 format. Please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer) to generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder.
To use fp8 quantization, first step is to quantize your model to fp8 format.

Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`.
By default, rocm-vllm accepts the quantized weights generated by Quark quantizer. To do this, install quark and run the command:

```
python3 quantize_quark.py --model_dir [llama2 checkpoint folder] \
--output_dir output_dir \
--quant_scheme w_fp8_a_fp8_o_fp8 \
--num_calib_data 128 \
--export_safetensors \
--no_weight_matrix_merge
```
For more details, please refer to Quark's documentation.

To use ammo, please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer), and set `VLLM_FP8_USE_AMMO=1`.

Both quantizers generate a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder. Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`.

## Gemm Tuning for Fp8

Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/fp8/amd/gemm_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include <cstdint>
#include <cstdio>

#include <torch/extension.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContextLight.h>
Expand Down
151 changes: 106 additions & 45 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import os
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
Expand Down Expand Up @@ -441,57 +442,117 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):

def load_quantized_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]):
params_dict = dict(self.named_parameters())
#with open("/projects/a.txt", "r") as f:
# j = json.load(f)
# for k, v in j.items():
# params_dict[k].data.copy_(v)
quant_shards = [
("mlp.gate_up_proj", "mlp.fc", 0), # fc is gate_proj
("mlp.gate_up_proj", "mlp.gate", 1), # gate is up_proj
]
quant_map = [
("mlp.down_proj", "mlp.proj"),
("self_attn.o_proj", "attention.dense"),
("self_attn.qkv_proj", "attention.qkv"),
]
for name, loaded_weight in weights:
#print(name)
name = name.replace('transformer', 'model')
name = name.replace('kv_cache_scaling_factor',
'qkv.output_scaling_factor')
loaded_weight = loaded_weight.to("cuda")
if loaded_weight.dtype == torch.int8:
loaded_weight[loaded_weight == -128] = 0
assert loaded_weight.is_contiguous
loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz)
for (param_name, weight_name, shard_id) in quant_shards:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:

def load_ammo():
params_dict = dict(self.named_parameters())
quant_shards = [
("mlp.gate_up_proj", "mlp.fc", 0), # fc is gate_proj
("mlp.gate_up_proj", "mlp.gate", 1), # gate is up_proj
]
quant_map = [
("mlp.down_proj", "mlp.proj"),
("self_attn.o_proj", "attention.dense"),
("self_attn.qkv_proj", "attention.qkv"),
]
for name, loaded_weight in weights:
name = name.replace('transformer', 'model')
name = name.replace('kv_cache_scaling_factor',
'qkv.output_scaling_factor')
loaded_weight = loaded_weight.to("cuda")
if loaded_weight.dtype == torch.int8:
loaded_weight[loaded_weight == -128] = 0
assert loaded_weight.is_contiguous
loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz)
for (param_name, weight_name, shard_id) in quant_shards:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
for (param_name, weight_name) in quant_map:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
if ("activation_scaling_factor" in name
or "weights_scaling_factor" in name
or "output_scaling_factor" in name):
param.data.copy_(loaded_weight)
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
break

def load_quark():
params_dict = dict(self.named_parameters())
quant_shards = [
("mlp.gate_up_proj", "mlp.gate_proj", 0), # fc is gate_proj
("mlp.gate_up_proj", "mlp.up_proj", 1), # gate is up_proj
]
quant_map = [
("mlp.down_proj", "mlp.down_proj"),
("self_attn.o_proj", "self_attn.o_proj"),
("self_attn.qkv_proj", "self_attn.qkv"),
]
scaling_factor_map = [
("activation_scaling_factor", "input_quant_scale"),
("weights_scaling_factor", "weight_quant_scale"),
("output_scaling_factor", "output_quant_scale"),
]
for name, loaded_weight in weights:
if "zero_point" in name:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
for (param_name, weight_name) in quant_map:
if len(loaded_weight.shape) == 0:
loaded_weight = torch.Tensor([loaded_weight])
# replace the name for scaling factor
for (scale_name, weight_name) in scaling_factor_map:
if weight_name not in name:
continue
name = name.replace(weight_name, scale_name)
if loaded_weight.dtype == torch.int8:
loaded_weight[loaded_weight == -128] = 0
assert loaded_weight.is_contiguous
loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz)

for (param_name, weight_name, shard_id) in quant_shards:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if ("activation_scaling_factor" in name
or "weights_scaling_factor" in name
or "output_scaling_factor" in name):
param.data.copy_(loaded_weight)
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
for (param_name, weight_name) in quant_map:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
if ("activation_scaling_factor" in name
or "weights_scaling_factor" in name
or "output_scaling_factor" in name):
param.data.copy_(loaded_weight)
else:
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
break

load_func = load_ammo if os.getenv(
"VLLM_FP8_USE_AMMO") == "1" else load_quark
load_func()

# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
Expand Down

0 comments on commit dc60612

Please sign in to comment.