Skip to content

Commit

Permalink
add multi-gpu tuning support with tqdm progress bar
Browse files Browse the repository at this point in the history
- todo: add fp8 support
- todo: add comments and documentation
  • Loading branch information
divakar-amd committed Sep 17, 2024
1 parent 6bd99d2 commit fd16af5
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 204 deletions.
257 changes: 53 additions & 204 deletions benchmarks/kernels/benchmark_mixtral_moe_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,178 +8,37 @@
import triton
import triton.language as tl
from tqdm import tqdm
import torch.distributed as dist
import torch.multiprocessing as mp

from tuning_utils import (get_full_tuning_space, prune_configs, union_of_list_of_dicts)

import vllm._moe_C as moe_kernels
from vllm._C import ops
from vllm.model_executor.layers.fused_moe import (get_config_file_name,
invoke_fused_moe_kernel,
moe_align_block_size)

os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1"

def main(args):
os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID
os.environ["HIP_FORCE_DEV_KERNARG"] = "1"
os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1"
os.environ["OPTIMIZE_EPILOGUE"] = "1"

for bs in [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]:
run_grid(bs, model=args.model, TP=args.TP)


## Utilize method from rocm/Triton tuning script
def get_full_tuning_space():
configs = []

block_mn_range = [16, 32, 64, 128, 256]
block_k_range = [16, 32, 64, 128, 256]
# split_k_range = [1] #, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24]
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [0]
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32]
kpack_range = [1, 2]

for block_m in block_mn_range:
for block_n in block_mn_range:
for block_k in block_k_range:
for num_warps in num_warps_range:
for group_m in group_m_range:
# for split_k in split_k_range:
for num_stages in num_stage_range:
for waves_per_eu in waves_per_eu_range:
for (matrix_instr_nonkdim
) in matrix_instr_nonkdim_range:
for kpack in kpack_range:
configs.append({
"BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": group_m,
"num_warps": num_warps,
"num_stages": num_stages,
"waves_per_eu": waves_per_eu,
"matrix_instr_nonkdim":
matrix_instr_nonkdim,
"kpack": kpack,
})

return configs


## Utilize method from rocm/Triton tuning script
def prune_configs(M, N, K, configs):
pruned_configs = []
elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes)
elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes)

mfma = 16 if M < 32 or N < 32 else 32

# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True

for config in configs:
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
# kpack = config.get("kpack")
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = 1 # config.get("SPLIT_K")
GROUP_M = config.get("GROUP_SIZE_M")
if (matrix_instr_nonkdim > BLOCK_SIZE_M
or matrix_instr_nonkdim > BLOCK_SIZE_N):
continue
if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
continue
if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue

pruned_configs.append(config)

return pruned_configs


def union_of_list_of_dicts(l1, l2):
result = []
temp_list = l1.copy()
temp_list.extend(l2)
for myDict in temp_list:
if myDict not in result:
result.append(myDict)

return result


def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024


def run_grid(bs, model, TP):
world_size = args.numGPU
mp.spawn(wrapper, args=(args,), nprocs=world_size, join=False)


def wrapper(rank, args):
dist.init_process_group("nccl", world_size=args.numGPU, rank=rank)
device_id = rank

batches = [1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, 2048, 3072, 4096]
for i in range(device_id, len(batches), args.numGPU):
tune_batch(batches[i], model=args.model, TP=args.modelTP)


def tune_batch(bs, model, TP):
device_id = torch.distributed.get_rank()

if model == '8x7B':
d_model = 4096
model_intermediate_size = 14336
Expand All @@ -194,9 +53,6 @@ def run_grid(bs, model, TP):
tp_size = TP
num_calls = 100

num_warmup_trials = 1
num_trials = 1

full_configs = get_full_tuning_space()
M1 = bs * 2
N1 = model_intermediate_size * 2 // tp_size
Expand All @@ -209,16 +65,17 @@ def run_grid(bs, model, TP):
prune_configs_2 = prune_configs(M2, N2, K2, full_configs)

configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2)
print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \
{len(prune_configs_2)=} | {len(configs)=}")

