diff --git a/gradlib/csrc/hipbsolgemm.cu b/gradlib/csrc/hipbsolgemm.cu index e5b7cbc7ea43f..7888abb6e923c 100644 --- a/gradlib/csrc/hipbsolgemm.cu +++ b/gradlib/csrc/hipbsolgemm.cu @@ -1,9 +1,9 @@ // #ifdef __gfx908__ -// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below just for gfx908 and not for others -// // below lines enable hip float to half conversion which are disabled by default in hip_fp16.h -// #undef __HIP_NO_HALF_OPERATORS__ -// #undef __HIP_NO_HALF_CONVERSIONS__ -// #endif +// // Uncomment ifdef and endif only if you need to undef the HIP_HALF ops below +// just for gfx908 and not for others +// // below lines enable hip float to half conversion which are disabled by +// default in hip_fp16.h #undef __HIP_NO_HALF_OPERATORS__ #undef +// __HIP_NO_HALF_CONVERSIONS__ #endif #include #include @@ -32,114 +32,93 @@ #include #include "nvToolsExt.h" -//#include - +// #include // #ifdef USE_ROCM -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #endif +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #endif // #ifdef __HIP_PLATFORM_HCC__ -// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) -// #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) -// #if USE_GEMM_FLAGS_FP16_ALT_IMPL +// #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + +// ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL +// (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) #if USE_GEMM_FLAGS_FP16_ALT_IMPL // #ifdef ROCM_BACKWARD_PASS_GUARD -// flag = at::BackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0; -// #endif -// #endif -// #endif +// flag = at::BackwardPassGuard::is_backward_pass() ? +// rocblas_gemm_flags_fp16_alt_impl : 0; #endif #endif #endif #ifndef CHECK_HIP_ERROR -#define CHECK_HIP_ERROR(error) \ - if(error != hipSuccess) \ - { \ - fprintf(stderr, \ - "Hip error: '%s'(%d) at %s:%d\n", \ - hipGetErrorString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIP_ERROR(error) \ + if (error != hipSuccess) { \ + fprintf(stderr, "Hip error: '%s'(%d) at %s:%d\n", \ + hipGetErrorString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif #ifndef CHECK_HIPBLAS_ERROR -#define CHECK_HIPBLAS_ERROR(error) \ - if(error != HIPBLAS_STATUS_SUCCESS) \ - { \ - fprintf(stderr, \ - "hipBLAS error: '%s'(%d) at %s:%d\n", \ - hipblasStatusToString(error), \ - error, \ - __FILE__, \ - __LINE__); \ - exit(EXIT_FAILURE); \ + #define CHECK_HIPBLAS_ERROR(error) \ + if (error != HIPBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "hipBLAS error: '%s'(%d) at %s:%d\n", \ + hipblasStatusToString(error), error, __FILE__, __LINE__); \ + exit(EXIT_FAILURE); \ } #endif namespace { - /*thread_local*/ cudaStream_t weight_stream; - // BUG: DLM has event and stream on different devices error - // In multi-GPU scenerio, do names defined in this namespace exist on all devices? - // C++ keyword: thread_local <- maybe this can help? - /*thread_local*/ cudaEvent_t event; +/*thread_local*/ cudaStream_t weight_stream; +// BUG: DLM has event and stream on different devices error +// In multi-GPU scenerio, do names defined in this namespace exist on all +// devices? C++ keyword: thread_local <- maybe this can help? +/*thread_local*/ cudaEvent_t event; + +// hipBLASLt +hipblasLtHandle_t hipblaslt_handle; +hipblasLtMatmulPreference_t preference; +size_t workspace_size = 2 * 128 * 1024 * 1024; +// uint64_t workspace_size = 0; +void* d_workspace; +int request_solutions = 1; +int returnedAlgoCount = 0; + +struct MatMulConfig { + hipblasOperation_t op_A; + hipblasOperation_t op_B; + int M; + int N; + int K; + hipDataType dtype; + + friend auto operator<(const MatMulConfig& left, + const MatMulConfig& right) -> bool { + return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < + std::tie(right.op_A, right.op_B, right.M, right.N, right.K, + right.dtype); + } +}; - // hipBLASLt - hipblasLtHandle_t hipblaslt_handle; - hipblasLtMatmulPreference_t preference; - size_t workspace_size = 2*128*1024*1024; - //uint64_t workspace_size = 0; - void* d_workspace; - int request_solutions = 1; - int returnedAlgoCount = 0; - - struct MatMulConfig { - hipblasOperation_t op_A; - hipblasOperation_t op_B; - int M; - int N; - int K; - hipDataType dtype; - - friend auto operator<(const MatMulConfig& left, const MatMulConfig& right) -> bool { - return std::tie(left.op_A, left.op_B, left.M, left.N, left.K, left.dtype) < std::tie(right.op_A, right.op_B, right.M, right.N, right.K, right.dtype); - } - }; +// std::map, +// std::vector> heuristic_map; +std::map heuristic_map; - // std::map, std::vector> heuristic_map; - std::map heuristic_map; +hipEvent_t start, stop; +int bench_iters{1}; +int warmup_iters{1}; - hipEvent_t start, stop; - int bench_iters { 1 }; - int warmup_iters { 1 }; +bool cout_print = false; - bool cout_print = false; - - torch::Tensor dTensor; - - //std::vector heuristicResult; -} +torch::Tensor dTensor; + +// std::vector heuristicResult; +} // namespace -//find all hipblaslt solutions for given gemm problem +// find all hipblaslt solutions for given gemm problem std::vector hipblasLtMatmul_findallsols_wrapper( - hipblasLtHandle_t handle, - hipblasOperation_t op_A, - hipblasOperation_t op_B, - int m, int n, int k, - const void *alpha, - const void *a, - int lda, - const void *b, - int ldb, - const void *beta, - void *c, - int ldc, - hipDataType intype, - hipDataType outtype, - hipStream_t &stream) -{ - int flag { 0 }; + hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, + int m, int n, int k, const void* alpha, const void* a, int lda, + const void* b, int ldb, const void* beta, void* c, int ldc, + hipDataType intype, hipDataType outtype, hipStream_t& stream) { + int flag{0}; hipblasLtMatrixLayout_t matA, matB, matC; hipblasLtMatmulDesc_t matmul; if (op_A == HIPBLAS_OP_N) { @@ -153,50 +132,38 @@ std::vector hipblasLtMatmul_findallsols_wrapper( CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); } CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); - - //std::vector heuristicResult(10); - //CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( - // handle, matmul, matA, matB, matC, matC, - // preference, 10, heuristicResult.data(), &returnedAlgoCount)); + + // std::vector heuristicResult(10); + // CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( + // handle, matmul, matA, matB, matC, matC, + // preference, 10, heuristicResult.data(), &returnedAlgoCount)); std::vector heuristicResult; - CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos(handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, - op_A, - op_B, - intype, - intype, - outtype, - outtype, - HIPBLAS_COMPUTE_32F, - heuristicResult)); + CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos( + handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, op_A, op_B, intype, + intype, outtype, outtype, HIPBLAS_COMPUTE_32F, heuristicResult)); std::vector algoIndex; int returned_algo_count = heuristicResult.size(); - //for (int i = 0; i < returnedAlgoCount; i++) { + // for (int i = 0; i < returnedAlgoCount; i++) { for (int i = 0; i < returned_algo_count; i++) { - auto algo = heuristicResult[i].algo; - size_t ret_workspace_size = 0; - auto status = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, - alpha, - matA, - matB, - beta, - matC, - matC, - algo, - ret_workspace_size - ); - if (status == HIPBLAS_STATUS_SUCCESS) { - if (ret_workspace_size hipblasLtMatmul_findallsols_wrapper( ///////////////////////////////////////////////////////////////////////////////////////////////////////// /** * hipBLASLt GEMM call -*/ + */ hipblasStatus_t hipblasLtMatmul_sol_wrapper( - hipblasLtHandle_t handle, - hipblasOperation_t op_A, - hipblasOperation_t op_B, - int m, int n, int k, - const void *alpha, - const void *a, - int lda, - const void *scaleA, - const void *b, - int ldb, - const void *scaleB, - const void *beta, - void *c, - int ldc, - const void *scaleC, - hipDataType intype, - hipDataType outtype, - hipStream_t &stream, - int solution_index=-1) -{ + hipblasLtHandle_t handle, hipblasOperation_t op_A, hipblasOperation_t op_B, + int m, int n, int k, const void* alpha, const void* a, int lda, + const void* scaleA, const void* b, int ldb, const void* scaleB, + const void* beta, void* c, int ldc, const void* scaleC, hipDataType intype, + hipDataType outtype, hipStream_t& stream, int solution_index = -1) { // TODO: flag is not supported for hipblasLt yet - int flag { 0 }; - //if (dtype == HIPBLAS_R_16F) { - // use fp16 alt impl for MI200 - // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices - //flag = rocblas_gemm_flags_fp16_alt_impl; + int flag{0}; + // if (dtype == HIPBLAS_R_16F) { + // use fp16 alt impl for MI200 + // https://pytorch.org/docs/stable/notes/numerical_accuracy.html#reduced-precision-fp16-and-bf16-gemms-and-convolutions-on-amd-instinct-mi200-devices + // flag = rocblas_gemm_flags_fp16_alt_impl; //} - //nvtxRangePushA("hipBLASLt variables creation"); + // nvtxRangePushA("hipBLASLt variables creation"); hipblasLtMatrixLayout_t matA, matB, matC; hipblasLtMatmulDesc_t matmul; if (op_A == HIPBLAS_OP_N) { @@ -250,7 +202,8 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper( CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); } CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); + CHECK_HIPBLAS_ERROR( + hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( @@ -261,188 +214,195 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper( matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC))); - //nvtxRangePop(); - // if heuristic does not exist in the map, do search and push into the map - //auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; - //if (heuristic_map.count(gemm_key) <= 0) { + // nvtxRangePop(); + // if heuristic does not exist in the map, do search and push into the map + // auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; + // if (heuristic_map.count(gemm_key) <= 0) { std::vector heuristicResult(1); - if (solution_index<0) { - //nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); - std::cout << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" << std::endl; + if (solution_index < 0) { + // nvtxRangePushA("hipblasLtMatmulAlgoGetHeuristic"); + std::cout + << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" + << std::endl; if (cout_print) { - std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") - << " (" << m << ", " << n << ", " << k << "), dtype: " << intype - << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl; + std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") + << (op_B == HIPBLAS_OP_N ? "N" : "T") << " (" << m << ", " << n + << ", " << k << "), dtype: " << intype << ", (lda, ldb, ldc): (" + << lda << ", " << ldb << ", " << ldc << "), " << std::endl; } - //std::vector heuristicResult(request_solutions); + // std::vector + // heuristicResult(request_solutions); CHECK_HIPBLAS_ERROR(hipblasLtMatmulAlgoGetHeuristic( - handle, matmul, matA, matB, matC, matC, - preference, request_solutions, heuristicResult.data(), &returnedAlgoCount)); - if((returnedAlgoCount != request_solutions) && cout_print) { + handle, matmul, matA, matB, matC, matC, preference, request_solutions, + heuristicResult.data(), &returnedAlgoCount)); + if ((returnedAlgoCount != request_solutions) && cout_print) { std::cout << "less solution found! request: " << request_solutions << ", found: " << returnedAlgoCount << std::endl; } - //heuristic_map[gemm_key] = heuristicResult[0]; -/* - if (returnedAlgoCount == 1) { - heuristic_map[gemm_key] = heuristicResult[0]; - } else { - // benchmark requested solutions and pick best one - int bestIndex { -1 }; - double bestMs { std::numeric_limits::max() }; - for (int sol { 0 }; sol < returnedAlgoCount; ++sol) { - // warm up - for (int iter { 0 }; iter < warmup_iters; ++iter) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - // performance measuring - double eventMs; - CHECK_HIP_ERROR(hipEventRecord(start, stream)); - for (int iter { 0 }; iter < bench_iters; ++iter) { - CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, // In case beta != 0, these runs can overwrite the values in c - // since c and d are the same - // TODO: allocates separate d memory for these runs - &heuristicResult[sol].algo, - d_workspace, workspace_size, - stream)); - } - CHECK_HIP_ERROR(hipEventRecord(stop, stream)); - CHECK_HIP_ERROR(hipEventSynchronize(stop)); - float temp; - CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); - eventMs = double(temp); - eventMs /= bench_iters; - - if (cout_print) { - std::cout << " Sol " << sol << ": average time per iter " << std::to_string(eventMs) << " ms"; - } - if (bestMs > eventMs) { - bestMs = eventMs; - bestIndex = sol; - if (cout_print) { - std::cout << " *" << std::endl; - } + // heuristic_map[gemm_key] = heuristicResult[0]; + /* + if (returnedAlgoCount == 1) { + heuristic_map[gemm_key] = heuristicResult[0]; } else { - if (cout_print) { - std::cout << std::endl; + // benchmark requested solutions and pick best one + int bestIndex { -1 }; + double bestMs { std::numeric_limits::max() }; + for (int sol { 0 }; sol < returnedAlgoCount; ++sol) { + // warm up + for (int iter { 0 }; iter < warmup_iters; ++iter) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, // In case beta != 0, these runs can overwrite the + values in c + // since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, workspace_size, + stream)); + } + // performance measuring + double eventMs; + CHECK_HIP_ERROR(hipEventRecord(start, stream)); + for (int iter { 0 }; iter < bench_iters; ++iter) { + CHECK_HIPBLAS_ERROR(hipblasLtMatmul(handle, matmul, + alpha, + a, matA, + b, matB, + beta, + c, matC, + c, matC, // In case beta != 0, these runs can overwrite the + values in c + // since c and d are the same + // TODO: allocates separate d memory for these runs + &heuristicResult[sol].algo, + d_workspace, workspace_size, + stream)); + } + CHECK_HIP_ERROR(hipEventRecord(stop, stream)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + float temp; + CHECK_HIP_ERROR(hipEventElapsedTime(&temp, start, stop)); + eventMs = double(temp); + eventMs /= bench_iters; + + if (cout_print) { + std::cout << " Sol " << sol << ": average time per iter " << + std::to_string(eventMs) << " ms"; + } + if (bestMs > eventMs) { + bestMs = eventMs; + bestIndex = sol; + if (cout_print) { + std::cout << " *" << std::endl; + } + } else { + if (cout_print) { + std::cout << std::endl; + } + } } + heuristic_map[gemm_key] = heuristicResult[bestIndex]; } - } - heuristic_map[gemm_key] = heuristicResult[bestIndex]; - } -*/ - //nvtxRangePop(); + */ + // nvtxRangePop(); } else { - std::vector algoIndex(1); - algoIndex[0]=solution_index; - //std::vector tmpAlgo; - CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResult)); + std::vector algoIndex(1); + algoIndex[0] = solution_index; + // std::vector tmpAlgo; + CHECK_HIPBLAS_ERROR( + hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResult)); } - - //size_t ret_workspace_size = 0; - - //auto status1 = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, - // alpha, - // matA, - // matB, - // beta, - // matC, - // matC, - // heuristicResult[0].algo, - // ret_workspace_size + + // size_t ret_workspace_size = 0; + + // auto status1 = hipblaslt_ext::matmulIsAlgoSupported(handle, matmul, + // alpha, + // matA, + // matB, + // beta, + // matC, + // matC, + // heuristicResult[0].algo, + // ret_workspace_size //); - //if (status1 == HIPBLAS_STATUS_SUCCESS) { - // std::cout << "Workspace size" << ret_workspace_size << std::endl; + // if (status1 == HIPBLAS_STATUS_SUCCESS) { + // std::cout << "Workspace size" << ret_workspace_size << std::endl; //} else { - // std::cout << "Algo not supported!!!" << std::endl; + // std::cout << "Algo not supported!!!" << std::endl; //} - hipblasStatus_t status = hipblasLtMatmul(handle, matmul, - alpha, - a, matA, - b, matB, - beta, - c, matC, - c, matC, - &heuristicResult[0].algo, - d_workspace, workspace_size, - stream); - - //nvtxRangePushA("hipBLASLt variables deletion"); + hipblasStatus_t status = hipblasLtMatmul( + handle, matmul, alpha, a, matA, b, matB, beta, c, matC, c, matC, + &heuristicResult[0].algo, d_workspace, workspace_size, stream); + + // nvtxRangePushA("hipBLASLt variables deletion"); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescDestroy(matmul)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matA)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matB)); CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutDestroy(matC)); - //nvtxRangePop(); + // nvtxRangePop(); return status; } ///////////////////////////////////////////////////////////////////////////////////////////////////////// torch::Tensor HipbSolIdxBlas( - const torch::Tensor& mat1, - const torch::Tensor& mat2, - const int solution_index, - at::optional Type = at::nullopt, + const torch::Tensor& mat1, const torch::Tensor& mat2, + const int solution_index, at::optional Type = at::nullopt, at::optional scale1 = at::nullopt, at::optional scale2 = at::nullopt, - at::optional scaleOut = at::nullopt - ) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; - // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | mat2 info: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; + at::optional scaleOut = at::nullopt) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; + // std::cout << " | mat1 info: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | mat2 info: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); - auto inType { mat1.options().dtype() }; + auto inType{mat1.options().dtype()}; auto outType = inType.toScalarType(); - if (Type.has_value()) outType = torch::python::detail::py_object_to_dtype(Type.value()); - auto options { at::TensorOptions().dtype(outType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; - // std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << std::endl; + if (Type.has_value()) + outType = torch::python::detail::py_object_to_dtype(Type.value()); + auto options{at::TensorOptions().dtype(outType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; + // std::cout << " | result info: size: " << result.sizes() << " stride: " << + // result.strides() << std::endl; bool transpose_result = true; bool transpose_mat1; bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { transpose_mat2 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { transpose_mat1 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } if (transpose_result) { @@ -454,14 +414,19 @@ torch::Tensor HipbSolIdxBlas( mat1_sizes = mat2.sizes(); mat2_sizes = mat1.sizes(); } - // std::cout << " | transpose_result: " << (transpose_result ? "true" : "false") << std::endl - // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << std::endl - // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << std::endl; - // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << mat1_strides << std::endl - // << " | B matrix: size: " << mat2_sizes << " stride: " << mat2_strides << std::endl; - - float one { 1.0f }; - float zero { 0.0f }; + // std::cout << " | transpose_result: " << (transpose_result ? "true" : + // "false") << std::endl + // << " | transpose_A: " << (transpose_mat1 ? "true" : "false") << + // std::endl + // << " | transpose_B: " << (transpose_mat2 ? "true" : "false") << + // std::endl; + // std::cout << " | A matrix: size: " << mat1_sizes << " stride: " << + // mat1_strides << std::endl + // << " | B matrix: size: " << mat2_sizes << " stride: " << + // mat2_strides << std::endl; + + float one{1.0f}; + float zero{0.0f}; int64_t m = mat1_sizes[transpose_result ? 1 : 0]; int64_t k = mat1_sizes[transpose_result ? 0 : 1]; int64_t n = mat2_sizes[transpose_result ? 0 : 1]; @@ -469,7 +434,7 @@ torch::Tensor HipbSolIdxBlas( int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); - void * d_scale1 = nullptr, * d_scale2 = nullptr, * d_scaleOut = nullptr; + void *d_scale1 = nullptr, *d_scale2 = nullptr, *d_scaleOut = nullptr; if (scale1.has_value()) { d_scale1 = static_cast(scale1.value().data_ptr()); } @@ -480,7 +445,6 @@ torch::Tensor HipbSolIdxBlas( d_scaleOut = static_cast(scaleOut.value().data_ptr()); } - hipDataType hipblasInType, hipblasOutType; if (inType == at::kHalf) { hipblasInType = HIP_R_16F; @@ -505,68 +469,65 @@ torch::Tensor HipbSolIdxBlas( } else { assert(false && "Wrong datatype!"); } - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; if (transpose_result) std::swap(d_scale1, d_scale2); - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; - + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; + CHECK_HIPBLAS_ERROR(hipblasLtMatmul_sol_wrapper( - hipblaslt_handle, - transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - &one, - ptrA, mat1_ld, d_scale1, - ptrB, mat2_ld, d_scale2, - &zero, - ptrC, result_ld, d_scaleOut, - hipblasInType, - hipblasOutType, - current_stream,solution_index)); + hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, ptrA, + mat1_ld, d_scale1, ptrB, mat2_ld, d_scale2, &zero, ptrC, result_ld, + d_scaleOut, hipblasInType, hipblasOutType, current_stream, + solution_index)); return result; } -//find all hipblas solutions and return them to python land +// find all hipblas solutions and return them to python land std::vector HipbFindAllSolIdxBlas( - const torch::Tensor& mat1, - const torch::Tensor& mat2, - at::optional Type = at::nullopt - ) -{ - auto mat1_strides { mat1.strides() }; - auto mat2_strides { mat2.strides() }; - auto mat1_sizes { mat1.sizes() }; - auto mat2_sizes { mat2.sizes() }; + const torch::Tensor& mat1, const torch::Tensor& mat2, + at::optional Type = at::nullopt) { + auto mat1_strides{mat1.strides()}; + auto mat2_strides{mat2.strides()}; + auto mat1_sizes{mat1.sizes()}; + auto mat2_sizes{mat2.sizes()}; TORCH_CHECK(mat1.dim() == 2 && mat2.dim() == 2, "tensors must be 2-D"); - TORCH_CHECK( - mat1.dtype() == mat2.dtype(), - "expected mat1 and mat2 to have the same dtype, but got: ", mat1.dtype(), " != ", mat2.dtype() - ); - TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); + TORCH_CHECK(mat1.dtype() == mat2.dtype(), + "expected mat1 and mat2 to have the same dtype, but got: ", + mat1.dtype(), " != ", mat2.dtype()); + TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], + "mat1 dim 1 must match mat2 dim 0"); - auto inType { mat1.options().dtype() }; + auto inType{mat1.options().dtype()}; auto outType = inType.toScalarType(); - if (Type.has_value()) outType = torch::python::detail::py_object_to_dtype(Type.value()); - auto options { at::TensorOptions().dtype(outType).device(at::kCUDA) }; - auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; + if (Type.has_value()) + outType = torch::python::detail::py_object_to_dtype(Type.value()); + auto options{at::TensorOptions().dtype(outType).device(at::kCUDA)}; + auto result{torch::empty({mat1_sizes[0], mat2_sizes[1]}, options)}; bool transpose_result = true; bool transpose_mat1; bool transpose_mat2; - if ((mat2_strides[0] == 1) && (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { + if ((mat2_strides[0] == 1) && + (mat2_strides[1] >= std::max(1, mat2_sizes[0]))) { transpose_mat2 = false; - } else if ((mat2_strides[1] == 1) && (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { + } else if ((mat2_strides[1] == 1) && + (mat2_strides[0] >= std::max(1, mat2_sizes[1]))) { transpose_mat2 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } - if ((mat1_strides[0] == 1) && (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { + if ((mat1_strides[0] == 1) && + (mat1_strides[1] >= std::max(1, mat1_sizes[0]))) { transpose_mat1 = false; - } else if ((mat1_strides[1] == 1) && (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { + } else if ((mat1_strides[1] == 1) && + (mat1_strides[0] >= std::max(1, mat1_sizes[1]))) { transpose_mat1 = true; } else { - assert(false && "unusual strides detected, may need to clone a contiguous tensor"); + assert(false && + "unusual strides detected, may need to clone a contiguous tensor"); } if (transpose_result) { bool tmp = transpose_mat1; @@ -577,8 +538,8 @@ std::vector HipbFindAllSolIdxBlas( mat1_sizes = mat2.sizes(); mat2_sizes = mat1.sizes(); } - float one { 1.0f }; - float zero { 0.0f }; + float one{1.0f}; + float zero{0.0f}; int64_t m = mat1_sizes[transpose_result ? 1 : 0]; int64_t k = mat1_sizes[transpose_result ? 0 : 1]; int64_t n = mat2_sizes[transpose_result ? 0 : 1]; @@ -608,72 +569,59 @@ std::vector HipbFindAllSolIdxBlas( } else { assert(false && "Wrong datatype!"); } - void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; - void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; - void *ptrC { static_cast(result.data_ptr()) }; - auto current_stream { torch::hip::getCurrentHIPStream().stream() }; + void* ptrA{static_cast((transpose_result ? mat2 : mat1).data_ptr())}; + void* ptrB{static_cast((transpose_result ? mat1 : mat2).data_ptr())}; + void* ptrC{static_cast(result.data_ptr())}; + auto current_stream{torch::hip::getCurrentHIPStream().stream()}; return hipblasLtMatmul_findallsols_wrapper( - hipblaslt_handle, - transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, - m, n, k, - &one, - ptrA, mat1_ld, - ptrB, mat2_ld, - &zero, - ptrC, result_ld, - hipblasInType, - hipblasOutType, - current_stream); - + hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, + transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, ptrA, + mat1_ld, ptrB, mat2_ld, &zero, ptrC, result_ld, hipblasInType, + hipblasOutType, current_stream); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -void hipb_create_extension() -{ - //CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); - //CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); +void hipb_create_extension() { + // CHECK_HIP_ERROR(hipStreamCreate(&weight_stream)); + // CHECK_HIP_ERROR(hipEventCreateWithFlags(&event, cudaEventDisableTiming)); // hipBLASLt CHECK_HIPBLAS_ERROR(hipblasLtCreate(&hipblaslt_handle)); CHECK_HIP_ERROR(hipMalloc(&d_workspace, workspace_size)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceCreate(&preference)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceSetAttribute( - preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); - - //CHECK_HIP_ERROR(hipEventCreate(&start)); - //CHECK_HIP_ERROR(hipEventCreate(&stop)); + preference, HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, + sizeof(workspace_size))); + + // CHECK_HIP_ERROR(hipEventCreate(&start)); + // CHECK_HIP_ERROR(hipEventCreate(&stop)); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -void hipb_destroy_extension() -{ - //CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); - //CHECK_HIP_ERROR(hipEventDestroy(event)); +void hipb_destroy_extension() { + // CHECK_HIP_ERROR(hipStreamDestroy(weight_stream)); + // CHECK_HIP_ERROR(hipEventDestroy(event)); - // hipBLASLt - CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); - CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); - CHECK_HIP_ERROR(hipFree(d_workspace)); + // hipBLASLt + CHECK_HIPBLAS_ERROR(hipblasLtDestroy(hipblaslt_handle)); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulPreferenceDestroy(preference)); + CHECK_HIP_ERROR(hipFree(d_workspace)); - //CHECK_HIP_ERROR(hipEventDestroy(start)); - //CHECK_HIP_ERROR(hipEventDestroy(stop)); + // CHECK_HIP_ERROR(hipEventDestroy(start)); + // CHECK_HIP_ERROR(hipEventDestroy(stop)); } ///////////////////////////////////////////////////////////////////////////////////////////////////////// -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) -{ +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); - m.def("hipb_mm", &HipbSolIdxBlas, "mm", - py::arg("mat1"), py::arg("mat2"), - py::arg("solution_index"), - py::arg("outType")=at::nullopt, - py::arg("scale1")=at::nullopt, py::arg("scale2")=at::nullopt, py::arg("scaleOut")=at::nullopt); - m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols", - py::arg("mat1"), py::arg("mat2"), - py::arg("outType")=at::nullopt); + m.def("hipb_mm", &HipbSolIdxBlas, "mm", py::arg("mat1"), py::arg("mat2"), + py::arg("solution_index"), py::arg("outType") = at::nullopt, + py::arg("scale1") = at::nullopt, py::arg("scale2") = at::nullopt, + py::arg("scaleOut") = at::nullopt); + m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols", + py::arg("mat1"), py::arg("mat2"), py::arg("outType") = at::nullopt); } \ No newline at end of file diff --git a/gradlib/gradlib/fp8_gemm_tuner.py b/gradlib/gradlib/fp8_gemm_tuner.py index c4d18d8952d47..61df1933f8658 100644 --- a/gradlib/gradlib/fp8_gemm_tuner.py +++ b/gradlib/gradlib/fp8_gemm_tuner.py @@ -5,7 +5,6 @@ from pathlib import Path import hipbsolidxgemm -import numpy as np import pandas as pd import torch import torch.nn.functional as F @@ -29,7 +28,8 @@ def __init__(self, m, n, k, indtype, outdtype): device='cuda').to(self.indtype) self.weights = torch.randn((self.m, self.k), device='cuda').to(self.indtype) - #weights2 is used in measurement/warm iters to ensure HBM fetch for weight tensors + # weights2 is used in measurement/warm iters to ensure HBM + # fetch for weight tensors self.weights2 = torch.randn((self.nb, self.m, self.k), device='cuda').to(self.indtype) self.blob = torch.ones(128 * 1024 * 1024, @@ -72,11 +72,11 @@ def check_gemm_ref(self, libtype, solidx): def hipb_time_sol(self, solidx, cold_iters=2, warm_iters=10): #print('>>>hipbtime',solidx) for i in range(cold_iters): - c = hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx, - self.outdtype) + hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx, + self.outdtype) self.start.record() for i in range(warm_iters): - c = hipbsolidxgemm.hipb_mm( + hipbsolidxgemm.hipb_mm( self.inp, self.weights2[random.randint(0, self.nb - 1)].t(), solidx, self.outdtype) self.end.record() @@ -92,7 +92,8 @@ def hipb_time_all_sols(self, fast_mode=0, top_sols=0): coldi = 2 warmi = 2 solutions = self.hipb_sols - if top_sols: solutions = self.hipb_top_sols + if top_sols: + solutions = self.hipb_top_sols gtimes = {} for solidx in solutions: gtimes[solidx] = self.hipb_time_sol(solidx, @@ -184,8 +185,8 @@ def find_best_sols(self): def generate_mk_sets(model_dir, tp=1): - f = open(f'{model_dir}/config.json') - data = json.load(f) + with open(f'{model_dir}/config.json') as f: + data = json.load(f) hidden_size = data['hidden_size'] intermediate_size = data['intermediate_size'] total_num_heads = data['num_attention_heads']