Skip to content

Commit

Permalink
adding fp8 gemm tunner to gradlib
Browse files Browse the repository at this point in the history
formatting

add instructions
  • Loading branch information
charlifu authored and mawong-amd committed Jun 7, 2024
1 parent 12d3b25 commit e96b25e
Show file tree
Hide file tree
Showing 4 changed files with 366 additions and 54 deletions.
21 changes: 21 additions & 0 deletions ROCm_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,24 @@ Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order
On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`.
Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0.
The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel.

## Fp8 Quantization

To use fp8 quantization, first step is to quantize your model to fp8 format. Please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer) to generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be added under your model folder.

Then we can run a model with fp8 quantization using vllm. When creating `vLLM.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`.

## Gemm Tunning for Fp8

To get better performance of fp8 quantization, we will need to tune the gemm with the information of all the shapes used in the execution of the model.

To obtain all the shapes of gemms during the execution of the model, set the env value `TUNE_FP8=1` and the run the model as usual. We will get the a file called `/tmp/fp8_shapes.csv`.

Next, run gradlib to obtain the best solutions of these shapes:

```
python3 gradlib/gradlib/fp8_gemm_tunner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv
```
where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer.

Now, when running inference with fp8, we are using the tunned gemm for best performance.
165 changes: 117 additions & 48 deletions gradlib/csrc/hipbsolgemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <hipblaslt/hipblaslt-ext.hpp>

#include <iostream>
#include <algorithm>
#include <limits>
#include <map>
#include <string>
Expand Down Expand Up @@ -115,6 +116,8 @@ namespace {

bool cout_print = false;

torch::Tensor dTensor;

//std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
}

Expand All @@ -132,23 +135,24 @@ std::vector<int> hipblasLtMatmul_findallsols_wrapper(
const void *beta,
void *c,
int ldc,
hipDataType dtype,
hipDataType intype,
hipDataType outtype,
hipStream_t &stream)
{
int flag { 0 };
hipblasLtMatrixLayout_t matA, matB, matC;
hipblasLtMatmulDesc_t matmul;
if (op_A == HIPBLAS_OP_N) {
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda));
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda));
} else {
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda));
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda));
}
if (op_B == HIPBLAS_OP_N) {
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb));
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb));
} else {
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb));
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb));
}
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc));
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t)));
Expand All @@ -163,10 +167,10 @@ std::vector<int> hipblasLtMatmul_findallsols_wrapper(
CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos(handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM,
op_A,
op_B,
dtype,
dtype,
dtype,
dtype,
intype,
intype,
outtype,
outtype,
HIPBLAS_COMPUTE_32F,
heuristicResult));

Expand Down Expand Up @@ -211,12 +215,16 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper(
const void *alpha,
const void *a,
int lda,
const void *scaleA,
const void *b,
int ldb,
const void *scaleB,
const void *beta,
void *c,
int ldc,
hipDataType dtype,
const void *scaleC,
hipDataType intype,
hipDataType outtype,
hipStream_t &stream,
int solution_index=-1)
{
Expand All @@ -232,21 +240,27 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper(
hipblasLtMatrixLayout_t matA, matB, matC;
hipblasLtMatmulDesc_t matmul;
if (op_A == HIPBLAS_OP_N) {
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda));
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda));
} else {
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda));
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda));
}
if (op_B == HIPBLAS_OP_N) {
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb));
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb));
} else {
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb));
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb));
}
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc));
CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t)));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t)));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, sizeof(scaleA)));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB)));
CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute(
matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC)));
//nvtxRangePop();
// if heuristic does not exist in the map, do search and push into the map
//auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } };
Expand All @@ -257,7 +271,7 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper(
std::cout << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" << std::endl;
if (cout_print) {
std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T")
<< " (" << m << ", " << n << ", " << k << "), dtype: " << dtype
<< " (" << m << ", " << n << ", " << k << "), dtype: " << intype
<< ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl;
}
//std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult(request_solutions);
Expand Down Expand Up @@ -385,7 +399,11 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper(
torch::Tensor HipbSolIdxBlas(
const torch::Tensor& mat1,
const torch::Tensor& mat2,
const int solution_index
const int solution_index,
at::optional<py::object> Type = at::nullopt,
at::optional<torch::Tensor> scale1 = at::nullopt,
at::optional<torch::Tensor> scale2 = at::nullopt,
at::optional<torch::Tensor> scaleOut = at::nullopt
)
{
auto mat1_strides { mat1.strides() };
Expand All @@ -402,8 +420,10 @@ torch::Tensor HipbSolIdxBlas(
);
TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0");

auto abcType { mat1.options().dtype() };
auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) };
auto inType { mat1.options().dtype() };
auto outType = inType.toScalarType();
if (Type.has_value()) outType = torch::python::detail::py_object_to_dtype(Type.value());
auto options { at::TensorOptions().dtype(outType).device(at::kCUDA) };
auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) };
// std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << std::endl;

Expand Down Expand Up @@ -448,35 +468,61 @@ torch::Tensor HipbSolIdxBlas(
int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0];
int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0];
int64_t result_ld = result.stride(transpose_result ? 0 : 1);
// std::cout << " | (m, n, k): " << m << ", " << n << ", " << k << std::endl
// << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " << result_ld << std::endl;

void * d_scale1 = nullptr, * d_scale2 = nullptr, * d_scaleOut = nullptr;
if (scale1.has_value()) {
d_scale1 = static_cast<void*>(scale1.value().data_ptr());
}
if (scale2.has_value()) {
d_scale2 = static_cast<void*>(scale2.value().data_ptr());
}
if (scaleOut.has_value()) {
d_scaleOut = static_cast<void*>(scaleOut.value().data_ptr());
}


hipDataType hipblasType;
if (abcType == at::kHalf) {
hipblasType = HIP_R_16F;
} else if (abcType == at::kBFloat16) {
hipblasType = HIP_R_16BF;
} else if (abcType == at::kFloat) {
hipblasType = HIP_R_32F;
hipDataType hipblasInType, hipblasOutType;
if (inType == at::kHalf) {
hipblasInType = HIP_R_16F;
} else if (inType == at::kBFloat16) {
hipblasInType = HIP_R_16BF;
} else if (inType == at::kFloat) {
hipblasInType = HIP_R_32F;
} else if (inType == at::kFloat8_e4m3fnuz) {
hipblasInType = HIP_R_8F_E4M3_FNUZ;
} else {
assert(false && "Wrong datatype!");
}

if (outType == at::kHalf) {
hipblasOutType = HIP_R_16F;
} else if (outType == at::kBFloat16) {
hipblasOutType = HIP_R_16BF;
} else if (outType == at::kFloat) {
hipblasOutType = HIP_R_32F;
} else if (outType == at::kFloat8_e4m3fnuz) {
hipblasOutType = HIP_R_8F_E4M3_FNUZ;
} else {
assert(false && "Wrong datatype!");
}
void *ptrA { static_cast<void *>((transpose_result ? mat2 : mat1).data_ptr()) };
void *ptrB { static_cast<void *>((transpose_result ? mat1 : mat2).data_ptr()) };
void *ptrC { static_cast<void *>(result.data_ptr()) };
if (transpose_result) std::swap(d_scale1, d_scale2);
auto current_stream { torch::hip::getCurrentHIPStream().stream() };

CHECK_HIPBLAS_ERROR(hipblasLtMatmul_sol_wrapper(
hipblaslt_handle,
transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N,
transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N,
m, n, k,
&one,
ptrA, mat1_ld,
ptrB, mat2_ld,
ptrA, mat1_ld, d_scale1,
ptrB, mat2_ld, d_scale2,
&zero,
ptrC, result_ld,
hipblasType,
ptrC, result_ld, d_scaleOut,
hipblasInType,
hipblasOutType,
current_stream,solution_index));

return result;
Expand All @@ -485,7 +531,8 @@ torch::Tensor HipbSolIdxBlas(
//find all hipblas solutions and return them to python land
std::vector<int> HipbFindAllSolIdxBlas(
const torch::Tensor& mat1,
const torch::Tensor& mat2
const torch::Tensor& mat2,
at::optional<py::object> Type = at::nullopt
)
{
auto mat1_strides { mat1.strides() };
Expand All @@ -499,8 +546,10 @@ std::vector<int> HipbFindAllSolIdxBlas(
);
TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0");

auto abcType { mat1.options().dtype() };
auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) };
auto inType { mat1.options().dtype() };
auto outType = inType.toScalarType();
if (Type.has_value()) outType = torch::python::detail::py_object_to_dtype(Type.value());
auto options { at::TensorOptions().dtype(outType).device(at::kCUDA) };
auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) };
bool transpose_result = true;
bool transpose_mat1;
Expand Down Expand Up @@ -536,13 +585,26 @@ std::vector<int> HipbFindAllSolIdxBlas(
int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0];
int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0];
int64_t result_ld = result.stride(transpose_result ? 0 : 1);
hipDataType hipblasType;
if (abcType == at::kHalf) {
hipblasType = HIP_R_16F;
} else if (abcType == at::kBFloat16) {
hipblasType = HIP_R_16BF;
} else if (abcType == at::kFloat) {
hipblasType = HIP_R_32F;
hipDataType hipblasInType, hipblasOutType;
if (inType == at::kHalf) {
hipblasInType = HIP_R_16F;
} else if (inType == at::kBFloat16) {
hipblasInType = HIP_R_16BF;
} else if (inType == at::kFloat) {
hipblasInType = HIP_R_32F;
} else if (inType == at::kFloat8_e4m3fnuz) {
hipblasInType = HIP_R_8F_E4M3_FNUZ;
} else {
assert(false && "Wrong datatype!");
}
if (outType == at::kHalf) {
hipblasOutType = HIP_R_16F;
} else if (outType == at::kBFloat16) {
hipblasOutType = HIP_R_16BF;
} else if (outType == at::kFloat) {
hipblasOutType = HIP_R_32F;
} else if (outType == at::kFloat8_e4m3fnuz) {
hipblasOutType = HIP_R_8F_E4M3_FNUZ;
} else {
assert(false && "Wrong datatype!");
}
Expand All @@ -561,7 +623,8 @@ std::vector<int> HipbFindAllSolIdxBlas(
ptrB, mat2_ld,
&zero,
ptrC, result_ld,
hipblasType,
hipblasInType,
hipblasOutType,
current_stream);

}
Expand Down Expand Up @@ -605,6 +668,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("hipb_create_extension", &hipb_create_extension, "create_extension");
m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension");
m.def("hipb_mm", &HipbSolIdxBlas, "mm");
m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols");
}
m.def("hipb_mm", &HipbSolIdxBlas, "mm",
py::arg("mat1"), py::arg("mat2"),
py::arg("solution_index"),
py::arg("outType")=at::nullopt,
py::arg("scale1")=at::nullopt, py::arg("scale2")=at::nullopt, py::arg("scaleOut")=at::nullopt);
m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols",
py::arg("mat1"), py::arg("mat2"),
py::arg("outType")=at::nullopt);
}
Loading

0 comments on commit e96b25e

Please sign in to comment.