diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 63080eaf2f11c..399348300554a 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -8,6 +8,10 @@ import triton import triton.language as tl from tqdm import tqdm +import torch.distributed as dist +import torch.multiprocessing as mp + +from tuning_utils import (get_full_tuning_space, prune_configs, union_of_list_of_dicts) import vllm._moe_C as moe_kernels from vllm._C import ops @@ -15,171 +19,26 @@ invoke_fused_moe_kernel, moe_align_block_size) +os.environ["HIP_FORCE_DEV_KERNARG"] = "1" +os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" def main(args): - os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID - os.environ["HIP_FORCE_DEV_KERNARG"] = "1" - os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" - os.environ["OPTIMIZE_EPILOGUE"] = "1" - - for bs in [ - 1, - 2, - 4, - 8, - 16, - 24, - 32, - 48, - 64, - 96, - 128, - 256, - 512, - 1024, - 1536, - 2048, - 3072, - 4096, - ]: - run_grid(bs, model=args.model, TP=args.TP) - - -## Utilize method from rocm/Triton tuning script -def get_full_tuning_space(): - configs = [] - - block_mn_range = [16, 32, 64, 128, 256] - block_k_range = [16, 32, 64, 128, 256] - # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] - num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8, 16, 32] - # For now we see better perf with num_stages=0 for all gemm configs we care - # But keep this explicit so that we do not forget we may need to set it to - # other values in the future - num_stage_range = [0] - waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] - - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - # for split_k in split_k_range: - for num_stages in num_stage_range: - for waves_per_eu in waves_per_eu_range: - for (matrix_instr_nonkdim - ) in matrix_instr_nonkdim_range: - for kpack in kpack_range: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_m, - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, - "matrix_instr_nonkdim": - matrix_instr_nonkdim, - "kpack": kpack, - }) - - return configs - - -## Utilize method from rocm/Triton tuning script -def prune_configs(M, N, K, configs): - pruned_configs = [] - elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) - elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) - - mfma = 16 if M < 32 or N < 32 else 32 - - # TODO (zhanglx): figure out the boundary between large and small gemms - large_gemm = False - if M >= 2048 and N >= 2048: - large_gemm = True - - for config in configs: - BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") - BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") - BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") - num_warps = config.get("num_warps") - matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - # kpack = config.get("kpack") - if matrix_instr_nonkdim > mfma: - continue - if mfma == 4 and BLOCK_SIZE_K < 64: - continue - # some layouts could not work properly in case - # number elements per thread is less 1 - if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: - continue - SPLIT_K = 1 # config.get("SPLIT_K") - GROUP_M = config.get("GROUP_SIZE_M") - if (matrix_instr_nonkdim > BLOCK_SIZE_M - or matrix_instr_nonkdim > BLOCK_SIZE_N): - continue - if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: - continue - if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: - continue - # Skip BLOCK_SIZE that is too large compare to M/N - # unless BLOCK_SIZE is already small enough - if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: - continue - if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: - continue - # skip large split_k when not necessary - if SPLIT_K != 1 and not need_split_k(M, N, K): - continue - # skip split_k that leads to EVEN_K = false - leap = SPLIT_K * BLOCK_SIZE_K - modv = K % leap - if modv != 0: - continue - # skip large GROUP_M - if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: - continue - # out of shared memory resource - # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + - BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) - if LDS > 65536: - continue - # Skip small block sizes and num_warps for large gemm - # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 - if large_gemm: - if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: - continue - if BLOCK_SIZE_K < 64: - continue - if num_warps < 4: - continue - - pruned_configs.append(config) - - return pruned_configs - - -def union_of_list_of_dicts(l1, l2): - result = [] - temp_list = l1.copy() - temp_list.extend(l2) - for myDict in temp_list: - if myDict not in result: - result.append(myDict) - - return result - - -def need_split_k(SIZE_M, SIZE_N, SIZE_K): - return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 - - -def run_grid(bs, model, TP): + world_size = args.numGPU + mp.spawn(wrapper, args=(args,), nprocs=world_size, join=False) + + +def wrapper(rank, args): + dist.init_process_group("nccl", world_size=args.numGPU, rank=rank) + device_id = rank + + batches = [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096] + for i in range(device_id, len(batches), args.numGPU): + tune_batch(batches[i], model=args.model, TP=args.modelTP) + + +def tune_batch(bs, model, TP): + device_id = torch.distributed.get_rank() + if model == '8x7B': d_model = 4096 model_intermediate_size = 14336 @@ -194,9 +53,6 @@ def run_grid(bs, model, TP): tp_size = TP num_calls = 100 - num_warmup_trials = 1 - num_trials = 1 - full_configs = get_full_tuning_space() M1 = bs * 2 N1 = model_intermediate_size * 2 // tp_size @@ -209,16 +65,17 @@ def run_grid(bs, model, TP): prune_configs_2 = prune_configs(M2, N2, K2, full_configs) configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) - print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \ - {len(prune_configs_2)=} | {len(configs)=}") best_config = None best_time_us = 1e20 - for config in tqdm(configs): - # warmup - try: - for _ in range(num_warmup_trials): + progress_bar = tqdm(total=len(configs), desc=f"bs={bs:4d} device={device_id}", position=device_id) + + with torch.cuda.device(device_id): + for config in configs: + progress_bar.update(1) + # warmup + try: run_timing( num_calls=num_calls, bs=bs, @@ -229,11 +86,10 @@ def run_grid(bs, model, TP): model_intermediate_size=model_intermediate_size, config=config, ) - except triton.runtime.autotuner.OutOfResources: - continue + except triton.runtime.autotuner.OutOfResources: + continue - # benchmark - for _ in range(num_trials): + # benchmark kernel_dur_ms = run_timing( num_calls=num_calls, bs=bs, @@ -246,20 +102,11 @@ def run_grid(bs, model, TP): ) kernel_dur_us = 1000 * kernel_dur_ms - # model_dur_ms = kernel_dur_ms * num_layers if kernel_dur_us < best_time_us: best_config = config best_time_us = kernel_dur_us - # print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - # f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - # f'{d_model=} {model_intermediate_size=} {num_layers=}') - - # print("best_time_us", best_time_us) - # print("best_config", best_config) - - # holds Dict[str, Dict[str, int]] filename = get_config_file_name(num_total_experts, model_intermediate_size // tp_size, dtype=None) @@ -286,10 +133,13 @@ def run_timing( ) -> float: shard_intermediate_size = model_intermediate_size // tp_size + device_ = "cuda" + dtype_ = torch.float16 + hidden_states = torch.rand( (bs, d_model), - device="cuda", - dtype=torch.float16, + device=device_, + dtype=dtype_, ) w1 = torch.rand( @@ -306,7 +156,6 @@ def run_timing( gating_output = F.softmax( torch.rand( - # (num_calls, bs, num_total_experts), # THIS (bs, num_total_experts), device=hidden_states.device, dtype=torch.float32, @@ -345,7 +194,7 @@ def run_timing( topk_weights, topk_ids, token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. + gating_output.float(), ) del token_expert_indicies # Not used. Will be used in the future. @@ -376,7 +225,7 @@ def run_timing( end_event = torch.cuda.Event(enable_timing=True) start_event.record() - for i in range(num_calls): + for _ in range(num_calls): invoke_fused_moe_kernel( hidden_states, w1, @@ -427,28 +276,28 @@ def run_timing( if __name__ == "__main__": parser = argparse.ArgumentParser( prog="benchmark_mixtral_moe_rocm", - description="Tune the fused_moe kernel for mixtral.") + description="Distributed tuning script for the fused_moe kernel.") + parser.add_argument('--model', + type=str, + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') parser.add_argument( - "--TP", + "--modelTP", type=int, choices=[8, 4, 2, 1], - help="Specify the TP value that the actual model will run on", + help="Specify the TP value that the model will actually run on", required=True, ) parser.add_argument( - "--GPUID", - type=str, - help="This script uses single GPU. Specify the GPU to use for tuning", - default="0", + "--numGPU", + type=int, + choices=[8, 4, 2, 1], + help="Total number of GPUs to use for tuning", + required=True, ) - parser.add_argument('--model', - type=str, - choices=['8x7B', '8x22B'], - help='The Mixtral model to benchmark') - args = parser.parse_args() print(f"Running tuning for {args.model} model") - print(f"TP is set to: {args.TP}") - print(f"GPU-ID being used for tuning: {args.GPUID}") + print(f"Model TP is set to: {args.modelTP}") + print(f"GPUs being used for tuning: {args.numGPU}") sys.exit(main(args)) diff --git a/benchmarks/kernels/tuning_utils.py b/benchmarks/kernels/tuning_utils.py new file mode 100644 index 0000000000000..120e9de870867 --- /dev/null +++ b/benchmarks/kernels/tuning_utils.py @@ -0,0 +1,134 @@ + + +## Utilize method from rocm/Triton tuning script +def get_full_tuning_space(): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + # for split_k in split_k_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for (matrix_instr_nonkdim + ) in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": + matrix_instr_nonkdim, + "kpack": kpack, + }) + + return configs + + +## Utilize method from rocm/Triton tuning script +def prune_configs(M, N, K, configs): + pruned_configs = [] + elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + # kpack = config.get("kpack") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elements per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = 1 # config.get("SPLIT_K") + GROUP_M = config.get("GROUP_SIZE_M") + if (matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N): + continue + if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: + continue + if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def union_of_list_of_dicts(l1, l2): + result = [] + temp_list = l1.copy() + temp_list.extend(l2) + for myDict in temp_list: + if myDict not in result: + result.append(myDict) + + return result + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024