best_config = None
best_time_us = 1e20

for config in tqdm(configs):
# warmup
try:
for _ in range(num_warmup_trials):
progress_bar = tqdm(total=len(configs), desc=f"bs={bs:4d} device={device_id}", position=device_id)

with torch.cuda.device(device_id):
for config in configs:
progress_bar.update(1)
# warmup
try:
run_timing(
num_calls=num_calls,
bs=bs,
Expand All @@ -229,11 +86,10 @@ def run_grid(bs, model, TP):
model_intermediate_size=model_intermediate_size,
config=config,
)
except triton.runtime.autotuner.OutOfResources:
continue
except triton.runtime.autotuner.OutOfResources:
continue

# benchmark
for _ in range(num_trials):
# benchmark
kernel_dur_ms = run_timing(
num_calls=num_calls,
bs=bs,
Expand All @@ -246,20 +102,11 @@ def run_grid(bs, model, TP):
)

kernel_dur_us = 1000 * kernel_dur_ms
# model_dur_ms = kernel_dur_ms * num_layers

if kernel_dur_us < best_time_us:
best_config = config
best_time_us = kernel_dur_us

# print(f'{kernel_dur_us=:.1f} {model_dur_ms=:.1f}'
# f' {bs=} {tp_size=} {top_k=} {num_total_experts=} '
# f'{d_model=} {model_intermediate_size=} {num_layers=}')

# print("best_time_us", best_time_us)
# print("best_config", best_config)

# holds Dict[str, Dict[str, int]]
filename = get_config_file_name(num_total_experts,
model_intermediate_size // tp_size,
dtype=None)
Expand All @@ -286,10 +133,13 @@ def run_timing(
) -> float:
shard_intermediate_size = model_intermediate_size // tp_size

device_ = "cuda"
dtype_ = torch.float16

hidden_states = torch.rand(
(bs, d_model),
device="cuda",
dtype=torch.float16,
device=device_,
dtype=dtype_,
)

w1 = torch.rand(
Expand All @@ -306,7 +156,6 @@ def run_timing(

gating_output = F.softmax(
torch.rand(
# (num_calls, bs, num_total_experts), # THIS
(bs, num_total_experts),
device=hidden_states.device,
dtype=torch.float32,
Expand Down Expand Up @@ -345,7 +194,7 @@ def run_timing(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
gating_output.float(),
)
del token_expert_indicies # Not used. Will be used in the future.

Expand Down Expand Up @@ -376,7 +225,7 @@ def run_timing(
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
for i in range(num_calls):
for _ in range(num_calls):
invoke_fused_moe_kernel(
hidden_states,
w1,
Expand Down Expand Up @@ -427,28 +276,28 @@ def run_timing(
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="benchmark_mixtral_moe_rocm",
description="Tune the fused_moe kernel for mixtral.")
description="Distributed tuning script for the fused_moe kernel.")
parser.add_argument('--model',
type=str,
choices=['8x7B', '8x22B'],
help='The Mixtral model to benchmark')
parser.add_argument(
"--TP",
"--modelTP",
type=int,
choices=[8, 4, 2, 1],
help="Specify the TP value that the actual model will run on",
help="Specify the TP value that the model will actually run on",
required=True,
)
parser.add_argument(
"--GPUID",
type=str,
help="This script uses single GPU. Specify the GPU to use for tuning",
default="0",
"--numGPU",
type=int,
choices=[8, 4, 2, 1],
help="Total number of GPUs to use for tuning",
required=True,
)
parser.add_argument('--model',
type=str,
choices=['8x7B', '8x22B'],
help='The Mixtral model to benchmark')

args = parser.parse_args()

print(f"Running tuning for {args.model} model")
print(f"TP is set to: {args.TP}")
print(f"GPU-ID being used for tuning: {args.GPUID}")
print(f"Model TP is set to: {args.modelTP}")
print(f"GPUs being used for tuning: {args.numGPU}")
sys.exit(main(args))
Loading

0 comments on commit fd16af5

Please sign in to comment.