From d254de7f0b27b2c6513ac6f8e472a01de6c49a46 Mon Sep 17 00:00:00 2001 From: Charlie Fu Date: Mon, 10 Jun 2024 11:40:08 -0500 Subject: [PATCH] Adding fp8 to gradlib (#44) * adding fp8 gemm tunner to gradlib * formatting * add instructions * Linting * adding fp8 gemm tunner to gradlib formatting add instructions * Linting fp8 gradlib * fix merging issue of ROCm_performance.md * delete fp8_gemm_tuner.py * Fix linting for triton: unmeld if with constexpr * update tutorial * Fix linting again * fix typo --------- Co-authored-by: Matthew Wong --- ROCm_performance.md | 21 + gradlib/csrc/hipbsolgemm.cu | 813 +++++++++--------- gradlib/gradlib/fp8_gemm_tuner.py | 289 +++++++ .../layers/quantization/fp8_rocm.py | 12 +- 4 files changed, 731 insertions(+), 404 deletions(-) create mode 100644 gradlib/gradlib/fp8_gemm_tuner.py diff --git a/ROCm_performance.md b/ROCm_performance.md index 180c848a21950..bea77d1a27fc4 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -18,3 +18,24 @@ Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`. Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0. The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel. + +## Fp8 Quantization + +To use fp8 quantization, first step is to quantize your model to fp8 format. Please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer) to generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be placed under your model folder. + +Then we can run a model with fp8 quantization using vllm. When creating `vllm.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`. + +## Gemm Tuning for Fp8 + +To get better performance of fp8 quantization, we will need to tune the gemm with the information of all the shapes used in the execution of the model. + +To obtain all the shapes of gemms during the execution of the model, set the env value `TUNE_FP8=1` and then run the model as usual. We will get the a file called `/tmp/fp8_shapes.csv`. + +Next, run gradlib to obtain the best solutions of these shapes: + +``` +python3 gradlib/gradlib/fp8_gemm_tuner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv +``` +where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer. + +Now, when running inference with fp8, we are using the tuned gemm for best performance. \ No newline at end of file diff --git a/gradlib/csrc/hipbsolgemm.cu b/gradlib/csrc/hipbsolgemm.cu index bf15fb1297667..7888abb6e923c 100644 --- a/gradlib/csrc/hipbsolgemm.cu +++ b/gradlib/csrc/hipbsolgemm.cu @@ -1,9 +1,9 @@ // #ifdef __gfx908__ -// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below just for gfx908 and not for others -// // below lines enable hip float to half conversion which are disabled by default in hip_fp16.h -// #undef __HIP_NO_HALF_OPERATORS__ -// #undef __HIP_NO_HALF_CONVERSIONS__ -// #endif +// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below +// just for gfx908 and not for others +// // below lines enable hip float to half conversion which are disabled by +// default in hip_fp16.h #undef __HIP_NO_HALF_OPERATORS__ #undef +// __HIP_NO_HALF_CONVERSIONS__ #endif #include #include @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -31,168 +32,138 @@ #include #include "nvToolsExt.h" -//#include - +// #include // #ifdef USE_ROCM -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #endif +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #endif // #ifdef __HIP_PLATFORM_HCC__ -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #if USE_GEMM_FLAGS_FP16_ALT_IMPL +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #if USE_GEMM_FLAGS_FP16_ALT_IMPL // #ifdef ROCM_BACKWARD_PASS_GUARD -// flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; -// #endif -// #endif -// #endif +// flag = at::BackwardPassGuard::is_backward_pass() ? +// rocblas_gemm_flags_fp16_alt_impl : 0; #endif #endif #endif #ifndef CHECK_HIP_ERROR -#define CHECK_HIP_ERROR(error) \ - if(error != hipSuccess) \ - { \ - fprintf(stderr, \ - "Hip error: '%s'(%d) at %s:%d\n", \ - hipGetErrorString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif #ifndef CHECK_HIPBLAS_ERROR -#define CHECK_HIPBLAS_ERROR(error) \ - if(error != HIPBLAS_STATUS_SUCCESS) \ - { \ - fprintf(stderr, \ - "hipBLAS error: '%s'(%d) at %s:%d\n", \ - hipblasStatusToString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIPBLAS_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "hipBLAS error: '%s'(%d) at %s:%d\n", \ + hipblasStatusToString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif namespace { - /*thread_local*/ cudaStream_t weight_stream; - // BUG: DLM has event and stream on different devices error - // In multi-GPU scenerio, do names defined in this namespace exist on all devices? - // C++ keyword: thread_local <- maybe this can help? - /*thread_local*/ cudaEvent_t event; +/*thread_local*/ cudaStream_t weight_stream; +// BUG: DLM has event and stream on different devices error +// In multi-GPU scenerio, do names defined in this namespace exist on all +// devices? C++ keyword: thread_local <- maybe this can help? +/*thread_local*/ cudaEvent_t event; + +// hipBLASLt +hipblasLtHandle_t hipblaslt_handle; +hipblasLtMatmulPreference_t preference; +size_t workspace_size = 2 * 128 * 1024 * 1024; +// uint64_t workspace_size = 0; +void* d_workspace; +int request_solutions = 1; +int returnedAlgoCount = 0; + +struct MatMulConfig { + hipblasOperation_t op_A; + hipblasOperation_t op_B; + int M; + int N; + int K; + hipDataType dtype; + + friend auto operator<(const MatMulConfig& left, + const MatMulConfig& right) -> bool { + return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < + std::tie(right.op_A, right.op_B, right.M, right.N, right.K, + right.dtype); + } +}; - // hipBLASLt - hipblasLtHandle_t hipblaslt_handle; - hipblasLtMatmulPreference_t preference; - size_t workspace_size = 2*128*1024*1024; - //uint64_t workspace_size = 0; - void* d_workspace; - int request_solutions = 1; - int returnedAlgoCount = 0; - - struct MatMulConfig { - hipblasOperation_t op_A; - hipblasOperation_t op_B; - int M; - int N; - int K; - hipDataType dtype; - - friend auto operator<(const MatMulConfig& left, const MatMulConfig& right) -> bool { - return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < std::tie(right.op_A, right.op_B, right.M, right.N, right.K, right.dtype); - } - }; +// std::map, +// std::vector> heuristic_map; +std::map heuristic_map; - // std::map, std::vector> heuristic_map; - std::map heuristic_map; +hipEvent_t start, stop; +int bench_iters{1}; +int warmup_iters{1}; - hipEvent_t start, stop; - int bench_iters { 1 }; - int warmup_iters { 1 }; +bool cout_print = false; - bool cout_print = false; - - //std::vector heuristicResult; -} +torch::Tensor dTensor; -//find all hipblaslt solutions for given gemm problem +// std::vector heuristicResult; +} // namespace + +// find all hipblaslt solutions for given gemm problem std::vector hipblasLtMatmul_findallsols_wrapper( - hipblasLtHandle_t handle, - hipblasOperation_t op_A, - hipblasOperation_t op_B, - int m, int n, int k, - const void *alpha, - const void *a, - int lda, - const void *b, - int ldb, - const void *beta, - void *c, - int ldc, - hipDataType dtype, - hipStream_t &stream) -{ - int flag { 0 }; + hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, + int m, int n, int k, const void* alpha, const void* a, int lda, + const void* b, int ldb, const void* beta, void* c, int ldc, + hipDataType intype, hipDataType outtype, hipStream_t& stream) { + int flag{0}; hipblasLtMatrixLayout_t matA, matB, matC; hipblasLtMatmulDesc_t matmul; if (op_A == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); } if (op_B == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); } - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); - - //std::vector heuristicResult(10); - //CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( - // handle, matmul, matA, matB, matC, matC, - // preference, 10, heuristicResult.data(), &returnedAlgoCount)); + + // std::vector heuristicResult(10); + // CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( + // handle, matmul, matA, matB, matC, matC, + // preference, 10, heuristicResult.data(), &returnedAlgoCount)); std::vector heuristicResult; - CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos(handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, - op_A, - op_B, - dtype, - dtype, - dtype, - dtype, - HIPBLAS_COMPUTE_32F, - heuristicResult)); + CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos( + handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, op_A, op_B, intype, + intype, outtype, outtype, HIPBLAS_COMPUTE_32F, heuristicResult)); std::vector algoIndex; int returned_algo_count = heuristicResult.size(); - //for (int i = 0; i < returnedAlgoCount; i++) { + // for (int i = 0; i < returnedAlgoCount; i++) { for (int i = 0; i < returned_algo_count; i++) { - auto algo = heuristicResult[i].algo; - size_t ret_workspace_size = 0; - auto status = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, - alpha, - matA, - matB, - beta, - matC, - matC, - algo, - ret_workspace_size - ); - if (status == HIPBLAS_STATUS_SUCCESS) { - if (ret_workspace_size hipblasLtMatmul_findallsols_wrapper( ///////////////////////////////////////////////////////////////////////////////////////////////////////// /** * hipBLASLt GEMM call -*/ + */ hipblasStatus_t hipblasLtMatmul_sol_wrapper( - hipblasLtHandle_t handle, - hipblasOperation_t op_A, - hipblasOperation_t op_B, - int m, int n, int k, - const void *alpha, - const void *a, - int lda, - const void *b, - int ldb, - const void *beta, - void *c, - int ldc, - hipDataType dtype, - hipStream_t &stream, - int solution_index=-1) -{ + hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, + int m, int n, int k, const void* alpha, const void* a, int lda, + const void* scaleA, const void* b, int ldb, const void* scaleB, + const void* beta, void* c, int ldc, const void* scaleC, hipDataType intype, + hipDataType outtype, hipStream_t& stream, int solution_index = -1) { // TODO: flag is not supported for hipblasLt yet - int flag { 0 }; - //if (dtype == HIPBLAS_R_16F) { - // use fp16 alt impl for MI200 - // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - //flag = rocblas_gemm_flags_fp16_alt_impl; + int flag{0}; + // if (dtype == HIPBLAS_R_16F) { + // use fp16 alt impl for MI200 + // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + // flag = rocblas_gemm_flags_fp16_alt_impl; //} - //nvtxRangePushA("hipBLASLt variables creation"); + // nvtxRangePushA("hipBLASLt variables creation"); hipblasLtMatrixLayout_t matA, matB, matC; hipblasLtMatmulDesc_t matmul; if (op_A == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); } if (op_B == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); } - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); - //nvtxRangePop(); - // if heuristic does not exist in the map, do search and push into the map - //auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; - //if (heuristic_map.count(gemm_key) <= 0) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, sizeof(scaleA))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC))); + // nvtxRangePop(); + // if heuristic does not exist in the map, do search and push into the map + // auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; + // if (heuristic_map.count(gemm_key) <= 0) { std::vector heuristicResult(1); - if (solution_index<0) { - //nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); - std::cout << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" << std::endl; + if (solution_index < 0) { + // nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); + std::cout + << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" + << std::endl; if (cout_print) { - std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") - << " (" << m << ", " << n << ", " << k << "), dtype: " << dtype - << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl; + std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") + << (op_B == HIPBLAS_OP_N ? "N" : "T") << " (" << m << ", " << n + << ", " << k << "), dtype: " << intype << ", (lda, ldb, ldc): (" + << lda << ", " << ldb << ", " << ldc << "), " << std::endl; } - //std::vector heuristicResult(request_solutions); + // std::vector + // heuristicResult(request_solutions); CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( - handle, matmul, matA, matB, matC, matC, - preference, request_solutions, heuristicResult.data(), &returnedAlgoCount)); - if((returnedAlgoCount != request_solutions) && cout_print) { + handle, matmul, matA, matB, matC, matC, preference, request_solutions, + heuristicResult.data(), &returnedAlgoCount)); + if ((returnedAlgoCount != request_solutions) && cout_print) { std::cout << "less solution found! request: " << request_solutions << ", found: " << returnedAlgoCount << std::endl; } - //heuristic_map[gemm_key] = heuristicResult[0]; -/* - if (returnedAlgoCount == 1) { - heuristic_map[gemm_key] = heuristicResult[0]; - } else { - // benchmark requested solutions and pick best one - int bestIndex { -1 }; - double bestMs { std::numeric_limits::max() }; - for (int sol { 0 }; sol < returnedAlgoCount; ++sol) { - // warm up - for (int iter { 0 }; iter < warmup_iters; ++iter) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - // performance measuring - double eventMs; - CHECK_HIP_ERROR(hipEventRecord(start, stream)); - for (int iter { 0 }; iter < bench_iters; ++iter) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - CHECK_HIP_ERROR(hipEventRecord(stop, stream)); - CHECK_HIP_ERROR(hipEventSynchronize(stop)); - float temp; - CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); - eventMs = double(temp); - eventMs /= bench_iters; - - if (cout_print) { - std::cout << " Sol " << sol << ": average time per iter " << std::to_string(eventMs) << " ms"; - } - if (bestMs > eventMs) { - bestMs = eventMs; - bestIndex = sol; - if (cout_print) { - std::cout << " *" << std::endl; - } + // heuristic_map[gemm_key] = heuristicResult[0]; + /* + if (returnedAlgoCount == 1) { + heuristic_map[gemm_key] = heuristicResult[0]; } else { - if (cout_print) { - std::cout << std::endl; + // benchmark requested solutions and pick best one + int bestIndex { -1 }; + double bestMs { std::numeric_limits::max() }; + for (int sol { 0 }; sol < returnedAlgoCount; ++sol) { + // warm up + for (int iter { 0 }; iter < warmup_iters; ++iter) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, // In case beta != 0, these runs can overwrite the + values in c + // since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, workspace_size, + stream)); + } + // performance measuring + double eventMs; + CHECK_HIP_ERROR(hipEventRecord(start, stream)); + for (int iter { 0 }; iter < bench_iters; ++iter) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, // In case beta != 0, these runs can overwrite the + values in c + // since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, workspace_size, + stream)); + } + CHECK_HIP_ERROR(hipEventRecord(stop, stream)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + float temp; + CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); + eventMs = double(temp); + eventMs /= bench_iters; + + if (cout_print) { + std::cout << " Sol " << sol << ": average time per iter " << + std::to_string(eventMs) << " ms"; + } + if (bestMs > eventMs) { + bestMs = eventMs; + bestIndex = sol; + if (cout_print) { + std::cout << " *" << std::endl; + } + } else { + if (cout_print) { + std::cout << std::endl; + } + } } + heuristic_map[gemm_key] = heuristicResult[bestIndex]; } - } - heuristic_map[gemm_key] = heuristicResult[bestIndex]; - } -*/ - //nvtxRangePop(); + */ + // nvtxRangePop(); } else { - std::vector algoIndex(1); - algoIndex[0]=solution_index; - //std::vector tmpAlgo; - CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResult)); + std::vector algoIndex(1); + algoIndex[0] = solution_index; + // std::vector tmpAlgo; + CHECK_HIPBLAS_ERROR( + hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResult)); } - - //size_t ret_workspace_size = 0; - - //auto status1 = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, - // alpha, - // matA, - // matB, - // beta, - // matC, - // matC, - // heuristicResult[0].algo, - // ret_workspace_size + + // size_t ret_workspace_size = 0; + + // auto status1 = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, + // alpha, + // matA, + // matB, + // beta, + // matC, + // matC, + // heuristicResult[0].algo, + // ret_workspace_size //); - //if (status1 == HIPBLAS_STATUS_SUCCESS) { - // std::cout << "Workspace size" << ret_workspace_size << std::endl; + // if (status1 == HIPBLAS_STATUS_SUCCESS) { + // std::cout << "Workspace size" << ret_workspace_size << std::endl; //} else { - // std::cout << "Algo not supported!!!" << std::endl; + // std::cout << "Algo not supported!!!" << std::endl; //} - hipblasStatus_t status = hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, - &heuristicResult[0].algo, - d_workspace, workspace_size, - stream); - - //nvtxRangePushA("hipBLASLt variables deletion"); + hipblasStatus_t status = hipblasLtMatmul( + handle, matmul, alpha, a, matA, b, matB, beta, c, matC, c, matC, + &heuristicResult[0].algo, d_workspace, workspace_size, stream); + + // nvtxRangePushA("hipBLASLt variables deletion"); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); - //nvtxRangePop(); + // nvtxRangePop(); return status; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// torch::Tensor HipbSolIdxBlas( - const torch::Tensor& mat1, - const torch::Tensor& mat2, - const int solution_index - ) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; - // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | mat2 info: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; + const torch::Tensor& mat1, const torch::Tensor& mat2, + const int solution_index, at::optional Type = at::nullopt, + at::optional scale1 = at::nullopt, + at::optional scale2 = at::nullopt, + at::optional scaleOut = at::nullopt) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; + // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | mat2 info: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - auto abcType { mat1.options().dtype() }; - auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; - // std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << std::endl; + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); + + auto inType{mat1.options().dtype()}; + auto outType = inType.toScalarType(); + if (Type.has_value()) + outType = torch::python::detail::py_object_to_dtype(Type.value()); + auto options{at::TensorOptions().dtype(outType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; + // std::cout << " | result info: size: " << result.sizes() << " stride: " << + // result.strides() << std::endl; bool transpose_result = true; bool transpose_mat1; bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { transpose_mat2 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { transpose_mat1 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } if (transpose_result) { @@ -434,90 +414,120 @@ torch::Tensor HipbSolIdxBlas( mat1_sizes = mat2.sizes(); mat2_sizes = mat1.sizes(); } - // std::cout << " | transpose_result: " << (transpose_result ? "true" : "false") << std::endl - // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << std::endl - // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << std::endl; - // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | B matrix: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; - - float one { 1.0f }; - float zero { 0.0f }; + // std::cout << " | transpose_result: " << (transpose_result ? "true" : + // "false") << std::endl + // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << + // std::endl + // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << + // std::endl; + // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | B matrix: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; + + float one{1.0f}; + float zero{0.0f}; int64_t m = mat1_sizes[transpose_result ? 1 : 0]; int64_t k = mat1_sizes[transpose_result ? 0 : 1]; int64_t n = mat2_sizes[transpose_result ? 0 : 1]; int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); - // std::cout << " | (m, n, k): " << m << ", " << n << ", " << k << std::endl - // << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " << result_ld << std::endl; - - hipDataType hipblasType; - if (abcType == at::kHalf) { - hipblasType = HIP_R_16F; - } else if (abcType == at::kBFloat16) { - hipblasType = HIP_R_16BF; - } else if (abcType == at::kFloat) { - hipblasType = HIP_R_32F; + + void *d_scale1 = nullptr, *d_scale2 = nullptr, *d_scaleOut = nullptr; + if (scale1.has_value()) { + d_scale1 = static_cast(scale1.value().data_ptr()); + } + if (scale2.has_value()) { + d_scale2 = static_cast(scale2.value().data_ptr()); + } + if (scaleOut.has_value()) { + d_scaleOut = static_cast(scaleOut.value().data_ptr()); + } + + hipDataType hipblasInType, hipblasOutType; + if (inType == at::kHalf) { + hipblasInType = HIP_R_16F; + } else if (inType == at::kBFloat16) { + hipblasInType = HIP_R_16BF; + } else if (inType == at::kFloat) { + hipblasInType = HIP_R_32F; + } else if (inType == at::kFloat8_e4m3fnuz) { + hipblasInType = HIP_R_8F_E4M3_FNUZ; } else { assert(false && "Wrong datatype!"); } - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; + + if (outType == at::kHalf) { + hipblasOutType = HIP_R_16F; + } else if (outType == at::kBFloat16) { + hipblasOutType = HIP_R_16BF; + } else if (outType == at::kFloat) { + hipblasOutType = HIP_R_32F; + } else if (outType == at::kFloat8_e4m3fnuz) { + hipblasOutType = HIP_R_8F_E4M3_FNUZ; + } else { + assert(false && "Wrong datatype!"); + } + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + if (transpose_result) std::swap(d_scale1, d_scale2); + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; CHECK_HIPBLAS_ERROR(hipblasLtMatmul_sol_wrapper( - hipblaslt_handle, - transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - &one, - ptrA, mat1_ld, - ptrB, mat2_ld, - &zero, - ptrC, result_ld, - hipblasType, - current_stream,solution_index)); + hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, ptrA, + mat1_ld, d_scale1, ptrB, mat2_ld, d_scale2, &zero, ptrC, result_ld, + d_scaleOut, hipblasInType, hipblasOutType, current_stream, + solution_index)); return result; } -//find all hipblas solutions and return them to python land +// find all hipblas solutions and return them to python land std::vector HipbFindAllSolIdxBlas( - const torch::Tensor& mat1, - const torch::Tensor& mat2 - ) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; + const torch::Tensor& mat1, const torch::Tensor& mat2, + at::optional Type = at::nullopt) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - - auto abcType { mat1.options().dtype() }; - auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); + + auto inType{mat1.options().dtype()}; + auto outType = inType.toScalarType(); + if (Type.has_value()) + outType = torch::python::detail::py_object_to_dtype(Type.value()); + auto options{at::TensorOptions().dtype(outType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; bool transpose_result = true; bool transpose_mat1; bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { transpose_mat2 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { transpose_mat1 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } if (transpose_result) { bool tmp = transpose_mat1; @@ -528,83 +538,90 @@ std::vector HipbFindAllSolIdxBlas( mat1_sizes = mat2.sizes(); mat2_sizes = mat1.sizes(); } - float one { 1.0f }; - float zero { 0.0f }; + float one{1.0f}; + float zero{0.0f}; int64_t m = mat1_sizes[transpose_result ? 1 : 0]; int64_t k = mat1_sizes[transpose_result ? 0 : 1]; int64_t n = mat2_sizes[transpose_result ? 0 : 1]; int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); - hipDataType hipblasType; - if (abcType == at::kHalf) { - hipblasType = HIP_R_16F; - } else if (abcType == at::kBFloat16) { - hipblasType = HIP_R_16BF; - } else if (abcType == at::kFloat) { - hipblasType = HIP_R_32F; + hipDataType hipblasInType, hipblasOutType; + if (inType == at::kHalf) { + hipblasInType = HIP_R_16F; + } else if (inType == at::kBFloat16) { + hipblasInType = HIP_R_16BF; + } else if (inType == at::kFloat) { + hipblasInType = HIP_R_32F; + } else if (inType == at::kFloat8_e4m3fnuz) { + hipblasInType = HIP_R_8F_E4M3_FNUZ; + } else { + assert(false && "Wrong datatype!"); + } + if (outType == at::kHalf) { + hipblasOutType = HIP_R_16F; + } else if (outType == at::kBFloat16) { + hipblasOutType = HIP_R_16BF; + } else if (outType == at::kFloat) { + hipblasOutType = HIP_R_32F; + } else if (outType == at::kFloat8_e4m3fnuz) { + hipblasOutType = HIP_R_8F_E4M3_FNUZ; } else { assert(false && "Wrong datatype!"); } - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; return hipblasLtMatmul_findallsols_wrapper( - hipblaslt_handle, - transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - &one, - ptrA, mat1_ld, - ptrB, mat2_ld, - &zero, - ptrC, result_ld, - hipblasType, - current_stream); - + hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, ptrA, + mat1_ld, ptrB, mat2_ld, &zero, ptrC, result_ld, hipblasInType, + hipblasOutType, current_stream); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -void hipb_create_extension() -{ - //CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); - //CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); +void hipb_create_extension() { + // CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); + // CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); // hipBLASLt CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( - preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); - - //CHECK_HIP_ERROR(hipEventCreate(&start)); - //CHECK_HIP_ERROR(hipEventCreate(&stop)); + preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, + sizeof(workspace_size))); + + // CHECK_HIP_ERROR(hipEventCreate(&start)); + // CHECK_HIP_ERROR(hipEventCreate(&stop)); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -void hipb_destroy_extension() -{ - //CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); - //CHECK_HIP_ERROR(hipEventDestroy(event)); +void hipb_destroy_extension() { + // CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); + // CHECK_HIP_ERROR(hipEventDestroy(event)); - // hipBLASLt - CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); - CHECK_HIP_ERROR(hipFree(d_workspace)); + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); + CHECK_HIP_ERROR(hipFree(d_workspace)); - //CHECK_HIP_ERROR(hipEventDestroy(start)); - //CHECK_HIP_ERROR(hipEventDestroy(stop)); + // CHECK_HIP_ERROR(hipEventDestroy(start)); + // CHECK_HIP_ERROR(hipEventDestroy(stop)); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); - m.def("hipb_mm", &HipbSolIdxBlas, "mm"); - m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols"); -} + m.def("hipb_mm", &HipbSolIdxBlas, "mm", py::arg("mat1"), py::arg("mat2"), + py::arg("solution_index"), py::arg("outType") = at::nullopt, + py::arg("scale1") = at::nullopt, py::arg("scale2") = at::nullopt, + py::arg("scaleOut") = at::nullopt); + m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols", + py::arg("mat1"), py::arg("mat2"), py::arg("outType") = at::nullopt); +} \ No newline at end of file diff --git a/gradlib/gradlib/fp8_gemm_tuner.py b/gradlib/gradlib/fp8_gemm_tuner.py new file mode 100644 index 0000000000000..61df1933f8658 --- /dev/null +++ b/gradlib/gradlib/fp8_gemm_tuner.py @@ -0,0 +1,289 @@ +import argparse +import json +import os +import random +from pathlib import Path + +import hipbsolidxgemm +import pandas as pd +import torch +import torch.nn.functional as F + +hipbsolidxgemm.hipb_create_extension() + +rtol = 1e-5 +atol = 1 + + +class Fp8Gemm: + + def __init__(self, m, n, k, indtype, outdtype): + self.m = m + self.k = k + self.n = n + self.indtype = indtype + self.outdtype = outdtype + self.nb = 37 + self.inp = torch.randn((self.n, self.k), + device='cuda').to(self.indtype) + self.weights = torch.randn((self.m, self.k), + device='cuda').to(self.indtype) + # weights2 is used in measurement/warm iters to ensure HBM + # fetch for weight tensors + self.weights2 = torch.randn((self.nb, self.m, self.k), + device='cuda').to(self.indtype) + self.blob = torch.ones(128 * 1024 * 1024, + dtype=torch.float32, + device='cuda') + self.topn = 20 #number of top solutions from each source + self.hipb_sols = [] + self.rtol = 1e-5 + self.atol = 1 + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + + def find_hipblas_sols(self): + sols = hipbsolidxgemm.hipb_findallsols(self.inp, self.weights.t(), + self.outdtype) + print('M N K', + self.m, + self.n, + self.k, + '>>> Total hipb solutions', + len(sols), + flush=True) + #print(sols) + self.hipb_sols = sols + + def check_gemm_ref(self, libtype, solidx): + ref = F.linear(self.inp.to(torch.float32), + self.weights.to(torch.float32)).to(self.outdtype) + c = hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx, + self.outdtype) + if torch.allclose(c, ref, atol=self.atol, rtol=self.rtol): + #print('>>>',libtype,'Solidx',solidx,'passed reference test') + return True + else: + print('>>>', 'Solidx', solidx, 'FAILED reference test', flush=True) + print(ref, flush=True) + print(c, flush=True) + return False + + def hipb_time_sol(self, solidx, cold_iters=2, warm_iters=10): + #print('>>>hipbtime',solidx) + for i in range(cold_iters): + hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx, + self.outdtype) + self.start.record() + for i in range(warm_iters): + hipbsolidxgemm.hipb_mm( + self.inp, self.weights2[random.randint(0, self.nb - 1)].t(), + solidx, self.outdtype) + self.end.record() + torch.cuda.synchronize() + gtime = self.start.elapsed_time(self.end) / warm_iters + #print('>>> Solidx GTime',solidx,gtime,'ms') + return gtime + + def hipb_time_all_sols(self, fast_mode=0, top_sols=0): + coldi = 20 + warmi = 20 + if fast_mode: + coldi = 2 + warmi = 2 + solutions = self.hipb_sols + if top_sols: + solutions = self.hipb_top_sols + gtimes = {} + for solidx in solutions: + gtimes[solidx] = self.hipb_time_sol(solidx, + cold_iters=coldi, + warm_iters=warmi) + self.hipb_gtimedf = pd.DataFrame.from_dict( + gtimes, orient='index', + columns=['gtimems']).sort_values(by='gtimems') + self.hipb_gtimedf.to_csv('/tmp/hipb_gtimedf.csv') + print('>>> HipBlasLt top solutions, Fast Mode', fast_mode) + print(self.hipb_gtimedf.head(self.topn)) + + def warmup(self, warmi=500): + for i in range(warmi): + self.blob = self.blob + 0.00001 + + def functional_check_topn_fastest(self): + hipb_topn = [] + for solidx in self.hipb_gtimedf.index[:self.topn]: + if self.check_gemm_ref(libtype='hipblaslt', solidx=solidx): + hipb_topn.append(solidx) + self.hipb_top_sols = hipb_topn + + def find_fastest_solution(self): + self.find_hipblas_sols() + self.warmup() + self.hipb_time_all_sols(fast_mode=1) + self.functional_check_topn_fastest() + self.warmup() + self.hipb_time_all_sols(fast_mode=0, top_sols=1) + if len(self.hipb_gtimedf) > 0: + best_hipb_time = self.hipb_gtimedf.gtimems.iloc[0] + self.best_solidx = self.hipb_gtimedf.index[0] + self.best_soltime = best_hipb_time + else: + print('>>> No hipblas solutions found!', flush=True) + self.best_solidx = 0 + self.best_soltime = 0 + print('>>> Fastest Solution is', + self.best_solidx, + self.best_soltime, + flush=True) + + +class Fp8GemmTuner: + + def __init__(self, indtype, outdtype, tuned_file=None): + self.gemm_problems = pd.DataFrame(columns=['M', 'N', 'K']) + self.indtype = indtype + self.outdtype = outdtype + self.tuned_file = tuned_file + if Path(tuned_file).is_file(): + self.gdf = pd.read_csv(tuned_file) + else: + self.gdf = None + + def add_gemm(self, m, n, k): + if (self.gdf is None + or (self.gdf[(self.gdf['M'] == m) & (self.gdf['N'] == n) & + (self.gdf['K'] == k)].empty)): + entry = {'M': [m], 'N': [n], 'K': [k]} + df = pd.DataFrame(entry) + self.gemm_problems = pd.concat([self.gemm_problems, df], + ignore_index=True) + else: + print( + f">>>Info: Found Duplicate shape(M:{m}, N:{n}, K:{k}), skipping" + ) + + def find_best_sols(self): + df = self.gemm_problems + soldf = pd.DataFrame() + for i in range(len(df)): + ds = df.iloc[i] + gemmobj = Fp8Gemm(ds['M'], + ds['N'], + ds['K'], + indtype=self.indtype, + outdtype=self.outdtype) + gemmobj.find_fastest_solution() + soldf.loc[i, 'solidx'] = gemmobj.best_solidx + soldf.loc[i, 'soltimems'] = gemmobj.best_soltime + soldf['indtype'] = self.indtype + soldf['outdtype'] = self.outdtype + finaldf = pd.concat([self.gemm_problems, soldf], axis=1) + finaldf = pd.concat([finaldf, self.gdf]) + finaldf.to_csv(self.tuned_file, index=False) + print(finaldf) + + +def generate_mk_sets(model_dir, tp=1): + with open(f'{model_dir}/config.json') as f: + data = json.load(f) + hidden_size = data['hidden_size'] + intermediate_size = data['intermediate_size'] + total_num_heads = data['num_attention_heads'] + total_num_kv_heads = data['num_key_value_heads'] + head_dim = hidden_size // total_num_heads + return [((total_num_heads + (2 * total_num_kv_heads)) * head_dim // tp, + hidden_size), (hidden_size, hidden_size // tp), + (intermediate_size * 2 // tp, hidden_size), + (hidden_size, intermediate_size // tp)], hidden_size + + +def get_dtype(dtype_str): + dtype = torch.float8_e4m3fnuz + if dtype_str == 'f32': + dtype = torch.float32 + elif dtype_str == 'bf16': + dtype = torch.bfloat16 + elif dtype_str == 'f16': + dtype = torch.float16 + elif dtype_str == 'f8': + dtype = torch.float8_e4m3fnuz + else: + print('>>> Warning! Invalid dtype', dtype_str, + 'using default dtype f8') + return dtype + + +def list_of_ints(arg): + return list(map(int, arg.split(','))) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model_dir", + type=str, + default=os.getenv('GTUNE_MODEL', ""), + help="Enter the location of your model directory") + parser.add_argument("--tuned_file", + type=str, + default=os.getenv('GTUNE_TUNED', "tuned.csv"), + help="output file for tuned gemm solutions") + parser.add_argument( + "--input_file", + type=str, + default=os.getenv('GTUNE_INPUT', None), + help="list of gemms to tune for, mutually exclusive with model_dir") + parser.add_argument("--tp", + type=int, + default=os.getenv('GTUNE_TP', 1), + help="Tensor parallelism to be used.") + parser.add_argument("--indtype", + type=str, + default='f8', + help="dtype f32 f16 bf16 fp8") + parser.add_argument("--outdtype", + type=str, + default='f16', + help="dtype f32 f16 bf16 fp8") + parser.add_argument("--batch_size", + type=int, + default=os.getenv('GTUNE_BATCH_SIZE', 1), + help="Batch size to tune for") + parser.add_argument("--nsets", + type=list_of_ints, + default=[1, 512, 1024, 2048, 3072, 4096, 8192, 16384], + help="N sizes to tune for: 1,128,2048") + args = parser.parse_args() + + indtype = get_dtype(args.indtype) + outdtype = get_dtype(args.outdtype) + + gtuner = Fp8GemmTuner(indtype, outdtype, args.tuned_file) + nsets = [i * args.batch_size for i in args.nsets] + if args.input_file: + print(f">>> Loading {args.input_file}") + if not Path(args.input_file).is_file(): + print(f">>> ERROR: {args.input_file} does not exist. Exiting") + exit(1) + shapes = pd.read_csv(args.input_file) + for i in range(len(shapes)): + ds = shapes.iloc[i] + gtuner.add_gemm(ds['M'], ds['N'], ds['K']) + else: + if not args.model_dir: + print(">>> Warning! NO MODEL SPECIFIED. Tuning for LL2 13B TP1") + #LL2 13B sizes + mksets = [(15360, 5120), (5120, 5120), (27648, 5120), + (5120, 13824)] + gtuner.add_gemm(m=32000, n=1, k=5120) # logits gemm + else: + mksets, hidden_size = generate_mk_sets(args.model_dir, args.tp) + gtuner.add_gemm( + m=32000 // args.tp, n=1 * args.batch_size, k=hidden_size + ) #TODO: Handle cases where vocab_size is not divisible by tp + + for n in sorted(nsets): + for m, k in mksets: + gtuner.add_gemm(m, n, k) + + gtuner.find_best_sols() diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index ddccc5825c8a4..ddb83a6ed452e 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -30,10 +30,10 @@ def __init__(self) -> None: #print(f"Integral Cross factor = {self.factor}") if gemm_type == "fp8_8": self.gemm_method = Fp8RocmLinearMethod.apply_fp8_8 - tuned_filename = "/projects/tuned_fp8_8.csv" + tuned_filename = "/tmp/tuned_fp8_8.csv" elif gemm_type == "fp8_16": self.gemm_method = Fp8RocmLinearMethod.apply_fp8_16 - tuned_filename = "/projects/tuned_fp8_16.csv" + tuned_filename = "/tmp/tuned_fp8_16.csv" else: raise ValueError(f"Unknown fp8 gemm type: {gemm_type}") try: @@ -50,7 +50,7 @@ def __init__(self) -> None: m = shape["M"] n = shape["N"] k = shape["K"] - algo = shape["algo"] + algo = shape["solidx"] self._tuned[(m, n, k)] = algo @classmethod @@ -224,7 +224,7 @@ def apply_fp8_16( if os.getenv("TUNE_FP8") == "1": try: - df = pd.read_csv("/projects/fp8_shapes.csv") + df = pd.read_csv("/tmp/fp8_shapes.csv") except (IOError, pd.errors.EmptyDataError, pd.errors.ParserError): df = pd.DataFrame(columns=["M", "N", "K"]) @@ -234,7 +234,7 @@ def apply_fp8_16( "N": [n], "K": [k] })]).drop_duplicates() - df.to_csv("/projects/fp8_shapes.csv", index=False) + df.to_csv("/tmp/fp8_shapes.csv", index=False) algo = 0 res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) return res @@ -271,7 +271,7 @@ def apply_fp8_8( "N": [n], "K": [k] })]).drop_duplicates() - df.to_csv("/projects/fp8_shapes.csv", index=False) + df.to_csv("/tmp/fp8_shapes.csv", index=False) algo = 0 res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo))