diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py deleted file mode 100755 index 8fab21c5fd8f5..0000000000000 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ /dev/null @@ -1,454 +0,0 @@ -import argparse -import json -import os -import sys - -import torch -import torch.nn.functional as F -import triton -import triton.language as tl -from tqdm import tqdm - -import vllm._moe_C as moe_kernels -from vllm._C import ops -from vllm.model_executor.layers.fused_moe import (get_config_file_name, - invoke_fused_moe_kernel, - moe_align_block_size) - - -def main(args): - os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID - os.environ["HIP_FORCE_DEV_KERNARG"] = "1" - os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" - os.environ["OPTIMIZE_EPILOGUE"] = "1" - - for bs in [ - 1, - 2, - 4, - 8, - 16, - 24, - 32, - 48, - 64, - 96, - 128, - 256, - 512, - 1024, - 1536, - 2048, - 3072, - 4096, - ]: - run_grid(bs, model=args.model, TP=args.TP) - - -## Utilize method from rocm/Triton tuning script -def get_full_tuning_space(): - configs = [] - - block_mn_range = [16, 32, 64, 128, 256] - block_k_range = [16, 32, 64, 128, 256] - # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] - num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8, 16, 32] - # For now we see better perf with num_stages=0 for all gemm configs we care - # But keep this explicit so that we do not forget we may need to set it to - # other values in the future - num_stage_range = [0] - waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] - - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - # for split_k in split_k_range: - for num_stages in num_stage_range: - for waves_per_eu in waves_per_eu_range: - for (matrix_instr_nonkdim - ) in matrix_instr_nonkdim_range: - for kpack in kpack_range: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_m, - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, - "matrix_instr_nonkdim": - matrix_instr_nonkdim, - "kpack": kpack, - }) - - return configs - - -## Utilize method from rocm/Triton tuning script -def prune_configs(M, N, K, configs): - pruned_configs = [] - elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) - elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) - - mfma = 16 if M < 32 or N < 32 else 32 - - # TODO (zhanglx): figure out the boundary between large and small gemms - large_gemm = False - if M >= 2048 and N >= 2048: - large_gemm = True - - for config in configs: - BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") - BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") - BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") - num_warps = config.get("num_warps") - matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - # kpack = config.get("kpack") - if matrix_instr_nonkdim > mfma: - continue - if mfma == 4 and BLOCK_SIZE_K < 64: - continue - # some layouts could not work properly in case - # number elements per thread is less 1 - if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: - continue - SPLIT_K = 1 # config.get("SPLIT_K") - GROUP_M = config.get("GROUP_SIZE_M") - if (matrix_instr_nonkdim > BLOCK_SIZE_M - or matrix_instr_nonkdim > BLOCK_SIZE_N): - continue - if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: - continue - if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: - continue - # Skip BLOCK_SIZE that is too large compare to M/N - # unless BLOCK_SIZE is already small enough - if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: - continue - if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: - continue - # skip large split_k when not necessary - if SPLIT_K != 1 and not need_split_k(M, N, K): - continue - # skip split_k that leads to EVEN_K = false - leap = SPLIT_K * BLOCK_SIZE_K - modv = K % leap - if modv != 0: - continue - # skip large GROUP_M - if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: - continue - # out of shared memory resource - # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + - BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) - if LDS > 65536: - continue - # Skip small block sizes and num_warps for large gemm - # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 - if large_gemm: - if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: - continue - if BLOCK_SIZE_K < 64: - continue - if num_warps < 4: - continue - - pruned_configs.append(config) - - return pruned_configs - - -def union_of_list_of_dicts(l1, l2): - result = [] - temp_list = l1.copy() - temp_list.extend(l2) - for myDict in temp_list: - if myDict not in result: - result.append(myDict) - - return result - - -def need_split_k(SIZE_M, SIZE_N, SIZE_K): - return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 - - -def run_grid(bs, model, TP): - if model == '8x7B': - d_model = 4096 - model_intermediate_size = 14336 - elif model == '8x22B': - d_model = 6144 - model_intermediate_size = 16384 - else: - raise ValueError(f'Unsupported Mixtral model {model}') - - num_total_experts = 8 - top_k = 2 - tp_size = TP - num_calls = 100 - - num_warmup_trials = 1 - num_trials = 1 - - full_configs = get_full_tuning_space() - M1 = bs * 2 - N1 = model_intermediate_size * 2 // tp_size - K1 = d_model - prune_configs_1 = prune_configs(M1, N1, K1, full_configs) - - M2 = bs * 2 - N2 = d_model - K2 = model_intermediate_size // tp_size - prune_configs_2 = prune_configs(M2, N2, K2, full_configs) - - configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) - print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \ - {len(prune_configs_2)=} | {len(configs)=}") - - best_config = None - best_time_us = 1e20 - - for config in tqdm(configs): - # warmup - try: - for _ in range(num_warmup_trials): - run_timing( - num_calls=num_calls, - bs=bs, - d_model=d_model, - num_total_experts=num_total_experts, - top_k=top_k, - tp_size=tp_size, - model_intermediate_size=model_intermediate_size, - config=config, - ) - except triton.runtime.autotuner.OutOfResources: - continue - - # benchmark - for _ in range(num_trials): - kernel_dur_ms = run_timing( - num_calls=num_calls, - bs=bs, - d_model=d_model, - num_total_experts=num_total_experts, - top_k=top_k, - tp_size=tp_size, - model_intermediate_size=model_intermediate_size, - config=config, - ) - - kernel_dur_us = 1000 * kernel_dur_ms - # model_dur_ms = kernel_dur_ms * num_layers - - if kernel_dur_us < best_time_us: - best_config = config - best_time_us = kernel_dur_us - - # print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - # f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - # f'{d_model=} {model_intermediate_size=} {num_layers=}') - - # print("best_time_us", best_time_us) - # print("best_config", best_config) - - # holds Dict[str, Dict[str, int]] - filename = get_config_file_name(num_total_experts, - model_intermediate_size // tp_size, - dtype=None) - print(f"writing config to file {filename}") - existing_content = {} - if os.path.exists(filename): - with open(filename) as f: - existing_content = json.load(f) - existing_content[str(bs)] = best_config - with open(filename, "w") as f: - json.dump(existing_content, f, indent=4) - f.write("\n") - - -def run_timing( - num_calls: int, - bs: int, - d_model: int, - num_total_experts: int, - top_k: int, - tp_size: int, - model_intermediate_size: int, - config, -) -> float: - shard_intermediate_size = model_intermediate_size // tp_size - - hidden_states = torch.rand( - (bs, d_model), - device="cuda", - dtype=torch.float16, - ) - - w1 = torch.rand( - (num_total_experts, 2 * shard_intermediate_size, d_model), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - w2 = torch.rand( - (num_total_experts, d_model, shard_intermediate_size), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - gating_output = F.softmax( - torch.rand( - # (num_calls, bs, num_total_experts), # THIS - (bs, num_total_experts), - device=hidden_states.device, - dtype=torch.float32, - ), - dim=-1, - ) - - ###### Stuff from fused moe ###### - - assert (hidden_states.shape[0] == gating_output.shape[0] - ), "Number of tokens mismatch" - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - M, _ = hidden_states.shape - E, N, _ = w1.shape - topk_ = 2 - topk_weights = torch.empty(M, - topk_, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - topk_, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk_, - dtype=torch.int32, - device=hidden_states.device) - moe_kernels.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. - ) - del token_expert_indicies # Not used. Will be used in the future. - - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - - intermediate_cache1 = torch.empty( - (M, topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache2 = torch.empty( - (M * topk_ids.shape[1], N // 2), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - intermediate_cache3 = torch.empty( - (M, topk_ids.shape[1], w2.shape[1]), - device=hidden_states.device, - dtype=hidden_states.dtype, - ) - - sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( - topk_ids, config["BLOCK_SIZE_M"], E) - - ################################## - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for i in range(num_calls): - invoke_fused_moe_kernel( - hidden_states, - w1, - intermediate_cache1, - None, # a1_scale - None, # w1_scale - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, - topk_ids.shape[1], - config, - compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 - else tl.float16), - use_fp8=False, - ) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) - - invoke_fused_moe_kernel( - intermediate_cache2, - w2, - intermediate_cache3, - None, # a2_scale - None, # w2_scale - topk_weights, - topk_ids, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - True, - 1, - config, - compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 - else tl.float16), - use_fp8=False, - ) - - end_event.record() - end_event.synchronize() - - dur_ms = start_event.elapsed_time(end_event) / num_calls - return dur_ms - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog="benchmark_mixtral_moe_rocm", - description="Tune the fused_moe kernel for mixtral.") - parser.add_argument( - "--TP", - type=int, - choices=[8, 4, 2, 1], - help="Specify the TP value that the actual model will run on", - required=True, - ) - parser.add_argument( - "--GPUID", - type=str, - help="This script uses single GPU. Specify the GPU to use for tuning", - default="0", - ) - parser.add_argument('--model', - type=str, - choices=['8x7B', '8x22B'], - help='The Mixtral model to benchmark') - - args = parser.parse_args() - - print(f"Running tuning for {args.model} model") - print(f"TP is set to: {args.TP}") - print(f"GPU-ID being used for tuning: {args.GPUID}") - sys.exit(main(args)) diff --git a/benchmarks/kernels/moe_tune_script.sh b/benchmarks/kernels/moe_tune_script.sh new file mode 100755 index 0000000000000..2ee1748b5535f --- /dev/null +++ b/benchmarks/kernels/moe_tune_script.sh @@ -0,0 +1,33 @@ +#!/bin/bash + +export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 + + +## ---- Mixtral fp8 tuning example ---- ## +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-Instruct-v0.1-FP8/ --tp-size 1 --tune --dtype fp8_w8a8 +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-Instruct-v0.1-FP8/ --tp-size 2 --tune --dtype fp8_w8a8 +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-Instruct-v0.1-FP8/ --tp-size 4 --tune --dtype fp8_w8a8 +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-Instruct-v0.1-FP8/ --tp-size 8 --tune --dtype fp8_w8a8 + + +## ---- Mixtral fp16 tuning example ---- ## +# we don't need --dtype fp16; it has been set as default for rocm in the script. + +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-v0.1/ --tp-size 1 --tune +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-v0.1/ --tp-size 2 --tune +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-v0.1/ --tp-size 4 --tune +python benchmark_moe.py --model /data/models/mistral-ai-models/Mixtral-8x22B-v0.1/ --tp-size 8 --tune + + + +## ---- After the tuning is finished ---- ## +# The tuning script saves the configurations in a json file at the same directory from where you launch the script. +# The name of the json file will look something like this: E=8,N=14336,device_name=AMD_Instinct_MI300X.json +# +# [IMPORTANT] -> Once the tuning is complete, move the tuned config file(s) to the following path: +# vllm/vllm/model_executor/layers/fused_moe/configs/ + + +## ---- Notes ---- ## +# 1. The tuned file is specific for a TP size. This means a tuned file obtained for --tp-size 8 can only be used when running the model under TP=8 setting. +# 2. The script uses Ray for multi-gpu tuning. Export HIP_VISIBLE_DEVICES accordingly to expose the required no. of GPUs and use multiple gpus for tuning. \ No newline at end of file