Skip to content

Commit

Permalink
Update doc
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu committed Jun 13, 2024
1 parent 3fcc82d commit e4e3592
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions ROCm_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,23 @@ The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head

## Fp8 Quantization

To use fp8 quantization, first step is to quantize your model to fp8 format. Please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer) to generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder.
To use fp8 quantization, first step is to quantize your model to fp8 format.

Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`.
By default, rocm-vllm accepts the quantized weights generated by Quark quantizer. To do this, install quark and run the command:

```
python3 quantize_quark.py --model_dir [llama2 checkpoint folder] \
--output_dir output_dir \
--quant_scheme w_fp8_a_fp8_o_fp8 \
--num_calib_data 128 \
--export_safetensors \
--no_weight_matrix_merge
```
For more details, please refer to Quark's documentation.

To use ammo, please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer), and set `VLLM_FP8_USE_AMMO=1`.

Both quantizers generate a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder. Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`.

## Gemm Tuning for Fp8

Expand Down

0 comments on commit e4e3592

Please sign in to comment.