Skip to content

Commit

Permalink
Adding fp8 to gradlib (#44)
Browse files Browse the repository at this point in the history
* adding fp8 gemm tunner to gradlib

* formatting

* add instructions

* Linting

* adding fp8 gemm tunner to gradlib

formatting

add instructions

* Linting fp8 gradlib

* fix merging issue of ROCm_performance.md

* delete fp8_gemm_tuner.py

* Fix linting for triton: unmeld if with constexpr

* update tutorial

* Fix linting again

* fix typo

---------

Co-authored-by: Matthew Wong <[email protected]>
  • Loading branch information
charlifu and mawong-amd authored Jun 10, 2024
1 parent 95b3acc commit d254de7
Show file tree
Hide file tree
Showing 4 changed files with 731 additions and 404 deletions.
21 changes: 21 additions & 0 deletions ROCm_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,24 @@ Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order
On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`.
Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0.
The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel.

## Fp8 Quantization

To use fp8 quantization, first step is to quantize your model to fp8 format. Please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer) to generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder.

Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`.

## Gemm Tuning for Fp8

To get better performance of fp8 quantization, we will need to tune the gemm with the information of all the shapes used in the execution of the model.

To obtain all the shapes of gemms during the execution of the model, set the env value `TUNE_FP8=1` and then run the model as usual. We will get the a file called `/tmp/fp8_shapes.csv`.

Next, run gradlib to obtain the best solutions of these shapes:

```
python3 gradlib/gradlib/fp8_gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv
```
where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer.

Now, when running inference with fp8, we are using the tuned gemm for best performance.
Loading

0 comments on commit d254de7

Please sign in to comment.