Skip to content

Commit

Permalink
add instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
charlifu committed Jun 7, 2024
1 parent 0a9e750 commit 11726d3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
21 changes: 21 additions & 0 deletions ROCm_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,24 @@ Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order
On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`.
Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0.
The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel.

## Fp8 Quantization

To use fp8 quantization, first step is to quantize your model to fp8 format. Please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer) to generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be added under your model folder.

Then we can run a model with fp8 quantization using vllm. When creating `vLLM.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`.

## Gemm Tunning for Fp8

To get better performance of fp8 quantization, we will need to tune the gemm with the information of all the shapes used in the execution of the model.

To obtain all the shapes of gemms during the execution of the model, set the env value `TUNE_FP8=1` and the run the model as usual. We will get the a file called `/tmp/fp8_shapes.csv`.

Next, run gradlib to obtain the best solutions of these shapes:

```
python3 gradlib/gradlib/fp8_gemm_tunner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv
```
where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer.

Now, when running inference with fp8, we are using the tunned gemm for best performance.
14 changes: 7 additions & 7 deletions vllm/model_executor/layers/quantization/fp8_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def __init__(self) -> None:
#print(f"Integral Cross factor = {self.factor}")
if gemm_type == "fp8_8":
self.gemm_method = Fp8RocmLinearMethod.apply_fp8_8
tuned_filename = "/projects/tuned_fp8_8.csv"
tuned_filename = "/tmp/tuned_fp8_8.csv"
elif gemm_type == "fp8_16":
self.gemm_method = Fp8RocmLinearMethod.apply_fp8_16
tuned_filename = "/projects/tuned_fp8_16.csv"
tuned_filename = "/tmp/tuned_fp8_16.csv"
else:
raise Exception(f"Unknown fp8 gemm type: {gemm_type}")
try:
Expand All @@ -49,7 +49,7 @@ def __init__(self) -> None:
m = shape["M"]
n = shape["N"]
k = shape["K"]
algo = shape["algo"]
algo = shape["solidx"]
self._tuned[(m, n, k)] = algo

@classmethod
Expand Down Expand Up @@ -225,13 +225,13 @@ def apply_fp8_16(

if os.getenv("TUNE_FP8") == "1":
try:
df = pd.read_csv("/projects/fp8_shapes.csv")
df = pd.read_csv("/tmp/fp8_shapes.csv")
except:
df = pd.DataFrame(columns=["M", "N", "K"])
df = pd.concat(
[df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})]
).drop_duplicates()
df.to_csv("/projects/fp8_shapes.csv", index=False)
df.to_csv("/tmp/fp8_shapes.csv", index=False)
algo = 0
res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo))
return res
Expand All @@ -258,13 +258,13 @@ def apply_fp8_8(

if os.getenv("TUNE_FP8") == "1":
try:
df = pd.read_csv("/projects/fp8_shapes.csv")
df = pd.read_csv("/tmp/fp8_shapes.csv")
except:
df = pd.DataFrame(columns=["M", "N", "K"])
df = pd.concat(
[df, pd.DataFrame({"M": [m], "N": [n], "K": [k]})]
).drop_duplicates()
df.to_csv("/projects/fp8_shapese.csv", index=False)
df.to_csv("/tmp/fp8_shapese.csv", index=False)
algo = 0

res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo))
Expand Down

0 comments on commit 11726d3

Please sign in to comment.