diff --git a/ROCm_performance.md b/ROCm_performance.md index 180c848a21950..83c4e05b941a9 100644 --- a/ROCm_performance.md +++ b/ROCm_performance.md @@ -18,3 +18,24 @@ Define the following environment symbol: `PYTORCH_TUNABLEOP_ENABLED=1` in order On ROCm, to have better performance, a custom paged attention is available by switching on the env variable: `VLLM_USE_ROCM_CUSTOM_PAGED_ATTN=1`. Currently, this env variable is enabled by default. To fallback to PagedAttention v2 kernel assign the env variable to 0. The custom PagedAttention kernel is enabled for dtype: fp16, block-size=16, head-size=128, and max context length <= 16k, with GQA ratio (num_heads//num_kv_heads) between 1 to 16. On all the other cases, we fallback to PagedAttention v2 kernel. + +## Fp8 Quantization + +To use fp8 quantization, first step is to quantize your model to fp8 format. Please follow this [instruction](https://github.com/ROCm/vllm/tree/main/examples/fp8/quantizer) to generating a safetensor file that contains the quantized weights and the corresponding scaling factors of your model. The safetensor file should be added under your model folder. + +Then we can run a model with fp8 quantization using vllm. When creating `vLLM.LLM` object, two additional parameters should be added: `quantization="fp8"` and `quantization_param_path={relative path of the safetensors with your model path}`. + +## Gemm Tunning for Fp8 + +To get better performance of fp8 quantization, we will need to tune the gemm with the information of all the shapes used in the execution of the model. + +To obtain all the shapes of gemms during the execution of the model, set the env value `TUNE_FP8=1` and the run the model as usual. We will get the a file called `/tmp/fp8_shapes.csv`. + +Next, run gradlib to obtain the best solutions of these shapes: + +``` +python3 gradlib/gradlib/fp8_gemm_tunner.py --input_file /tmp/fp8_shapes.csv --tuned_file /tmp/tuned_fp8_16.csv +``` +where `/tmp/tuned_fp8_16` will be used by our fp8 gemm linear layer. + +Now, when running inference with fp8, we are using the tunned gemm for best performance. \ No newline at end of file diff --git a/gradlib/csrc/hipbsolgemm.cu b/gradlib/csrc/hipbsolgemm.cu index bf15fb1297667..e5b7cbc7ea43f 100644 --- a/gradlib/csrc/hipbsolgemm.cu +++ b/gradlib/csrc/hipbsolgemm.cu @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -115,6 +116,8 @@ namespace { bool cout_print = false; + torch::Tensor dTensor; + //std::vector heuristicResult; } @@ -132,23 +135,24 @@ std::vector hipblasLtMatmul_findallsols_wrapper( const void *beta, void *c, int ldc, - hipDataType dtype, + hipDataType intype, + hipDataType outtype, hipStream_t &stream) { int flag { 0 }; hipblasLtMatrixLayout_t matA, matB, matC; hipblasLtMatmulDesc_t matmul; if (op_A == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); } if (op_B == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); } - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); @@ -163,10 +167,10 @@ std::vector hipblasLtMatmul_findallsols_wrapper( CHECK_HIPBLAS_ERROR(hipblaslt_ext::getAllAlgos(handle, hipblaslt_ext::GemmType::HIPBLASLT_GEMM, op_A, op_B, - dtype, - dtype, - dtype, - dtype, + intype, + intype, + outtype, + outtype, HIPBLAS_COMPUTE_32F, heuristicResult)); @@ -211,12 +215,16 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper( const void *alpha, const void *a, int lda, + const void *scaleA, const void *b, int ldb, + const void *scaleB, const void *beta, void *c, int ldc, - hipDataType dtype, + const void *scaleC, + hipDataType intype, + hipDataType outtype, hipStream_t &stream, int solution_index=-1) { @@ -232,21 +240,27 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper( hipblasLtMatrixLayout_t matA, matB, matC; hipblasLtMatmulDesc_t matmul; if (op_A == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, m, k, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, m, k, lda)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, dtype, k, m, lda)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matA, intype, k, m, lda)); } if (op_B == HIPBLAS_OP_N) { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, k, n, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, k, n, ldb)); } else { - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, dtype, n, k, ldb)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matB, intype, n, k, ldb)); } - CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, dtype, m, n, ldc)); + CHECK_HIPBLAS_ERROR(hipblasLtMatrixLayoutCreate(&matC, outtype, m, n, ldc)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLAS_COMPUTE_32F, HIP_R_32F)); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSA, &op_A, sizeof(int32_t))); CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( matmul, HIPBLASLT_MATMUL_DESC_TRANSB, &op_B, sizeof(int32_t))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_A_SCALE_POINTER, &scaleA, sizeof(scaleA))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_B_SCALE_POINTER, &scaleB, sizeof(scaleB))); + CHECK_HIPBLAS_ERROR(hipblasLtMatmulDescSetAttribute( + matmul, HIPBLASLT_MATMUL_DESC_D_SCALE_POINTER, &scaleC, sizeof(scaleC))); //nvtxRangePop(); // if heuristic does not exist in the map, do search and push into the map //auto gemm_key { MatMulConfig { op_A, op_B, m, n, k, dtype } }; @@ -257,7 +271,7 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper( std::cout << "Warning! HipbSolId Gemm Fallback Path used for solution index <0" << std::endl; if (cout_print) { std::cout << (op_A == HIPBLAS_OP_N ? "N" : "T") << (op_B == HIPBLAS_OP_N ? "N" : "T") - << " (" << m << ", " << n << ", " << k << "), dtype: " << dtype + << " (" << m << ", " << n << ", " << k << "), dtype: " << intype << ", (lda, ldb, ldc): (" << lda << ", " << ldb << ", " << ldc << "), " << std::endl; } //std::vector heuristicResult(request_solutions); @@ -385,7 +399,11 @@ hipblasStatus_t hipblasLtMatmul_sol_wrapper( torch::Tensor HipbSolIdxBlas( const torch::Tensor& mat1, const torch::Tensor& mat2, - const int solution_index + const int solution_index, + at::optional Type = at::nullopt, + at::optional scale1 = at::nullopt, + at::optional scale2 = at::nullopt, + at::optional scaleOut = at::nullopt ) { auto mat1_strides { mat1.strides() }; @@ -402,8 +420,10 @@ torch::Tensor HipbSolIdxBlas( ); TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - auto abcType { mat1.options().dtype() }; - auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) }; + auto inType { mat1.options().dtype() }; + auto outType = inType.toScalarType(); + if (Type.has_value()) outType = torch::python::detail::py_object_to_dtype(Type.value()); + auto options { at::TensorOptions().dtype(outType).device(at::kCUDA) }; auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; // std::cout << " | result info: size: " << result.sizes() << " stride: " << result.strides() << std::endl; @@ -448,35 +468,61 @@ torch::Tensor HipbSolIdxBlas( int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); - // std::cout << " | (m, n, k): " << m << ", " << n << ", " << k << std::endl - // << " | (lda, ldb, ldc): " << mat1_ld << ", " << mat2_ld << ", " << result_ld << std::endl; + + void * d_scale1 = nullptr, * d_scale2 = nullptr, * d_scaleOut = nullptr; + if (scale1.has_value()) { + d_scale1 = static_cast(scale1.value().data_ptr()); + } + if (scale2.has_value()) { + d_scale2 = static_cast(scale2.value().data_ptr()); + } + if (scaleOut.has_value()) { + d_scaleOut = static_cast(scaleOut.value().data_ptr()); + } + - hipDataType hipblasType; - if (abcType == at::kHalf) { - hipblasType = HIP_R_16F; - } else if (abcType == at::kBFloat16) { - hipblasType = HIP_R_16BF; - } else if (abcType == at::kFloat) { - hipblasType = HIP_R_32F; + hipDataType hipblasInType, hipblasOutType; + if (inType == at::kHalf) { + hipblasInType = HIP_R_16F; + } else if (inType == at::kBFloat16) { + hipblasInType = HIP_R_16BF; + } else if (inType == at::kFloat) { + hipblasInType = HIP_R_32F; + } else if (inType == at::kFloat8_e4m3fnuz) { + hipblasInType = HIP_R_8F_E4M3_FNUZ; + } else { + assert(false && "Wrong datatype!"); + } + + if (outType == at::kHalf) { + hipblasOutType = HIP_R_16F; + } else if (outType == at::kBFloat16) { + hipblasOutType = HIP_R_16BF; + } else if (outType == at::kFloat) { + hipblasOutType = HIP_R_32F; + } else if (outType == at::kFloat8_e4m3fnuz) { + hipblasOutType = HIP_R_8F_E4M3_FNUZ; } else { assert(false && "Wrong datatype!"); } void *ptrA { static_cast((transpose_result ? mat2 : mat1).data_ptr()) }; void *ptrB { static_cast((transpose_result ? mat1 : mat2).data_ptr()) }; void *ptrC { static_cast(result.data_ptr()) }; + if (transpose_result) std::swap(d_scale1, d_scale2); auto current_stream { torch::hip::getCurrentHIPStream().stream() }; - + CHECK_HIPBLAS_ERROR(hipblasLtMatmul_sol_wrapper( hipblaslt_handle, transpose_mat1 ? HIPBLAS_OP_T : HIPBLAS_OP_N, transpose_mat2 ? HIPBLAS_OP_T : HIPBLAS_OP_N, m, n, k, &one, - ptrA, mat1_ld, - ptrB, mat2_ld, + ptrA, mat1_ld, d_scale1, + ptrB, mat2_ld, d_scale2, &zero, - ptrC, result_ld, - hipblasType, + ptrC, result_ld, d_scaleOut, + hipblasInType, + hipblasOutType, current_stream,solution_index)); return result; @@ -485,7 +531,8 @@ torch::Tensor HipbSolIdxBlas( //find all hipblas solutions and return them to python land std::vector HipbFindAllSolIdxBlas( const torch::Tensor& mat1, - const torch::Tensor& mat2 + const torch::Tensor& mat2, + at::optional Type = at::nullopt ) { auto mat1_strides { mat1.strides() }; @@ -499,8 +546,10 @@ std::vector HipbFindAllSolIdxBlas( ); TORCH_CHECK(mat1_sizes[1] == mat2_sizes[0], "mat1 dim 1 must match mat2 dim 0"); - auto abcType { mat1.options().dtype() }; - auto options { at::TensorOptions().dtype(abcType).device(at::kCUDA) }; + auto inType { mat1.options().dtype() }; + auto outType = inType.toScalarType(); + if (Type.has_value()) outType = torch::python::detail::py_object_to_dtype(Type.value()); + auto options { at::TensorOptions().dtype(outType).device(at::kCUDA) }; auto result { torch::empty({ mat1_sizes[0], mat2_sizes[1] }, options) }; bool transpose_result = true; bool transpose_mat1; @@ -536,13 +585,26 @@ std::vector HipbFindAllSolIdxBlas( int64_t mat1_ld = mat1_strides[(transpose_mat1 == transpose_result) ? 1 : 0]; int64_t mat2_ld = mat2_strides[(transpose_mat2 == transpose_result) ? 1 : 0]; int64_t result_ld = result.stride(transpose_result ? 0 : 1); - hipDataType hipblasType; - if (abcType == at::kHalf) { - hipblasType = HIP_R_16F; - } else if (abcType == at::kBFloat16) { - hipblasType = HIP_R_16BF; - } else if (abcType == at::kFloat) { - hipblasType = HIP_R_32F; + hipDataType hipblasInType, hipblasOutType; + if (inType == at::kHalf) { + hipblasInType = HIP_R_16F; + } else if (inType == at::kBFloat16) { + hipblasInType = HIP_R_16BF; + } else if (inType == at::kFloat) { + hipblasInType = HIP_R_32F; + } else if (inType == at::kFloat8_e4m3fnuz) { + hipblasInType = HIP_R_8F_E4M3_FNUZ; + } else { + assert(false && "Wrong datatype!"); + } + if (outType == at::kHalf) { + hipblasOutType = HIP_R_16F; + } else if (outType == at::kBFloat16) { + hipblasOutType = HIP_R_16BF; + } else if (outType == at::kFloat) { + hipblasOutType = HIP_R_32F; + } else if (outType == at::kFloat8_e4m3fnuz) { + hipblasOutType = HIP_R_8F_E4M3_FNUZ; } else { assert(false && "Wrong datatype!"); } @@ -561,7 +623,8 @@ std::vector HipbFindAllSolIdxBlas( ptrB, mat2_ld, &zero, ptrC, result_ld, - hipblasType, + hipblasInType, + hipblasOutType, current_stream); } @@ -605,6 +668,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("hipb_create_extension", &hipb_create_extension, "create_extension"); m.def("hipb_destroy_extension", &hipb_destroy_extension, "destroy_extension"); - m.def("hipb_mm", &HipbSolIdxBlas, "mm"); - m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols"); -} + m.def("hipb_mm", &HipbSolIdxBlas, "mm", + py::arg("mat1"), py::arg("mat2"), + py::arg("solution_index"), + py::arg("outType")=at::nullopt, + py::arg("scale1")=at::nullopt, py::arg("scale2")=at::nullopt, py::arg("scaleOut")=at::nullopt); + m.def("hipb_findallsols", &HipbFindAllSolIdxBlas, "hipblas_find_all_sols", + py::arg("mat1"), py::arg("mat2"), + py::arg("outType")=at::nullopt); +} \ No newline at end of file diff --git a/gradlib/gradlib/fp8_gemm_tunner.py b/gradlib/gradlib/fp8_gemm_tunner.py new file mode 100644 index 0000000000000..71d117ff03336 --- /dev/null +++ b/gradlib/gradlib/fp8_gemm_tunner.py @@ -0,0 +1,222 @@ +import torch +import os +import argparse +import hipbsolidxgemm +import numpy as np +import torch.nn.functional as F +import pandas as pd +import json +import random +from pathlib import Path + +hipbsolidxgemm.hipb_create_extension() + +rtol = 1e-5 +atol = 1 + +class Fp8Gemm: + def __init__(self, m, n, k, indtype, outdtype): + self.m = m + self.k = k + self.n = n + self.indtype = indtype + self.outdtype = outdtype + self.nb = 37 + self.inp = torch.randn((self.n, self.k), device='cuda').to(self.indtype) + self.weights = torch.randn((self.m, self.k), device='cuda').to(self.indtype) + #weights2 is used in measurement/warm iters to ensure HBM fetch for weight tensors + self.weights2 = torch.randn((self.nb, self.m, self.k), device='cuda').to(self.indtype) + self.blob = torch.ones(128*1024*1024, dtype=torch.float32, device='cuda') + self.topn = 20 #number of top solutions from each source + self.hipb_sols = [] + self.rtol = 1e-5 + self.atol = 1 + self.start = torch.cuda.Event(enable_timing=True) + self.end = torch.cuda.Event(enable_timing=True) + + + def find_hipblas_sols(self): + sols = hipbsolidxgemm.hipb_findallsols(self.inp, self.weights.t(), self.outdtype) + print('M N K', self.m, self.n, self.k,'>>> Total hipb solutions', len(sols), flush=True) + #print(sols) + self.hipb_sols = sols + + + def check_gemm_ref(self, libtype, solidx): + ref = F.linear(self.inp.to(torch.float32),self.weights.to(torch.float32)).to(self.outdtype) + c = hipbsolidxgemm.hipb_mm(self.inp,self.weights.t(),solidx,self.outdtype) + if torch.allclose(c, ref, atol=self.atol, rtol=self.rtol): + #print('>>>',libtype,'Solidx',solidx,'passed reference test') + return True + else: + print('>>>','Solidx',solidx,'FAILED reference test', flush=True) + print(ref, flush=True) + print(c, flush=True) + return False + + def hipb_time_sol(self, solidx, cold_iters=2, warm_iters=10): + #print('>>>hipbtime',solidx) + for i in range(cold_iters): + c = hipbsolidxgemm.hipb_mm(self.inp,self.weights.t(),solidx,self.outdtype) + self.start.record() + for i in range(warm_iters): + c = hipbsolidxgemm.hipb_mm(self.inp,self.weights2 [random.randint(0,self.nb-1)].t(),solidx,self.outdtype) + self.end.record() + torch.cuda.synchronize() + gtime = self.start.elapsed_time(self.end)/warm_iters + #print('>>> Solidx GTime',solidx,gtime,'ms') + return gtime + + def hipb_time_all_sols(self, fast_mode=0, top_sols=0): + coldi=20; warmi=20 + if fast_mode: coldi=2; warmi=2 + solutions = self.hipb_sols + if top_sols: solutions = self.hipb_top_sols + gtimes = {} + for solidx in solutions: + gtimes[solidx] = self.hipb_time_sol(solidx, cold_iters=coldi, warm_iters=warmi) + self.hipb_gtimedf = pd.DataFrame.from_dict(gtimes,orient='index',columns=['gtimems']).sort_values(by='gtimems') + self.hipb_gtimedf.to_csv('/tmp/hipb_gtimedf.csv') + print('>>> HipBlasLt top solutions, Fast Mode',fast_mode) + print(self.hipb_gtimedf.head(self.topn)) + + def warmup(self, warmi=500): + for i in range(warmi): + self.blob = self.blob + 0.00001 + + def functional_check_topn_fastest(self): + hipb_topn = [] + for solidx in self.hipb_gtimedf.index[:self.topn]: + if self.check_gemm_ref(libtype='hipblaslt',solidx=solidx): + hipb_topn.append(solidx) + self.hipb_top_sols = hipb_topn + + def find_fastest_solution(self): + self.find_hipblas_sols() + self.warmup() + self.hipb_time_all_sols(fast_mode=1) + self.functional_check_topn_fastest() + self.warmup() + self.hipb_time_all_sols(fast_mode=0,top_sols=1) + if len(self.hipb_gtimedf)>0: + best_hipb_time = self.hipb_gtimedf.gtimems.iloc[0] + self.best_solidx = self.hipb_gtimedf.index[0] + self.best_soltime = best_hipb_time + else: + print('>>> No hipblas solutions found!',flush=True) + self.best_solidx = 0 + self.best_soltime = 0 + print('>>> Fastest Solution is',self.best_solidx,self.best_soltime,flush=True) + + +class Fp8GemmTuner: + def __init__(self, indtype, outdtype, tuned_file=None): + self.gemm_problems = pd.DataFrame(columns=['M','N','K']) + self.indtype = indtype + self.outdtype = outdtype + self.tuned_file = tuned_file + if Path(tuned_file).is_file(): + self.gdf = pd.read_csv(tuned_file) + else: + self.gdf = None + + def add_gemm(self, m, n, k): + if (self.gdf is None or (self.gdf[(self.gdf['M'] == m) & (self.gdf['N'] == n) & (self.gdf['K'] == k)].empty)): + entry = {'M':[m], 'N':[n], 'K':[k]} + df = pd.DataFrame(entry) + self.gemm_problems = pd.concat([self.gemm_problems, df],ignore_index=True) + else: + print(f">>>Info: Found Duplicate shape(M:{m}, N:{n}, K:{k}), skipping") + + def find_best_sols(self): + df = self.gemm_problems + soldf = pd.DataFrame() + for i in range(len(df)): + ds = df.iloc[i] + gemmobj = Fp8Gemm(ds['M'], ds['N'], ds['K'], + indtype=self.indtype, + outdtype=self.outdtype) + gemmobj.find_fastest_solution() + soldf.loc[i,'solidx'] = gemmobj.best_solidx + soldf.loc[i,'soltimems'] = gemmobj.best_soltime + soldf['indtype'] = self.indtype + soldf['outdtype'] = self.outdtype + finaldf = pd.concat([self.gemm_problems, soldf],axis=1) + finaldf = pd.concat([finaldf, self.gdf]) + finaldf.to_csv(self.tuned_file, index=False) + print(finaldf) + + +def generate_mk_sets(model_dir, tp=1): + f = open(f'{model_dir}/config.json') + data = json.load(f) + hidden_size = data['hidden_size'] + intermediate_size = data['intermediate_size'] + total_num_heads = data['num_attention_heads'] + total_num_kv_heads = data['num_key_value_heads'] + head_dim = hidden_size // total_num_heads + return [((total_num_heads + (2*total_num_kv_heads)) * head_dim // tp, hidden_size), + (hidden_size, hidden_size // tp), + (intermediate_size *2 // tp, hidden_size), + (hidden_size, intermediate_size // tp)], hidden_size + +def get_dtype(dtype_str): + dtype = torch.float8_e4m3fnuz + if dtype_str == 'f32': + dtype = torch.float32 + elif dtype_str == 'bf16': + dtype = torch.bfloat16 + elif dtype_str == 'f16': + dtype = torch.float16 + elif dtype_str == 'f8': + dtype = torch.float8_e4m3fnuz + else: + print('>>> Warning! Invalid dtype', dtype_str, 'using default dtype f8') + return dtype + + +def list_of_ints(arg): + return list(map(int, arg.split(','))) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--model_dir", type=str, default=os.getenv('GTUNE_MODEL', ""), help="Enter the location of your model directory") + parser.add_argument("--tuned_file", type=str, default=os.getenv('GTUNE_TUNED', "tuned.csv"), help="output file for tuned gemm solutions") + parser.add_argument("--input_file", type=str, default=os.getenv('GTUNE_INPUT', None), help="list of gemms to tune for, mutually exclusive with model_dir") + parser.add_argument("--tp", type=int, default=os.getenv('GTUNE_TP', 1), help="Tensor parallelism to be used.") + parser.add_argument("--indtype", type=str, default='f8', help="dtype f32 f16 bf16 fp8") + parser.add_argument("--outdtype", type=str, default='f16', help="dtype f32 f16 bf16 fp8") + parser.add_argument("--batch_size", type=int, default=os.getenv('GTUNE_BATCH_SIZE', 1), help="Batch size to tune for") + parser.add_argument("--nsets", type=list_of_ints, default=[1, 512, 1024, 2048, 3072, 4096, 8192, 16384], help="N sizes to tune for: 1,128,2048") + args = parser.parse_args() + + indtype = get_dtype(args.indtype) + outdtype = get_dtype(args.outdtype) + + gtuner = Fp8GemmTuner(indtype, outdtype, args.tuned_file) + nsets = [i * args.batch_size for i in args.nsets] + if args.input_file: + print(f">>> Loading {args.input_file}") + if not Path(args.input_file).is_file(): + print(f">>> ERROR: {args.input_file} does not exist. Exiting") + exit(1) + shapes = pd.read_csv(args.input_file) + for i in range(len(shapes)): + ds = shapes.iloc[i] + gtuner.add_gemm(ds['M'],ds['N'],ds['K']) + else: + if not args.model_dir: + print(">>> Warning! NO MODEL SPECIFIED. Tuning for LL2 13B TP1") + #LL2 13B sizes + mksets = [(15360, 5120), (5120, 5120), (27648, 5120), (5120, 13824)] + gtuner.add_gemm(m=32000, n=1, k=5120) # logits gemm + else: + mksets, hidden_size = generate_mk_sets(args.model_dir, args.tp) + gtuner.add_gemm(m=32000//args.tp, n=1 * args.batch_size, k=hidden_size) #TODO: Handle cases where vocab_size is not divisible by tp + + for n in sorted(nsets): + for m, k in mksets: + gtuner.add_gemm(m, n, k) + + gtuner.find_best_sols() \ No newline at end of file diff --git a/vllm/model_executor/layers/quantization/fp8_rocm.py b/vllm/model_executor/layers/quantization/fp8_rocm.py index ddccc5825c8a4..ddb83a6ed452e 100644 --- a/vllm/model_executor/layers/quantization/fp8_rocm.py +++ b/vllm/model_executor/layers/quantization/fp8_rocm.py @@ -30,10 +30,10 @@ def __init__(self) -> None: #print(f"Integral Cross factor = {self.factor}") if gemm_type == "fp8_8": self.gemm_method = Fp8RocmLinearMethod.apply_fp8_8 - tuned_filename = "/projects/tuned_fp8_8.csv" + tuned_filename = "/tmp/tuned_fp8_8.csv" elif gemm_type == "fp8_16": self.gemm_method = Fp8RocmLinearMethod.apply_fp8_16 - tuned_filename = "/projects/tuned_fp8_16.csv" + tuned_filename = "/tmp/tuned_fp8_16.csv" else: raise ValueError(f"Unknown fp8 gemm type: {gemm_type}") try: @@ -50,7 +50,7 @@ def __init__(self) -> None: m = shape["M"] n = shape["N"] k = shape["K"] - algo = shape["algo"] + algo = shape["solidx"] self._tuned[(m, n, k)] = algo @classmethod @@ -224,7 +224,7 @@ def apply_fp8_16( if os.getenv("TUNE_FP8") == "1": try: - df = pd.read_csv("/projects/fp8_shapes.csv") + df = pd.read_csv("/tmp/fp8_shapes.csv") except (IOError, pd.errors.EmptyDataError, pd.errors.ParserError): df = pd.DataFrame(columns=["M", "N", "K"]) @@ -234,7 +234,7 @@ def apply_fp8_16( "N": [n], "K": [k] })]).drop_duplicates() - df.to_csv("/projects/fp8_shapes.csv", index=False) + df.to_csv("/tmp/fp8_shapes.csv", index=False) algo = 0 res = vllm_ops.fp8_gemm_16(x8, weight.t(), asf, wsf, int(algo)) return res @@ -271,7 +271,7 @@ def apply_fp8_8( "N": [n], "K": [k] })]).drop_duplicates() - df.to_csv("/projects/fp8_shapes.csv", index=False) + df.to_csv("/tmp/fp8_shapes.csv", index=False) algo = 0 res = vllm_ops.fp8_gemm(x8, weight.t(), asf, wsf, osf, int(algo))