diff --git a/ROCm_performance.md b/ROCm_performance.md index bea77d1a27fc4..1c47a818ec852 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -21,9 +21,23 @@ The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head ## Fp8 Quantization -To use fp8 quantization, first step is to quantize your model to fp8 format. Please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer) to generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder. +To use fp8 quantization, first step is to quantize your model to fp8 format. -Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`. +By default, rocm-vllm accepts the quantized weights generated by Quark quantizer. To do this, install quark and run the command: + +``` +python3 quantize_quark.py --model_dir [llama2 checkpoint folder] \ + --output_dir output_dir \ + --quant_scheme w_fp8_a_fp8_o_fp8 \ + --num_calib_data 128 \ + --export_safetensors \ + --no_weight_matrix_merge +``` +For more details, please refer to Quark's documentation. + +To use ammo, please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer), and set `VLLM_FP8_USE_AMMO=1`. + +Both quantizers generate a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder. Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`. ## Gemm Tuning for Fp8 diff --git a/csrc/quantization/fp8/amd/gemm_kernel.cu b/csrc/quantization/fp8/amd/gemm_kernel.cu index f8586b77d7792..c199591e0e0f4 100644 --- a/csrc/quantization/fp8/amd/gemm_kernel.cu +++ b/csrc/quantization/fp8/amd/gemm_kernel.cu @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include #include diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c7d63df353ca5..f935f6060e65a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -21,6 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only LLaMA model compatible with HuggingFace weights.""" +import os from typing import Any, Dict, Iterable, List, Optional, Tuple import torch @@ -441,57 +442,117 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_quantized_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - params_dict = dict(self.named_parameters()) - #with open("/projects/a.txt", "r") as f: - # j = json.load(f) - # for k, v in j.items(): - # params_dict[k].data.copy_(v) - quant_shards = [ - ("mlp.gate_up_proj", "mlp.fc", 0), # fc is gate_proj - ("mlp.gate_up_proj", "mlp.gate", 1), # gate is up_proj - ] - quant_map = [ - ("mlp.down_proj", "mlp.proj"), - ("self_attn.o_proj", "attention.dense"), - ("self_attn.qkv_proj", "attention.qkv"), - ] - for name, loaded_weight in weights: - #print(name) - name = name.replace('transformer', 'model') - name = name.replace('kv_cache_scaling_factor', - 'qkv.output_scaling_factor') - loaded_weight = loaded_weight.to("cuda") - if loaded_weight.dtype == torch.int8: - loaded_weight[loaded_weight == -128] = 0 - assert loaded_weight.is_contiguous - loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) - for (param_name, weight_name, shard_id) in quant_shards: - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + + def load_ammo(): + params_dict = dict(self.named_parameters()) + quant_shards = [ + ("mlp.gate_up_proj", "mlp.fc", 0), # fc is gate_proj + ("mlp.gate_up_proj", "mlp.gate", 1), # gate is up_proj + ] + quant_map = [ + ("mlp.down_proj", "mlp.proj"), + ("self_attn.o_proj", "attention.dense"), + ("self_attn.qkv_proj", "attention.qkv"), + ] + for name, loaded_weight in weights: + name = name.replace('transformer', 'model') + name = name.replace('kv_cache_scaling_factor', + 'qkv.output_scaling_factor') + loaded_weight = loaded_weight.to("cuda") + if loaded_weight.dtype == torch.int8: + loaded_weight[loaded_weight == -128] = 0 + assert loaded_weight.is_contiguous + loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) + for (param_name, weight_name, shard_id) in quant_shards: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + for (param_name, weight_name) in quant_map: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + if ("activation_scaling_factor" in name + or "weights_scaling_factor" in name + or "output_scaling_factor" in name): + param.data.copy_(loaded_weight) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + break + + def load_quark(): + params_dict = dict(self.named_parameters()) + quant_shards = [ + ("mlp.gate_up_proj", "mlp.gate_proj", 0), # fc is gate_proj + ("mlp.gate_up_proj", "mlp.up_proj", 1), # gate is up_proj + ] + quant_map = [ + ("mlp.down_proj", "mlp.down_proj"), + ("self_attn.o_proj", "self_attn.o_proj"), + ("self_attn.qkv_proj", "self_attn.qkv"), + ] + scaling_factor_map = [ + ("activation_scaling_factor", "input_quant_scale"), + ("weights_scaling_factor", "weight_quant_scale"), + ("output_scaling_factor", "output_quant_scale"), + ] + for name, loaded_weight in weights: + if "zero_point" in name: continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - for (param_name, weight_name) in quant_map: + if len(loaded_weight.shape) == 0: + loaded_weight = torch.Tensor([loaded_weight]) + # replace the name for scaling factor + for (scale_name, weight_name) in scaling_factor_map: + if weight_name not in name: + continue + name = name.replace(weight_name, scale_name) + if loaded_weight.dtype == torch.int8: + loaded_weight[loaded_weight == -128] = 0 + assert loaded_weight.is_contiguous + loaded_weight = loaded_weight.view(torch.float8_e4m3fnuz) + + for (param_name, weight_name, shard_id) in quant_shards: if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] - if ("activation_scaling_factor" in name - or "weights_scaling_factor" in name - or "output_scaling_factor" in name): - param.data.copy_(loaded_weight) - else: - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break + else: + # Skip loading extra bias for GPTQ models. + for (param_name, weight_name) in quant_map: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + if ("activation_scaling_factor" in name + or "weights_scaling_factor" in name + or "output_scaling_factor" in name): + param.data.copy_(loaded_weight) + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + break + + load_func = load_ammo if os.getenv( + "VLLM_FP8_USE_AMMO") == "1" else load_quark + load_func() # If this function is called, it should always initialize KV cache scale # factors (or else raise an exception). Thus, handled exceptions should