From fd16af5d7b5d1cdafd6a3f8d2895900d0999266b Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Fri, 16 Aug 2024 18:32:08 +0000 Subject: [PATCH 01/14] add multi-gpu tuning support with tqdm progress bar - todo: add fp8 support - todo: add comments and documentation --- .../kernels/benchmark_mixtral_moe_rocm.py | 257 ++++-------------- benchmarks/kernels/tuning_utils.py | 134 +++++++++ 2 files changed, 187 insertions(+), 204 deletions(-) create mode 100644 benchmarks/kernels/tuning_utils.py diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 63080eaf2f11c..399348300554a 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -8,6 +8,10 @@ import triton import triton.language as tl from tqdm import tqdm +import torch.distributed as dist +import torch.multiprocessing as mp + +from tuning_utils import (get_full_tuning_space, prune_configs, union_of_list_of_dicts) import vllm._moe_C as moe_kernels from vllm._C import ops @@ -15,171 +19,26 @@ invoke_fused_moe_kernel, moe_align_block_size) +os.environ["HIP_FORCE_DEV_KERNARG"] = "1" +os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" def main(args): - os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID - os.environ["HIP_FORCE_DEV_KERNARG"] = "1" - os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" - os.environ["OPTIMIZE_EPILOGUE"] = "1" - - for bs in [ - 1, - 2, - 4, - 8, - 16, - 24, - 32, - 48, - 64, - 96, - 128, - 256, - 512, - 1024, - 1536, - 2048, - 3072, - 4096, - ]: - run_grid(bs, model=args.model, TP=args.TP) - - -## Utilize method from rocm/Triton tuning script -def get_full_tuning_space(): - configs = [] - - block_mn_range = [16, 32, 64, 128, 256] - block_k_range = [16, 32, 64, 128, 256] - # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] - num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8, 16, 32] - # For now we see better perf with num_stages=0 for all gemm configs we care - # But keep this explicit so that we do not forget we may need to set it to - # other values in the future - num_stage_range = [0] - waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] - - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - # for split_k in split_k_range: - for num_stages in num_stage_range: - for waves_per_eu in waves_per_eu_range: - for (matrix_instr_nonkdim - ) in matrix_instr_nonkdim_range: - for kpack in kpack_range: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_m, - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, - "matrix_instr_nonkdim": - matrix_instr_nonkdim, - "kpack": kpack, - }) - - return configs - - -## Utilize method from rocm/Triton tuning script -def prune_configs(M, N, K, configs): - pruned_configs = [] - elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) - elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) - - mfma = 16 if M < 32 or N < 32 else 32 - - # TODO (zhanglx): figure out the boundary between large and small gemms - large_gemm = False - if M >= 2048 and N >= 2048: - large_gemm = True - - for config in configs: - BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") - BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") - BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") - num_warps = config.get("num_warps") - matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - # kpack = config.get("kpack") - if matrix_instr_nonkdim > mfma: - continue - if mfma == 4 and BLOCK_SIZE_K < 64: - continue - # some layouts could not work properly in case - # number elements per thread is less 1 - if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: - continue - SPLIT_K = 1 # config.get("SPLIT_K") - GROUP_M = config.get("GROUP_SIZE_M") - if (matrix_instr_nonkdim > BLOCK_SIZE_M - or matrix_instr_nonkdim > BLOCK_SIZE_N): - continue - if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: - continue - if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: - continue - # Skip BLOCK_SIZE that is too large compare to M/N - # unless BLOCK_SIZE is already small enough - if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: - continue - if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: - continue - # skip large split_k when not necessary - if SPLIT_K != 1 and not need_split_k(M, N, K): - continue - # skip split_k that leads to EVEN_K = false - leap = SPLIT_K * BLOCK_SIZE_K - modv = K % leap - if modv != 0: - continue - # skip large GROUP_M - if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: - continue - # out of shared memory resource - # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + - BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) - if LDS > 65536: - continue - # Skip small block sizes and num_warps for large gemm - # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 - if large_gemm: - if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: - continue - if BLOCK_SIZE_K < 64: - continue - if num_warps < 4: - continue - - pruned_configs.append(config) - - return pruned_configs - - -def union_of_list_of_dicts(l1, l2): - result = [] - temp_list = l1.copy() - temp_list.extend(l2) - for myDict in temp_list: - if myDict not in result: - result.append(myDict) - - return result - - -def need_split_k(SIZE_M, SIZE_N, SIZE_K): - return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 - - -def run_grid(bs, model, TP): + world_size = args.numGPU + mp.spawn(wrapper, args=(args,), nprocs=world_size, join=False) + + +def wrapper(rank, args): + dist.init_process_group("nccl", world_size=args.numGPU, rank=rank) + device_id = rank + + batches = [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096] + for i in range(device_id, len(batches), args.numGPU): + tune_batch(batches[i], model=args.model, TP=args.modelTP) + + +def tune_batch(bs, model, TP): + device_id = torch.distributed.get_rank() + if model == '8x7B': d_model = 4096 model_intermediate_size = 14336 @@ -194,9 +53,6 @@ def run_grid(bs, model, TP): tp_size = TP num_calls = 100 - num_warmup_trials = 1 - num_trials = 1 - full_configs = get_full_tuning_space() M1 = bs * 2 N1 = model_intermediate_size * 2 // tp_size @@ -209,16 +65,17 @@ def run_grid(bs, model, TP): prune_configs_2 = prune_configs(M2, N2, K2, full_configs) configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) - print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \ - {len(prune_configs_2)=} | {len(configs)=}") best_config = None best_time_us = 1e20 - for config in tqdm(configs): - # warmup - try: - for _ in range(num_warmup_trials): + progress_bar = tqdm(total=len(configs), desc=f"bs={bs:4d} device={device_id}", position=device_id) + + with torch.cuda.device(device_id): + for config in configs: + progress_bar.update(1) + # warmup + try: run_timing( num_calls=num_calls, bs=bs, @@ -229,11 +86,10 @@ def run_grid(bs, model, TP): model_intermediate_size=model_intermediate_size, config=config, ) - except triton.runtime.autotuner.OutOfResources: - continue + except triton.runtime.autotuner.OutOfResources: + continue - # benchmark - for _ in range(num_trials): + # benchmark kernel_dur_ms = run_timing( num_calls=num_calls, bs=bs, @@ -246,20 +102,11 @@ def run_grid(bs, model, TP): ) kernel_dur_us = 1000 * kernel_dur_ms - # model_dur_ms = kernel_dur_ms * num_layers if kernel_dur_us < best_time_us: best_config = config best_time_us = kernel_dur_us - # print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}' - # f' {bs=} {tp_size=} {top_k=} {num_total_experts=} ' - # f'{d_model=} {model_intermediate_size=} {num_layers=}') - - # print("best_time_us", best_time_us) - # print("best_config", best_config) - - # holds Dict[str, Dict[str, int]] filename = get_config_file_name(num_total_experts, model_intermediate_size // tp_size, dtype=None) @@ -286,10 +133,13 @@ def run_timing( ) -> float: shard_intermediate_size = model_intermediate_size // tp_size + device_ = "cuda" + dtype_ = torch.float16 + hidden_states = torch.rand( (bs, d_model), - device="cuda", - dtype=torch.float16, + device=device_, + dtype=dtype_, ) w1 = torch.rand( @@ -306,7 +156,6 @@ def run_timing( gating_output = F.softmax( torch.rand( - # (num_calls, bs, num_total_experts), # THIS (bs, num_total_experts), device=hidden_states.device, dtype=torch.float32, @@ -345,7 +194,7 @@ def run_timing( topk_weights, topk_ids, token_expert_indicies, - gating_output.float(), # TODO(woosuk): Optimize this. + gating_output.float(), ) del token_expert_indicies # Not used. Will be used in the future. @@ -376,7 +225,7 @@ def run_timing( end_event = torch.cuda.Event(enable_timing=True) start_event.record() - for i in range(num_calls): + for _ in range(num_calls): invoke_fused_moe_kernel( hidden_states, w1, @@ -427,28 +276,28 @@ def run_timing( if __name__ == "__main__": parser = argparse.ArgumentParser( prog="benchmark_mixtral_moe_rocm", - description="Tune the fused_moe kernel for mixtral.") + description="Distributed tuning script for the fused_moe kernel.") + parser.add_argument('--model', + type=str, + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') parser.add_argument( - "--TP", + "--modelTP", type=int, choices=[8, 4, 2, 1], - help="Specify the TP value that the actual model will run on", + help="Specify the TP value that the model will actually run on", required=True, ) parser.add_argument( - "--GPUID", - type=str, - help="This script uses single GPU. Specify the GPU to use for tuning", - default="0", + "--numGPU", + type=int, + choices=[8, 4, 2, 1], + help="Total number of GPUs to use for tuning", + required=True, ) - parser.add_argument('--model', - type=str, - choices=['8x7B', '8x22B'], - help='The Mixtral model to benchmark') - args = parser.parse_args() print(f"Running tuning for {args.model} model") - print(f"TP is set to: {args.TP}") - print(f"GPU-ID being used for tuning: {args.GPUID}") + print(f"Model TP is set to: {args.modelTP}") + print(f"GPUs being used for tuning: {args.numGPU}") sys.exit(main(args)) diff --git a/benchmarks/kernels/tuning_utils.py b/benchmarks/kernels/tuning_utils.py new file mode 100644 index 0000000000000..120e9de870867 --- /dev/null +++ b/benchmarks/kernels/tuning_utils.py @@ -0,0 +1,134 @@ + + +## Utilize method from rocm/Triton tuning script +def get_full_tuning_space(): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [1, 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + # for split_k in split_k_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for (matrix_instr_nonkdim + ) in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, + "waves_per_eu": waves_per_eu, + "matrix_instr_nonkdim": + matrix_instr_nonkdim, + "kpack": kpack, + }) + + return configs + + +## Utilize method from rocm/Triton tuning script +def prune_configs(M, N, K, configs): + pruned_configs = [] + elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + # kpack = config.get("kpack") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elements per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = 1 # config.get("SPLIT_K") + GROUP_M = config.get("GROUP_SIZE_M") + if (matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N): + continue + if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: + continue + if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def union_of_list_of_dicts(l1, l2): + result = [] + temp_list = l1.copy() + temp_list.extend(l2) + for myDict in temp_list: + if myDict not in result: + result.append(myDict) + + return result + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 From 37fc50073e4890bbb956e175a3c1ad4ee4b6a03f Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Fri, 16 Aug 2024 18:54:43 +0000 Subject: [PATCH 02/14] ruff & yapf --- .../kernels/benchmark_mixtral_moe_rocm.py | 43 ++++++++++++------- benchmarks/kernels/tuning_utils.py | 2 - 2 files changed, 27 insertions(+), 18 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 399348300554a..5f84c7bc6c680 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -4,41 +4,50 @@ import sys import torch +import torch.distributed as dist +import torch.multiprocessing as mp import torch.nn.functional as F import triton import triton.language as tl from tqdm import tqdm -import torch.distributed as dist -import torch.multiprocessing as mp - -from tuning_utils import (get_full_tuning_space, prune_configs, union_of_list_of_dicts) +from tuning_utils import ( + get_full_tuning_space, + prune_configs, + union_of_list_of_dicts, +) import vllm._moe_C as moe_kernels from vllm._C import ops -from vllm.model_executor.layers.fused_moe import (get_config_file_name, - invoke_fused_moe_kernel, - moe_align_block_size) +from vllm.model_executor.layers.fused_moe import ( + get_config_file_name, + invoke_fused_moe_kernel, + moe_align_block_size, +) os.environ["HIP_FORCE_DEV_KERNARG"] = "1" os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" + def main(args): world_size = args.numGPU - mp.spawn(wrapper, args=(args,), nprocs=world_size, join=False) + mp.spawn(wrapper, args=(args, ), nprocs=world_size, join=False) + - def wrapper(rank, args): dist.init_process_group("nccl", world_size=args.numGPU, rank=rank) device_id = rank - - batches = [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096] + + batches = [ + 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, + 3072, 4096 + ] for i in range(device_id, len(batches), args.numGPU): tune_batch(batches[i], model=args.model, TP=args.modelTP) -def tune_batch(bs, model, TP): +def tune_batch(bs, model, TP): device_id = torch.distributed.get_rank() - + if model == '8x7B': d_model = 4096 model_intermediate_size = 14336 @@ -69,7 +78,9 @@ def tune_batch(bs, model, TP): best_config = None best_time_us = 1e20 - progress_bar = tqdm(total=len(configs), desc=f"bs={bs:4d} device={device_id}", position=device_id) + progress_bar = tqdm(total=len(configs), + desc=f"bs={bs:4d} device={device_id}", + position=device_id) with torch.cuda.device(device_id): for config in configs: @@ -135,7 +146,7 @@ def run_timing( device_ = "cuda" dtype_ = torch.float16 - + hidden_states = torch.rand( (bs, d_model), device=device_, @@ -194,7 +205,7 @@ def run_timing( topk_weights, topk_ids, token_expert_indicies, - gating_output.float(), + gating_output.float(), ) del token_expert_indicies # Not used. Will be used in the future. diff --git a/benchmarks/kernels/tuning_utils.py b/benchmarks/kernels/tuning_utils.py index 120e9de870867..ca75d59fe71ba 100644 --- a/benchmarks/kernels/tuning_utils.py +++ b/benchmarks/kernels/tuning_utils.py @@ -1,5 +1,3 @@ - - ## Utilize method from rocm/Triton tuning script def get_full_tuning_space(): configs = [] From 6f45d02b699ae69c1f7380b246e76c62c91c3eda Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Tue, 17 Sep 2024 19:17:01 +0000 Subject: [PATCH 03/14] kernel api update & Torchrun usage warning --- .../kernels/benchmark_mixtral_moe_rocm.py | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 5f84c7bc6c680..e9f6fc2c79ca0 100755 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -16,17 +16,13 @@ union_of_list_of_dicts, ) -import vllm._moe_C as moe_kernels -from vllm._C import ops +from vllm import _custom_ops as ops from vllm.model_executor.layers.fused_moe import ( get_config_file_name, invoke_fused_moe_kernel, moe_align_block_size, ) -os.environ["HIP_FORCE_DEV_KERNARG"] = "1" -os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" - def main(args): world_size = args.numGPU @@ -201,7 +197,7 @@ def run_timing( topk_, dtype=torch.int32, device=hidden_states.device) - moe_kernels.topk_softmax( + ops.topk_softmax( topk_weights, topk_ids, token_expert_indicies, @@ -253,8 +249,8 @@ def run_timing( config, compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16), - use_fp8=False, - ) + use_fp8_w8a8=False, + use_int8_w8a16=False) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -274,8 +270,8 @@ def run_timing( config, compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16), - use_fp8=False, - ) + use_fp8_w8a8=False, + use_int8_w8a16=False) end_event.record() end_event.synchronize() @@ -307,6 +303,12 @@ def run_timing( required=True, ) args = parser.parse_args() + if "LOCAL_RANK" not in os.environ: + print("Please use torchrun to launch this multi-gpu script. E.g:") + print("\ttorchrun benchmark_mixtral_moe_rocm.py", + "--model 8x7B --modelTP 4 --numGPU 2") + print("Exiting...") + exit() print(f"Running tuning for {args.model} model") print(f"Model TP is set to: {args.modelTP}") From ffaffcdc40d6f9bf9a7742c604bfcca65f643360 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Tue, 17 Sep 2024 19:33:38 +0000 Subject: [PATCH 04/14] [nit] file mode fix --- benchmarks/kernels/benchmark_mixtral_moe_rocm.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 benchmarks/kernels/benchmark_mixtral_moe_rocm.py diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py old mode 100755 new mode 100644 From 0e0608f36a435e623202de51073ddec6aacf09c8 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Tue, 17 Sep 2024 19:42:00 +0000 Subject: [PATCH 05/14] ruff ruff ruff (isort) --- benchmarks/kernels/benchmark_mixtral_moe_rocm.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index e9f6fc2c79ca0..0dde1c5d78e45 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -10,18 +10,13 @@ import triton import triton.language as tl from tqdm import tqdm -from tuning_utils import ( - get_full_tuning_space, - prune_configs, - union_of_list_of_dicts, -) +from tuning_utils import (get_full_tuning_space, prune_configs, + union_of_list_of_dicts) from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import ( - get_config_file_name, - invoke_fused_moe_kernel, - moe_align_block_size, -) +from vllm.model_executor.layers.fused_moe import (get_config_file_name, + invoke_fused_moe_kernel, + moe_align_block_size) def main(args): From fd6a4eef4f78fd2cc96cc024abeee319ea4e5a1c Mon Sep 17 00:00:00 2001 From: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Date: Fri, 20 Sep 2024 22:20:31 -0500 Subject: [PATCH 06/14] fix hardcoded top_k --- benchmarks/kernels/benchmark_mixtral_moe_rocm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 0dde1c5d78e45..a8ca8bfa29816 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -179,7 +179,7 @@ def run_timing( ] M, _ = hidden_states.shape E, N, _ = w1.shape - topk_ = 2 + topk_ = top_k topk_weights = torch.empty(M, topk_, dtype=torch.float32, From 2cb34c4a9d5bf8c4f02ede74a72683a1b81271d8 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Thu, 26 Sep 2024 19:14:43 +0000 Subject: [PATCH 07/14] add exception handling to see silen torchrun failures --- benchmarks/kernels/benchmark_mixtral_moe_rocm.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index a8ca8bfa29816..79c9547b57edd 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -21,7 +21,10 @@ def main(args): world_size = args.numGPU - mp.spawn(wrapper, args=(args, ), nprocs=world_size, join=False) + try: + mp.spawn(wrapper, args=(args,), nprocs=world_size, join=False) + except Exception as e: + print(f"An error occurred during multiprocessing: {e}") def wrapper(rank, args): @@ -32,8 +35,11 @@ def wrapper(rank, args): 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096 ] - for i in range(device_id, len(batches), args.numGPU): - tune_batch(batches[i], model=args.model, TP=args.modelTP) + try: + for i in range(device_id, len(batches), args.numGPU): + tune_batch(batches[i], model=args.model, TP=args.modelTP) + except Exception as e: + print(f"An error occurred on device {device_id}: {e}") def tune_batch(bs, model, TP): From 60860d945153c0c733ad83671ade43cc98cec8e7 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Thu, 26 Sep 2024 21:59:03 +0000 Subject: [PATCH 08/14] yapf --- benchmarks/kernels/benchmark_mixtral_moe_rocm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 79c9547b57edd..aef2406fee7bb 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -22,7 +22,7 @@ def main(args): world_size = args.numGPU try: - mp.spawn(wrapper, args=(args,), nprocs=world_size, join=False) + mp.spawn(wrapper, args=(args, ), nprocs=world_size, join=False) except Exception as e: print(f"An error occurred during multiprocessing: {e}") From 618663d52f313728e84e9a035a55cff22a41fc83 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Wed, 16 Oct 2024 20:05:52 +0000 Subject: [PATCH 09/14] use itertool.product for readability --- benchmarks/kernels/tuning_utils.py | 42 ++++++++++++++---------------- 1 file changed, 19 insertions(+), 23 deletions(-) diff --git a/benchmarks/kernels/tuning_utils.py b/benchmarks/kernels/tuning_utils.py index ca75d59fe71ba..88994fd3649d5 100644 --- a/benchmarks/kernels/tuning_utils.py +++ b/benchmarks/kernels/tuning_utils.py @@ -1,3 +1,6 @@ +from itertools import product + + ## Utilize method from rocm/Triton tuning script def get_full_tuning_space(): configs = [] @@ -15,29 +18,22 @@ def get_full_tuning_space(): matrix_instr_nonkdim_range = [16, 32] kpack_range = [1, 2] - for block_m in block_mn_range: - for block_n in block_mn_range: - for block_k in block_k_range: - for num_warps in num_warps_range: - for group_m in group_m_range: - # for split_k in split_k_range: - for num_stages in num_stage_range: - for waves_per_eu in waves_per_eu_range: - for (matrix_instr_nonkdim - ) in matrix_instr_nonkdim_range: - for kpack in kpack_range: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_m, - "num_warps": num_warps, - "num_stages": num_stages, - "waves_per_eu": waves_per_eu, - "matrix_instr_nonkdim": - matrix_instr_nonkdim, - "kpack": kpack, - }) + param_ranges = { + "BLOCK_SIZE_M": block_mn_range, + "BLOCK_SIZE_N": block_mn_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + "waves_per_eu": waves_per_eu_range, + "matrix_instr_nonkdim": matrix_instr_nonkdim_range, + "kpack": kpack_range, + } + + keys, values = zip(*param_ranges.items()) + for config_values in product(*values): + config = dict(zip(keys, config_values)) + configs.append(config) return configs From d47b89ccdd0f319d279d8c3d8fd9ec995c547cca Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Wed, 16 Oct 2024 20:20:46 +0000 Subject: [PATCH 10/14] re-use fused_topk function --- .../kernels/benchmark_mixtral_moe_rocm.py | 25 +++---------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index aef2406fee7bb..4b60b2e81ceb1 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -14,7 +14,8 @@ union_of_list_of_dicts) from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import (get_config_file_name, +from vllm.model_executor.layers.fused_moe import (fused_topk, + get_config_file_name, invoke_fused_moe_kernel, moe_align_block_size) @@ -185,28 +186,8 @@ def run_timing( ] M, _ = hidden_states.shape E, N, _ = w1.shape - topk_ = top_k - topk_weights = torch.empty(M, - topk_, - dtype=torch.float32, - device=hidden_states.device) - topk_ids = torch.empty(M, - topk_, - dtype=torch.int32, - device=hidden_states.device) - token_expert_indicies = torch.empty(M, - topk_, - dtype=torch.int32, - device=hidden_states.device) - ops.topk_softmax( - topk_weights, - topk_ids, - token_expert_indicies, - gating_output.float(), - ) - del token_expert_indicies # Not used. Will be used in the future. - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, top_k, True) intermediate_cache1 = torch.empty( (M, topk_ids.shape[1], N), From 17f95332c4631f6135f835947136d5820c72d0d2 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Wed, 16 Oct 2024 20:26:24 +0000 Subject: [PATCH 11/14] yapf --- benchmarks/kernels/benchmark_mixtral_moe_rocm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index 4b60b2e81ceb1..f7be38d7a942f 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -187,7 +187,8 @@ def run_timing( M, _ = hidden_states.shape E, N, _ = w1.shape - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, top_k, True) + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, top_k, + True) intermediate_cache1 = torch.empty( (M, topk_ids.shape[1], N), From 7202911cf6fd9a9c00ff17b31c699b63c120371e Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Wed, 16 Oct 2024 21:12:28 +0000 Subject: [PATCH 12/14] keep the config keys sorted in multi-gpu --- benchmarks/kernels/benchmark_mixtral_moe_rocm.py | 9 +++++++++ requirements-rocm.txt | 3 ++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index f7be38d7a942f..f595d95756f68 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -9,6 +9,7 @@ import torch.nn.functional as F import triton import triton.language as tl +from natsort import natsorted from tqdm import tqdm from tuning_utils import (get_full_tuning_space, prune_configs, union_of_list_of_dicts) @@ -125,11 +126,19 @@ def tune_batch(bs, model, TP): with open(filename, "r") as f: existing_content = json.load(f) existing_content[str(bs)] = best_config + existing_content = sort_json(existing_content) with open(filename, "w") as f: json.dump(existing_content, f, indent=4) f.write("\n") +def sort_json(json_file): + return { + k: v + for k, v in natsorted(json_file.items(), key=lambda item: item[0]) + } + + def run_timing( num_calls: int, bs: int, diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 9e3c4a86cd81d..5d6998c2e3b36 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -9,4 +9,5 @@ ray >= 2.10.0 peft pytest-asyncio tensorizer>=2.9.0 -setuptools-scm>=8 \ No newline at end of file +setuptools-scm>=8 +natsort \ No newline at end of file From c1134b63181da85cbaa8460f8f5c726cf50919f4 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Wed, 23 Oct 2024 02:03:46 +0000 Subject: [PATCH 13/14] add fp8 tuning support --- .../kernels/benchmark_mixtral_moe_rocm.py | 61 ++++++++++++++----- benchmarks/kernels/tuning_utils.py | 44 +++++++------ 2 files changed, 71 insertions(+), 34 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index f595d95756f68..fcd086a0378aa 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -2,6 +2,7 @@ import json import os import sys +import time import torch import torch.distributed as dist @@ -23,10 +24,13 @@ def main(args): world_size = args.numGPU + start_time = time.time() try: - mp.spawn(wrapper, args=(args, ), nprocs=world_size, join=False) + mp.spawn(wrapper, args=(args, ), nprocs=world_size, join=True) except Exception as e: print(f"An error occurred during multiprocessing: {e}") + end_time = time.time() + print(f"Total time taken: {end_time - start_time:.2f} seconds") def wrapper(rank, args): @@ -39,12 +43,15 @@ def wrapper(rank, args): ] try: for i in range(device_id, len(batches), args.numGPU): - tune_batch(batches[i], model=args.model, TP=args.modelTP) + tune_batch(batches[i], args) except Exception as e: print(f"An error occurred on device {device_id}: {e}") -def tune_batch(bs, model, TP): +def tune_batch(bs, args): + model = args.model + TP = args.modelTP + use_fp8 = args.use_fp8 device_id = torch.distributed.get_rank() if model == '8x7B': @@ -61,16 +68,16 @@ def tune_batch(bs, model, TP): tp_size = TP num_calls = 100 - full_configs = get_full_tuning_space() + full_configs = get_full_tuning_space(use_fp8) M1 = bs * 2 N1 = model_intermediate_size * 2 // tp_size K1 = d_model - prune_configs_1 = prune_configs(M1, N1, K1, full_configs) + prune_configs_1 = prune_configs(M1, N1, K1, full_configs, use_fp8) M2 = bs * 2 N2 = d_model K2 = model_intermediate_size // tp_size - prune_configs_2 = prune_configs(M2, N2, K2, full_configs) + prune_configs_2 = prune_configs(M2, N2, K2, full_configs, use_fp8) configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) @@ -87,7 +94,7 @@ def tune_batch(bs, model, TP): # warmup try: run_timing( - num_calls=num_calls, + num_calls=5, bs=bs, d_model=d_model, num_total_experts=num_total_experts, @@ -95,8 +102,10 @@ def tune_batch(bs, model, TP): tp_size=tp_size, model_intermediate_size=model_intermediate_size, config=config, + use_fp8_w8a8=use_fp8, ) - except triton.runtime.autotuner.OutOfResources: + except Exception as e: + print(f"Error during warmup: {e}") continue # benchmark @@ -109,6 +118,7 @@ def tune_batch(bs, model, TP): tp_size=tp_size, model_intermediate_size=model_intermediate_size, config=config, + use_fp8_w8a8=use_fp8, ) kernel_dur_us = 1000 * kernel_dur_ms @@ -117,9 +127,10 @@ def tune_batch(bs, model, TP): best_config = config best_time_us = kernel_dur_us + config_dtype = "fp8_w8a8" if use_fp8 else None filename = get_config_file_name(num_total_experts, model_intermediate_size // tp_size, - dtype=None) + dtype=config_dtype) print(f"writing config to file {filename}") existing_content = {} if os.path.exists(filename): @@ -148,6 +159,7 @@ def run_timing( tp_size: int, model_intermediate_size: int, config, + use_fp8_w8a8: bool, ) -> float: shard_intermediate_size = model_intermediate_size // tp_size @@ -180,6 +192,19 @@ def run_timing( ), dim=-1, ) + w1_scale = None + w2_scale = None + a1_scale = None + a2_scale = None + + if use_fp8_w8a8: + w1_scale = torch.randn(num_total_experts, dtype=torch.float32, device=device_) + w2_scale = torch.randn(num_total_experts, dtype=torch.float32, device=device_) + a1_scale = torch.randn(1, dtype=torch.float32, device=device_) + a2_scale = torch.randn(1, dtype=torch.float32, device=device_) + + w1 = w1.to(torch.float8_e4m3fnuz) + w2 = w2.to(torch.float8_e4m3fnuz) ###### Stuff from fused moe ###### @@ -229,8 +254,8 @@ def run_timing( hidden_states, w1, intermediate_cache1, - None, # a1_scale - None, # w1_scale + a1_scale, + w1_scale, topk_weights, topk_ids, sorted_token_ids, @@ -241,7 +266,7 @@ def run_timing( config, compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16), - use_fp8_w8a8=False, + use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=False) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -250,8 +275,8 @@ def run_timing( intermediate_cache2, w2, intermediate_cache3, - None, # a2_scale - None, # w2_scale + a2_scale, + w2_scale, topk_weights, topk_ids, sorted_token_ids, @@ -262,7 +287,7 @@ def run_timing( config, compute_type=(tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16), - use_fp8_w8a8=False, + use_fp8_w8a8=use_fp8_w8a8, use_int8_w8a16=False) end_event.record() @@ -294,6 +319,11 @@ def run_timing( help="Total number of GPUs to use for tuning", required=True, ) + parser.add_argument( + "--use_fp8", + action="store_true", + help="Flag to indicate whether to use FP8 tuning", + ) args = parser.parse_args() if "LOCAL_RANK" not in os.environ: print("Please use torchrun to launch this multi-gpu script. E.g:") @@ -305,4 +335,5 @@ def run_timing( print(f"Running tuning for {args.model} model") print(f"Model TP is set to: {args.modelTP}") print(f"GPUs being used for tuning: {args.numGPU}") + print(f"Using FP8: {args.use_fp8}") sys.exit(main(args)) diff --git a/benchmarks/kernels/tuning_utils.py b/benchmarks/kernels/tuning_utils.py index 88994fd3649d5..4055d5f5efd8e 100644 --- a/benchmarks/kernels/tuning_utils.py +++ b/benchmarks/kernels/tuning_utils.py @@ -2,11 +2,13 @@ ## Utilize method from rocm/Triton tuning script -def get_full_tuning_space(): +def get_full_tuning_space(use_fp8): configs = [] block_mn_range = [16, 32, 64, 128, 256] block_k_range = [16, 32, 64, 128, 256] + if use_fp8: + block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 # split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] num_warps_range = [1, 2, 4, 8] group_m_range = [1, 4, 8, 16, 32] @@ -15,8 +17,8 @@ def get_full_tuning_space(): # other values in the future num_stage_range = [0] waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] - kpack_range = [1, 2] + matrix_instr_nonkdim_range = [] if use_fp8 else [16, 32] + kpack_range = [] if use_fp8 else [1, 2] param_ranges = { "BLOCK_SIZE_M": block_mn_range, @@ -26,10 +28,12 @@ def get_full_tuning_space(): "num_warps": num_warps_range, "num_stages": num_stage_range, "waves_per_eu": waves_per_eu_range, - "matrix_instr_nonkdim": matrix_instr_nonkdim_range, - "kpack": kpack_range, } + if not use_fp8: + param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range + param_ranges["kpack"] = kpack_range + keys, values = zip(*param_ranges.items()) for config_values in product(*values): config = dict(zip(keys, config_values)) @@ -39,10 +43,10 @@ def get_full_tuning_space(): ## Utilize method from rocm/Triton tuning script -def prune_configs(M, N, K, configs): +def prune_configs(M, N, K, configs, is_fp8=False): pruned_configs = [] - elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) - elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_a = 1 if is_fp8 else 2 # Assuming fp16 or fp8 cases only + elemBytes_b = 1 if is_fp8 else 2 # Assuming fp16 or fp8 cases only mfma = 16 if M < 32 or N < 32 else 32 @@ -56,10 +60,11 @@ def prune_configs(M, N, K, configs): BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") num_warps = config.get("num_warps") - matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - # kpack = config.get("kpack") - if matrix_instr_nonkdim > mfma: - continue + + if not is_fp8: + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + if matrix_instr_nonkdim > mfma: + continue if mfma == 4 and BLOCK_SIZE_K < 64: continue # some layouts could not work properly in case @@ -68,13 +73,14 @@ def prune_configs(M, N, K, configs): continue SPLIT_K = 1 # config.get("SPLIT_K") GROUP_M = config.get("GROUP_SIZE_M") - if (matrix_instr_nonkdim > BLOCK_SIZE_M - or matrix_instr_nonkdim > BLOCK_SIZE_N): - continue - if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: - continue - if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: - continue + if not is_fp8: + if (matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N): + continue + if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: + continue + if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: + continue # Skip BLOCK_SIZE that is too large compare to M/N # unless BLOCK_SIZE is already small enough if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: From 0356fddedac60df03b0a0a22e49e917f3d1d7e38 Mon Sep 17 00:00:00 2001 From: Divakar Verma Date: Wed, 23 Oct 2024 02:11:45 +0000 Subject: [PATCH 14/14] ruff + yapf --- benchmarks/kernels/benchmark_mixtral_moe_rocm.py | 9 ++++++--- benchmarks/kernels/tuning_utils.py | 6 ++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py index fcd086a0378aa..ba5f77e4463ec 100644 --- a/benchmarks/kernels/benchmark_mixtral_moe_rocm.py +++ b/benchmarks/kernels/benchmark_mixtral_moe_rocm.py @@ -8,7 +8,6 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn.functional as F -import triton import triton.language as tl from natsort import natsorted from tqdm import tqdm @@ -198,8 +197,12 @@ def run_timing( a2_scale = None if use_fp8_w8a8: - w1_scale = torch.randn(num_total_experts, dtype=torch.float32, device=device_) - w2_scale = torch.randn(num_total_experts, dtype=torch.float32, device=device_) + w1_scale = torch.randn(num_total_experts, + dtype=torch.float32, + device=device_) + w2_scale = torch.randn(num_total_experts, + dtype=torch.float32, + device=device_) a1_scale = torch.randn(1, dtype=torch.float32, device=device_) a2_scale = torch.randn(1, dtype=torch.float32, device=device_) diff --git a/benchmarks/kernels/tuning_utils.py b/benchmarks/kernels/tuning_utils.py index 4055d5f5efd8e..5a9c84d9ce736 100644 --- a/benchmarks/kernels/tuning_utils.py +++ b/benchmarks/kernels/tuning_utils.py @@ -77,9 +77,11 @@ def prune_configs(M, N, K, configs, is_fp8=False): if (matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N): continue - if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: + if (matrix_instr_nonkdim >= M + and matrix_instr_nonkdim != BLOCK_SIZE_M): continue - if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: + if (matrix_instr_nonkdim >= N + and matrix_instr_nonkdim != BLOCK_SIZE_N): continue # Skip BLOCK_SIZE that is too large compare to M/N # unless BLOCK_SIZE is already small enough