From 52c92c942569001cf00afc45ead6bb051feb57a0 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Mon, 18 Nov 2024 14:59:29 -0500 Subject: [PATCH] [Kernel] Initial Machete W4A8 support + Refactors (#9855) Signed-off-by: Lucas Wilkinson --- benchmarks/kernels/benchmark_machete.py | 519 +++++++++---- benchmarks/kernels/graph_machete_bench.py | 5 +- benchmarks/kernels/weight_shapes.py | 6 + csrc/cutlass_extensions/cute_utils.cuh | 4 +- .../epilogue}/broadcast_load_epilogue_c2x.hpp | 1 + .../epilogue}/broadcast_load_epilogue_c3x.hpp | 0 .../epilogue/scaled_mm_epilogues_c2x.hpp | 317 ++++++++ .../epilogue/scaled_mm_epilogues_c3x.hpp | 315 ++++++++ .../vllm_cutlass_library_extension.py | 29 + .../vllm_numeric_conversion.cuh | 239 +++++- csrc/cutlass_extensions/vllm_type_utils.cuh | 42 + .../cutlass_w8a8/scaled_mm_c2x.cu | 53 +- .../cutlass_w8a8/scaled_mm_c2x.cuh | 302 -------- .../cutlass_w8a8/scaled_mm_c3x.cu | 312 +------- csrc/quantization/machete/generate.py | 732 ++++++++++-------- .../quantization/machete/machete_mainloop.cuh | 25 +- .../machete/machete_mm_kernel.cuh | 206 +++-- .../machete/machete_mm_launcher.cuh | 90 +-- .../machete/machete_prepack_kernel.cuh | 63 +- .../machete/machete_prepack_launcher.cuh | 15 +- .../machete/machete_prepacked_layout.cuh | 54 +- csrc/quantization/machete/machete_pytorch.cu | 120 ++- csrc/torch_bindings.cpp | 35 +- tests/kernels/test_machete_gemm.py | 284 ------- tests/kernels/test_machete_mm.py | 406 ++++++++++ vllm/_custom_ops.py | 75 +- .../layers/quantization/kernels/machete.py | 16 +- .../layers/quantization/utils/quant_utils.py | 45 +- 28 files changed, 2616 insertions(+), 1694 deletions(-) rename csrc/{quantization/cutlass_w8a8 => cutlass_extensions/epilogue}/broadcast_load_epilogue_c2x.hpp (99%) rename csrc/{quantization/cutlass_w8a8 => cutlass_extensions/epilogue}/broadcast_load_epilogue_c3x.hpp (100%) create mode 100644 csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp create mode 100644 csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp create mode 100644 csrc/cutlass_extensions/vllm_type_utils.cuh delete mode 100644 tests/kernels/test_machete_gemm.py create mode 100644 tests/kernels/test_machete_mm.py diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 665b50bf18cf0..a0342d08f1db8 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -2,8 +2,10 @@ import copy import itertools import math +import os import pickle as pkl import time +from dataclasses import dataclass from itertools import product from typing import Callable, Iterable, List, Optional, Tuple @@ -15,11 +17,12 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales) + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, + marlin_zero_points) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, pack_rows, quantize_weights) + pack_rows, quantize_weights) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import FlexibleArgumentParser @@ -27,149 +30,349 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] DEFAULT_TP_SIZES = [1] +NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False) + +if NVTX_PROFILE: + import nvtx + + +def terse_type_name(dt): + return { + torch.bfloat16: "bf16", + torch.float16: "fp16", + torch.int8: "int8", + torch.float8_e4m3fn: "fp8", + torch.bfloat16: "bf16", + torch.float: "float", + torch.int: "int", + }[dt] + + +@dataclass +class BenchmarkTensors: + w_ref: torch.Tensor + a: torch.Tensor + + w_q: torch.Tensor + group_size: Optional[int] + wtype: ScalarType + w_g_s: torch.Tensor + w_g_zp: Optional[torch.Tensor] + w_ch_s: Optional[torch.Tensor] + w_tok_s: Optional[torch.Tensor] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + group_zero_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +def rand_data(shape, dtype=torch.float16, scale=1): + if dtype.is_floating_point: + return (scale * torch.rand(shape, device="cuda") - 0.3).to(dtype) + else: + return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") + + +def quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size=group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) -def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor: w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) - w_q = w_q.t().contiguous().t() # make col major - return ops.machete_prepack_B(w_q, wtype) + return w_ref, w_q, w_s, w_zp -def make_bench_tensors( - atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int, - k: int -) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor, - torch.tensor]]]: - assert wtype.is_integer(), "TODO: support floating point weights" +def create_bench_tensors(shape: Tuple[int, int, int], types: TypeConfig, + group_size: Optional[int]) -> List[BenchmarkTensors]: + m, n, k = shape # we want to make sure that weights don't fit into L2 cache between runs so # we construct enough weights to exceed L2 cache, which is 50mb on a H100 # so we target total weight size > 2*50mb - num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits)) - - a = torch.randn((m, k), device="cuda", dtype=atype) * 5 - weights = [ - torch.randn((k, n), device="cuda", dtype=atype) - for _ in range(num_weights) - ] - quanitized_weights = [ - quantize_weights(w, wtype, group_size) for w in weights - ] - - return a, quanitized_weights + num_weights = math.ceil(2 * 50 * 1024**2 * 8 / + (k * n * types.weight_type.size_bits)) + + a = rand_data((m, k), types.act_type, scale=5) + + benchmark_tensors: List[BenchmarkTensors] = [] + for _ in range(num_weights): + w = rand_data((k, n), types.act_type, scale=5) + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + types.group_zero_type is not None) + + if not a.dtype.is_floating_point: + aiinfo = torch.iinfo(a.dtype) + w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) + + w_ref = w_ref.to(torch.float32) + + w_ch_s = None if types.channel_scale_type is None else\ + rand_data((n,), types.channel_scale_type) + w_tok_s = None if types.token_scale_type is None else\ + rand_data((m,), types.token_scale_type) + + benchmark_tensors.append( + BenchmarkTensors(w_ref=w_ref, + a=a, + w_q=w_q_packed, + wtype=types.weight_type, + w_g_s=w_s, + w_g_zp=w_zp, + group_size=group_size, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s)) + + return benchmark_tensors + + +def torch_matmul_f16_create_bench_fn(bt: BenchmarkTensors) -> Callable: + a = bt.a + w = bt.w_ref.to(bt.a.dtype) # use float reference tensor + if a.dtype not in [torch.float16, torch.bfloat16]: + a = a.to(torch.float16) + w = w.to(torch.float16) + return lambda: torch.matmul(a, w) + + +def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: + if bt.w_ch_s is not None and bt.w_tok_s is not None: + scale_a = bt.w_tok_s.to(torch.float32) + scale_b = bt.w_ch_s.to(torch.float32) + else: + scale_a = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) + w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() + return lambda: ops.cutlass_scaled_mm( + bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) + + +def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: + device = bt.a.device + + workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + if bt.w_g_zp is None: + w_zp = torch.empty(0, dtype=torch.int, device=device) + else: + w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.wtype.size_bits) + + if bt.group_size is None: + w_s = torch.tensor([], device="cuda", dtype=torch.half) + else: + w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.group_size) + + sort_indices = torch.empty(0, dtype=torch.int, device=device) + g_idx = torch.empty(0, dtype=torch.int, device=device) + w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], + bt.w_ref.shape[1], bt.wtype.size_bits) + + if bt.a.dtype.is_floating_point: + assert bt.w_ch_s is None + assert bt.w_tok_s is None + assert bt.group_size is not None + + fn = lambda: ops.gptq_marlin_gemm(a=bt.a, + b_q_weight=w_q, + b_scales=w_s, + b_zeros=w_zp, + g_idx=g_idx, + perm=sort_indices, + workspace=workspace.scratch, + b_q_type=bt.wtype, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + is_k_full=True) + else: + assert bt.a.dtype == torch.int8 + assert bt.wtype == scalar_types.uint4b8 + + if bt.w_ch_s is not None: + s_ch = bt.w_ch_s.to(torch.float32) + else: + s_ch = torch.ones(bt.w_ref.shape[1], + dtype=torch.float32, + device=device) + + if bt.w_tok_s is not None: + s_tok = bt.w_tok_s.to(torch.float32) + else: + s_tok = torch.ones(bt.a.shape[0], + dtype=torch.float32, + device=device) + + fn = lambda: ops.marlin_qqq_gemm(a=bt.a, + b_q_weight=w_q, + s_group=w_s, + s_tok=s_tok, + s_ch=s_ch, + workspace=workspace.scratch, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0]) + + return fn + + +def machete_create_bench_fn(bt: BenchmarkTensors, + out_type=torch.dtype, + schedule=None) -> Callable: + w_q = bt.w_q.t().contiguous().t() # make col major + w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, + None if bt.w_g_s is None else bt.w_g_s.dtype) + + w_g_zp = bt.w_g_zp + if w_g_zp is not None: + w_g_zp = -1 * bt.w_g_s * (w_g_zp.to(bt.w_g_s.dtype)) + + return lambda: ops.machete_mm( + a=bt.a, + b_q=bt.w_q, + b_type=bt.wtype, + b_group_scales=bt.w_g_s, + b_group_zeros=w_g_zp, + b_group_size=bt.group_size, + b_channel_scales=bt.w_ch_s, + a_token_scales=bt.w_tok_s, + out_type=out_type, + schedule=schedule, + ) # impl - # bench -def bench_fn(label: str, sub_label: str, description: str, - fn: Callable) -> TMeasurement: - min_run_time = 1 - return TBenchmark.Timer( - stmt="fn()", + +def bench_fns(label: str, sub_label: str, description: str, + fns: List[Callable]): + + min_run_time = 1 if not NVTX_PROFILE else 0.1 + res = TBenchmark.Timer( + stmt=""" + for fn in fns: + fn() + """, globals={ - "fn": fn + "fns": fns }, label=label, sub_label=sub_label, description=description, ).blocked_autorange(min_run_time=min_run_time) + if NVTX_PROFILE: + with nvtx.annotate("mm-bench"), nvtx.annotate( + f"{label}|{sub_label}|{description}"): + fns[0]() -def loop_over_weights( - a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor, - torch.tensor, torch.tensor]], - fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor], - None]): - for w_ref, w_q, w_s, _ in weights: - fn(a, w_ref, w_q, w_s) + return res _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None -def bench(atype: torch.dtype, - wtype: ScalarType, +def bench(types: TypeConfig, group_size: int, m: int, k: int, n: int, label: str, sub_label: str, - benchmark_marlinv1: bool = True, - sweep_schedules: bool = True) -> Iterable[TMeasurement]: - global _SWEEP_SCHEDULES_RESULTS - - a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) - sub_label += f", L={len(weights)}" - - weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp) - for w_ref, w_q, w_s, w_zp in weights] + sweep_schedules: bool = True) -> List[TMeasurement]: + benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) + sub_label += f", L={len(benchmark_tensors)}" + + name_type_string = f"W{types.weight_type}"+\ + f"-A{terse_type_name(types.act_type)}" + if types.group_scale_type is not None: + name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" + if types.group_zero_type is not None: + name_type_string += f"-GZ{terse_type_name(types.group_zero_type)}" + if group_size is not None: + name_type_string += f"-G{group_size}" + if types.channel_scale_type is not None: + name_type_string += f"-CS{terse_type_name(types.channel_scale_type)}" + if types.token_scale_type is not None: + name_type_string += f"-TS{terse_type_name(types.token_scale_type)}" timers = [] # pytorch impl timers.append( - bench_fn( - label, sub_label, "torch.matmul", lambda: loop_over_weights( - a, - weights, - lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref), - ))) + bench_fns( + label, sub_label, "torch.matmul (fp16)", + [torch_matmul_f16_create_bench_fn(bt) + for bt in benchmark_tensors])) - if benchmark_marlinv1: - w_ref = weights[0][0] - - w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device) - sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device) - g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device) - - def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor: - w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape) - return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape, - wtype.size_bits) - - def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: - return marlin_permute_scales(w_s, *w_ref.shape, group_size) - - weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q), - marlinv1_permute_scales(w_s), w_zp) - for w_ref, w_q, w_s, w_zp in weights] - - workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) - - # marlinv1 + if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: + timers.append( + bench_fns( + label, sub_label, + f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ + cutlass_scaled_mm_create_bench_fn(bt) + for bt in benchmark_tensors + ])) + + if types.act_type != torch.float8_e4m3fn: timers.append( - bench_fn( - label, sub_label, "marlin_orig", lambda: loop_over_weights( - a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops. - gptq_marlin_gemm(a, - w_q, - w_s, - w_zp_empty, - g_idx, - sort_indices, - workspace.scratch, - wtype, - size_m=a.shape[0], - size_n=w_ref.shape[1], - size_k=w_ref.shape[0], - is_k_full=True)))) + bench_fns(label, sub_label, f"marlin ({name_type_string})", + [marlin_create_bench_fn(bt) + for bt in benchmark_tensors])) # machete timers.append( - bench_fn( - label, sub_label, "machete_heuristic", lambda: loop_over_weights( - a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm( - a, w_q, wtype, b_scales=w_s, b_group_size=group_size)))) + bench_fns(label, sub_label, f"machete ({name_type_string})", [ + machete_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ])) if sweep_schedules: + global _SWEEP_SCHEDULES_RESULTS + print("Finding best schedule for machete") best = None best_schedule = None - schedules = ops.machete_supported_schedules(wtype) + schedules = ops.machete_supported_schedules( + a_type=types.act_type, + b_type=types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_zero_type, + token_scales_type=types.token_scale_type, + channel_scales_type=types.channel_scale_type, + out_type=types.output_type) + + if schedules is None or len(schedules) == 0: + raise ValueError("No schedules found to sweep") + for schedule in reversed(schedules): schedule_M = int(schedule.split("_")[0].split("x")[1]) @@ -177,16 +380,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: continue - def run(a, _, w_q, w_s, schedule=schedule): - ops.machete_gemm(a, - w_q, - wtype, - w_s, - b_group_size=group_size, - schedule=schedule) - - res = bench_fn(label, sub_label, "machete_best", - lambda: loop_over_weights(a, weights_machete, run)) + res = bench_fns(label, sub_label, "machete_best", [ + machete_create_bench_fn( + bt, out_type=types.output_type, schedule=schedule) + for bt in benchmark_tensors + ]) results_row = { "M": m, @@ -213,25 +411,33 @@ def run(a, _, w_q, w_s, schedule=schedule): # runner -def print_timers(timers: Iterable[TMeasurement]): +def print_timers(timers: List[TMeasurement]): compare = TBenchmark.Compare(timers) compare.print() -def run(dtype: torch.dtype, sweep_schedules: bool, - MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: +def run(args, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + types = TypeConfig( + act_type=args.act_type, + weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ + else scalar_types.uint4, + output_type=args.out_type, + group_scale_type=args.group_scale_type, + group_zero_type=args.group_zero_type, + channel_scale_type=args.channel_scale_type, + token_scale_type=args.token_scale_type, + ) - results = [] + results: List[TMeasurement] = [] for m, k, n in MKNs: - timers = bench(dtype, - scalar_types.uint4b8, - 128, + timers = bench(types, + args.group_size, m, k, n, - f"{dtype}-gemm", + f"{args.act_type}-gemm", f"MKN=({m}x{k}x{n})", - sweep_schedules=sweep_schedules) + sweep_schedules=args.sweep_schedules) print_timers(timers) results.extend(timers) @@ -240,7 +446,7 @@ def run(dtype: torch.dtype, sweep_schedules: bool, # output makers def make_output( - data: Iterable[TMeasurement], + data: List[TMeasurement], MKNs: Iterable[Tuple[int, int, int]], base_description: str, timestamp=None, @@ -262,7 +468,6 @@ def run_square_bench(args): dim_sizes = list( range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) - data = run(args.dtype, args.sweep_schedules, MKNs) make_output(data, MKNs, f"square_bench-{args.dtype}") @@ -306,33 +511,49 @@ def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: for k, n in KNs: MKNs.append((m, k, n)) - data = run(args.dtype, args.sweep_schedules, MKNs) + data = run(args, MKNs) model_bench_data.append(data) + type_string = f"{args.act_type}" + # Print all results for data, model_tp in zip(model_bench_data, models_tps): model, tp_size = model_tp - print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print(f"== Results {type_string} {model}-TP{tp_size} ====") print_timers(data) - timestamp = int(time.time()) + timestr = time.strftime("%Y%m%d-%H%M%S") - all_data = [] + all_results = [] for d in model_bench_data: - all_data.extend(d) + all_results.extend(d) + # pickle all data - with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: - pkl.dump(all_data, f) + with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: + args_dict = vars(args) + args_dict.pop("func") + pkl.dump({ + "args": args_dict, + "results": all_results, + }, f) if __name__ == "__main__": def to_torch_dtype(dt): - if dt == "bfloat16": - return torch.bfloat16 - if dt == "float16": - return torch.float16 - raise ValueError("unsupported dtype") + return { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "int8": torch.int8, + "float8_e4m3fn": torch.float8_e4m3fn, + "int": torch.int, + "float": torch.float, + }[dt] + + class ToTorchDtype(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, to_torch_dtype(values)) parser = FlexibleArgumentParser( description=""" @@ -352,12 +573,42 @@ def to_torch_dtype(dt): """, # noqa: E501 formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument( - "--dtype", - type=to_torch_dtype, + "--act-type", + action=ToTorchDtype, required=True, - help="Available options are ['bfloat16', 'float16']", + choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], + ) + parser.add_argument( + "--group-scale-type", + action=ToTorchDtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--group-zero-type", + type=to_torch_dtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--channel-scale-type", + action=ToTorchDtype, + choices=['float'], + ) + parser.add_argument( + "--token-scale-type", + action=ToTorchDtype, + choices=['float'], + ) + parser.add_argument( + "--out-type", + action=ToTorchDtype, + choices=['bfloat16', 'float16'], + ) + parser.add_argument( + "--group-size", + type=int, + help="Available options are ['None', '-1', '128'], default=128", + default=128, ) parser.add_argument( "--sweep-schedules", diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index de608fd05af70..7d0bd84150a27 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -20,10 +20,11 @@ args = parser.parse_args() with open(args.filename, 'rb') as f: - data: List[TMeasurement] = pickle.load(f) + data = pickle.load(f) + raw_results: List[TMeasurement] = data["results"] results = defaultdict(lambda: list()) - for v in data: + for v in raw_results: result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) if result is not None: KN = result.group(1) diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index 25ec9d6028627..51f24f3ba1774 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -40,4 +40,10 @@ ([8192, 57344], 1), ([28672, 8192], 0), ], + "meta-llama/Llama-3.1-405b-hf": [ + ([16384, 18432], 1), + ([16384, 16384], 0), + ([16384, 106496], 1), + ([53248, 16384], 0), + ], } diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh index 1842fab8b2cac..f61fe3ceb978a 100644 --- a/csrc/cutlass_extensions/cute_utils.cuh +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -20,9 +20,9 @@ CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { // is the layout f(x) = x template CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { - if constexpr (std::is_same_v) + if constexpr (std::is_same_v) { return true; - else { + } else { constexpr auto coalesced_layout = coalesce(Layout{}); if constexpr (rank(coalesced_layout) == 1 && stride<0>(coalesced_layout) == 1) { diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp similarity index 99% rename from csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp rename to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp index d407d66ab2aa6..7aa87feb4cce2 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp +++ b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp @@ -52,6 +52,7 @@ // clang-format off #include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cute/tensor.hpp" namespace cutlass::epilogue::threadblock { diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp b/csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp similarity index 100% rename from csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c3x.hpp rename to csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp new file mode 100644 index 0000000000000..c69e87999ae71 --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -0,0 +1,317 @@ +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs. + + Epilogues must contain a public type named EVTCompute of type Sm80EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c2x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + + template + using ColOrScalarLoad = + cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = + cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< + OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; + + template + using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + template + using RowOrZeroLoad = + cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< + OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + // it would technically work but no use case as data_ptr is never nullptr + static_assert(!std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + static_assert(std::is_same_v>); + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch._scaled_mm. + + A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or + per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : protected ScaledEpilogueBase { + protected: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::threadblock::Sm80EVT; + + using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : protected ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowOrZeroLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::threadblock::Sm80EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::threadblock::Sm80EVT; + + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +}; // namespace vllm::c2x \ No newline at end of file diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp new file mode 100644 index 0000000000000..95764ecddc79f --- /dev/null +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -0,0 +1,315 @@ +#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" + +/* + This file defines custom epilogues for fusing channel scales, token scales, + bias, and activation zero-points onto a GEMM operation using the + CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper) or later. + + Epilogues must contain a public type named EVTCompute of type Sm90EVT, + as well as a static prepare_args function that constructs an + EVTCompute::Arguments struct. +*/ + +namespace vllm::c3x { + +using namespace cute; + +/* + * This class provides the common load descriptors for the + * ScaledEpilogue[...] classes + */ +template +struct ScaledEpilogueBase { + protected: + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + template + using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>>; + + template + using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>>; + + // Don't want to support nullptr by default + template + using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // Don't want to support nullptr by default + template + using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< + 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, + Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + + // This utility function constructs the arguments for the load descriptors + // from a tensor. It can handle both row and column, as well as row/column or + // scalar cases. + template + static auto args_from_tensor(torch::Tensor const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = static_cast(tensor.data_ptr()); + if constexpr (std::is_same_v> || + std::is_same_v>) { + return Arguments{data_ptr, tensor.numel() != 1}; + } else { + static_assert(!std::is_same_v> && + !std::is_same_v>); + return Arguments{data_ptr}; + } + } + + // This overload handles the case where there might not be a tensor, in which + // case a nullptr is passed and a constant (0) is used. + template + static auto args_from_tensor(c10::optional const& tensor) { + using Arguments = typename Descriptor::Arguments; + auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; + static_assert(std::is_same_v> || + std::is_same_v>); + return Arguments{data_ptr}; + } +}; + +/* + This epilogue function defines a quantized GEMM operation similar to + torch.scaled_mm_. + + A and B may be both either int8 or fp8_e4m3. A can be + quantized per-tensor or per-row. B can be quantized per-tensor or per-column. + Any combination of per-tensor and per-row or column is supported. + A and B must have symmetric quantization (zero point == 0). + + So the GEMM operation is D = (a_scales * A) (b_scales * B), where the + scales are applied elementwise with numpy-style broadcasting. + + ScaleA and ScaleB define the epilogue functions that apply the scales for + the A and B operands respectively. These scales may be either per-tensor or + per row or column. +*/ +template +struct ScaledEpilogue + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args}; + } +}; + +/* + * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. + * This bias can also be used in the per-tensor azp case, where the activation + * zero point (azp) is used to compute an azp correction term, + * which is folded into the bias. + * + * The bias tensor must be per-output channel. + * ScaleA and ScaleB can be per-tensor or per-token/per-channel. + */ +template +struct ScaledEpilogueBias + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTCompute0 = + cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + + using ArgumentType = typename EVTCompute::Arguments; + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + + typename EVTCompute0::Arguments evt0_args{b_args}; + return ArgumentType{a_args, evt0_args, bias_args}; + } +}; + +/* + * This epilogue directly supports per-tensor azp in int32 form. + * As opposed to the per-token epilogue below, this epilogue only has an azp_adj + * term, which should already be multiplied with the scalar azp. + * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzp + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // This is the full AZP term, azp * J @ B, shape (1,n) + using AzpWithAdj = typename SUPER::template RowLoad; + + // Compute float(accum - azp_adj), both operands are int32_t + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +/* + * This epilogue supports per-token azp by computing and applying + * the correction term using a rank-1 update. If the term were materialized, + * it would require O(m*n) space, and this way it only requires O(m+n) space. + * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero + * point for each row of A. + * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. + * + * This epilogue also supports bias, which remains per-channel. + */ +template +struct ScaledEpilogueBiasAzpToken + : private ScaledEpilogueBase { + private: + using SUPER = ScaledEpilogueBase; + using Accum = typename SUPER::Accum; + using ScaleA = typename SUPER::template ColOrScalarLoad; + using ScaleB = typename SUPER::template RowOrScalarLoad; + using Bias = typename SUPER::template RowLoad; + + // Per-token azp term, shape (m,1) + using Azp = typename SUPER::template ColLoad; + + // This is the AZP adjustment term, J @ B, shape (1,n) + using AzpAdj = typename SUPER::template RowLoad; + + // Compute azp * azp_adj + using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, int32_t, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAzp = + cutlass::epilogue::fusion::Sm90EVT; + + // Compute float(accum - azp*azp_adj), all operands are int32_t + using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< + cutlass::minus, float, int32_t, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeAcc = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiplies, float, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + using EVTComputeScaleB = + cutlass::epilogue::fusion::Sm90EVT; + + using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< + cutlass::multiply_add, ElementD, float, + cutlass::FloatRoundStyle::round_to_nearest>; + + public: + using EVTCompute = + cutlass::epilogue::fusion::Sm90EVT; + using ArgumentType = typename EVTCompute::Arguments; + + static ArgumentType prepare_args(torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + torch::Tensor const& azp, + c10::optional const& bias) { + auto a_args = SUPER::template args_from_tensor(a_scales); + auto b_args = SUPER::template args_from_tensor(b_scales); + auto bias_args = SUPER::template args_from_tensor(bias); + auto azp_args = SUPER::template args_from_tensor(azp); + auto azp_adj_args = + SUPER::template args_from_tensor(azp_adj); + + typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; + typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; + typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; + return ArgumentType{a_args, evt_scale_b_args, bias_args}; + } +}; + +}; // namespace vllm::c3x \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index 4fcfcd311aa91..a5beea1a35e49 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -35,6 +35,35 @@ class MixedInputKernelScheduleType(enum.Enum): } } +VLLMDataTypeSize: Dict[Union[VLLMDataType, DataType], int] = { + **DataTypeSize, # type: ignore + **{ + VLLMDataType.u4b8: 4, + VLLMDataType.u8b128: 8, + } +} + +VLLMDataTypeVLLMScalarTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + VLLMDataType.u4b8: "vllm::kU4B8", + VLLMDataType.u8b128: "vllm::kU8B128", + DataType.u4: "vllm::kU4", + DataType.u8: "vllm::kU8", + DataType.s4: "vllm::kS4", + DataType.s8: "vllm::kS8", + DataType.f16: "vllm::kFloat16", + DataType.bf16: "vllm::kBfloat16", +} + +VLLMDataTypeTorchDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + DataType.u8: "at::ScalarType::Byte", + DataType.s8: "at::ScalarType::Char", + DataType.e4m3: "at::ScalarType::Float8_e4m3fn", + DataType.s32: "at::ScalarType::Int", + DataType.f16: "at::ScalarType::Half", + DataType.bf16: "at::ScalarType::BFloat16", + DataType.f32: "at::ScalarType::Float", +} + VLLMKernelScheduleTag: Dict[Union[ MixedInputKernelScheduleType, KernelScheduleType], str] = { **KernelScheduleTag, # type: ignore diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh index 2ad914f8e9868..90f226cf64c0a 100644 --- a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -3,6 +3,7 @@ #include "cutlass/numeric_conversion.h" #include "cutlass_extensions/vllm_custom_types.cuh" #include "cutlass_extensions/cute_utils.cuh" +#include "cutlass_extensions/vllm_type_utils.cuh" // this file extends: // https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h @@ -28,8 +29,19 @@ struct InterleavedNumericArrayConverter { CUTLASS_DEVICE static result_type convert(source_type const& source) { - CUTE_INVALID_CONTROL_PATH( - "InterleavedNumericArrayConverter not implemented\n"); + if (cute::elect_one_sync()) { + if constexpr (std::is_same_v) { + printf( + "Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n", + nameof_v, nameof_v, N); + } else { + printf( + "Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not " + "implemented\n", + nameof_v, nameof_v, N, size(IlvBlkLayout{})); + } + __brkpt(); + } return {}; } @@ -56,11 +68,6 @@ struct InterleavedNumericArrayConverter< result_type operator()(source_type const& s) const { return convert(s); } }; -// TODO (LucasWilkinson): Implement -// for Array <= Array - -// .... - template struct ArrayConverterPacked32Bit { using result_type = Array; @@ -86,14 +93,16 @@ struct ArrayConverterPacked32Bit { using ScalarConverter = NumericConverter; template - CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { + CUTLASS_DEVICE static auto to_regs(PackedSrc const& src) { if constexpr (sizeof(PackedSrc) == 1) { - return static_cast(reinterpret_cast(source)); + return Array{reinterpret_cast(src)}; } else if constexpr (sizeof(PackedSrc) == 2) { - return static_cast(reinterpret_cast(source)); + return Array{reinterpret_cast(src)}; + } else if constexpr (sizeof(PackedSrc) == 4) { + return Array{reinterpret_cast(src)}; } else { - static_assert(sizeof(PackedSrc) == 4); - return reinterpret_cast(source); + static_assert(sizeof(PackedSrc) == 8); + return reinterpret_cast const&>(src); } } @@ -110,7 +119,7 @@ struct ArrayConverterPacked32Bit { static_assert(std::is_same_v); static_assert(std::is_same_v); - return RegConvert32bit::template convert(to_reg(source)); + return RegConvert32bit::template convert(to_regs(source)); } friend class detail::VectorizedConverter; @@ -140,6 +149,131 @@ struct ArrayConverterPacked32Bit { } }; +// Convert 8 4bit values packed into a 32bit register to 8 8bit values packed +// into 2 32bit register. +template +CUTLASS_DEVICE cutlass::AlignedArray lut_4bit_to_8bit_convert( + uint32_t src) { + cutlass::AlignedArray r; + // Determines if the value is in the top half of the LUT if set or + // (i.e. LUT[8:15]) in the bottom half (i.e. LUT[0:7]) if not set. Then move + // into bit position 0x4 of each nibble so when or'd with final_prmt_base it + // selects the correct candidate. When elements in final_prmt_base + // are >= 0x4, the high candidate is selected (i.e. LUT[8:15]), when elements + // are < 0x4, the low candidate is selected (i.e. LUT[0:7]) + uint32_t high_bit = (src & 0x88888888) >> 1; + + // `high_bit` is OR'd with 0x31203120 to find the correct value in the LUT + // (selects correct high or low candidate) + const uint32_t final_prmt_base = 0x32103210; + + // Ignore the high bit when indexing into LUT, for each 4bit value + // we index into both the high and low candidates then use + // high_bit | final_prmt_base to select the correct candidate + uint32_t lut_idx = (src & 0x77777777); + + auto pack = [](uint8_t a, uint8_t b, uint8_t c, uint8_t d) { + return uint32_t(a) | (uint32_t(b) << 8) | (uint32_t(c) << 16) | + (uint32_t(d) << 24); + }; + + static constexpr uint32_t LOW_0 = pack(LUT0, LUT1, LUT2, LUT3); + static constexpr uint32_t LOW_1 = pack(LUT4, LUT5, LUT6, LUT7); + static constexpr uint32_t HIGH_0 = pack(LUT8, LUT9, LUT10, LUT11); + static constexpr uint32_t HIGH_1 = pack(LUT12, LUT13, LUT14, LUT15); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < 2; ++ii, lut_idx >>= 16, high_bit >>= 16) { + uint32_t final_prmt_idx = final_prmt_base | high_bit; + + // This uses a look up table to convert packed int4s to packed int8s, + // using the int4 value as the index to prmt. It first select both the + // high and low candidates, then uses the high bit (i.e. `high_bit`) to + // select the correct candidate. + asm volatile( + "{\n" + " .reg .b32 low, high;\n" + " prmt.b32 low, %1, %2, %5;\n" + " prmt.b32 high, %3, %4, %5;\n" + " prmt.b32 %0, low, high, %6;\n" + "}\n" + : "=r"(r[ii]) + : "n"(LOW_0), "n"(LOW_1), "n"(HIGH_0), "n"(HIGH_1), "r"(lut_idx), + "r"(final_prmt_idx)); + } + + return r; +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as int8s + auto r = lut_4bit_to_8bit_convert<0xF8, 0xF9, 0xFA, 0xFB, // + 0xFC, 0xFD, 0xFE, 0xFF, // + 0x00, 0x01, 0x02, 0x03, // + 0x04, 0x05, 0x06, 0x07>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + // [-8, -7, -6, -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5, 6, 7] as fp8s + auto r = lut_4bit_to_8bit_convert<0xD0, 0xCE, 0xCC, 0xCA, // + 0xC8, 0xC4, 0xC0, 0xB8, // + 0x00, 0x38, 0x40, 0x44, // + 0x48, 0x4A, 0x4C, 0x4E>(src_[0]); + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + // for Array <= Array template struct NumericArrayConverter { @@ -148,7 +282,8 @@ struct NumericArrayConverter { struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -249,7 +384,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -338,7 +474,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -417,7 +554,8 @@ struct NumericArrayConverter { struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; // Hold output FP16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray { private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; PackedResultType r; // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of @@ -513,7 +652,8 @@ struct NumericArrayConverter { private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src_reg = src_[0]; // Hold output BF16s in reg. We need 1 reg for every 2 elements using RegArray = cutlass::AlignedArray, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -671,7 +812,8 @@ struct InterleavedNumericArrayConverter, Stride<_4, _1>>, private: struct RegConvert { template - CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + CUTLASS_DEVICE static PackedResultType convert(Array src_) { + uint32_t src = src_[0]; using RegArray = cutlass::AlignedArray; @@ -788,6 +930,61 @@ struct NumericArrayConverter { #endif +// for Array <= Array +// FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + // FastFP16toINT8 from https://arxiv.org/pdf/2406.09904 + template + CUTLASS_DEVICE static PackedResultType convert( + Array src) { + // Hold output int8s in reg. We need 1 reg for every 4 elements + using RegArray = cutlass::AlignedArray< + uint32_t, std::max(PackedResultType::kElements / 4, size_t(1))>; + RegArray r; + + static constexpr uint32_t MAGIC_BIAS_ = 0x64806480; + auto MAGIC_BIAS = *reinterpret_cast(&MAGIC_BIAS_); + + *reinterpret_cast(&src[0]) = + __hadd2(*reinterpret_cast(&src[0]), MAGIC_BIAS); + + if constexpr (src_regs > 1) { + *reinterpret_cast(&src[1]) = + __hadd2(*reinterpret_cast(&src[1]), MAGIC_BIAS); + } + + static_assert(PackedResultType::kElements <= 4); + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(uint8s) + : "r"(src[0]), "r"((src_regs > 1) ? src[1] : src[0]), + "n"(MASK_0246)); + + uint32_t int8s = (uint8s ^ UINT8s_TO_INT8s_MASK); + + return reinterpret_cast(int8s); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/csrc/cutlass_extensions/vllm_type_utils.cuh b/csrc/cutlass_extensions/vllm_type_utils.cuh new file mode 100644 index 0000000000000..500ed508c8303 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_type_utils.cuh @@ -0,0 +1,42 @@ +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" +#include "cuda_bf16.h" + +#include "cutlass_extensions/vllm_custom_types.cuh" + +namespace cutlass { + +template +struct nameof { + static constexpr char const* value = "unknown"; +}; + +template +inline constexpr auto nameof_v = nameof::value; + +#define NAMEOF_TYPE(T) \ + template <> \ + struct nameof { \ + static constexpr char const* value = #T; \ + }; + +NAMEOF_TYPE(float_e4m3_t) +NAMEOF_TYPE(float_e5m2_t) +NAMEOF_TYPE(half_t) +NAMEOF_TYPE(nv_bfloat16) +NAMEOF_TYPE(bfloat16_t) +NAMEOF_TYPE(float) + +NAMEOF_TYPE(int4b_t) +NAMEOF_TYPE(int8_t) +NAMEOF_TYPE(int32_t) +NAMEOF_TYPE(int64_t) + +NAMEOF_TYPE(vllm_uint4b8_t) +NAMEOF_TYPE(uint4b_t) +NAMEOF_TYPE(uint8_t) +NAMEOF_TYPE(vllm_uint8b128_t) +NAMEOF_TYPE(uint32_t) +NAMEOF_TYPE(uint64_t) + +}; // namespace cutlass \ No newline at end of file diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index ee801e16573d4..dbb72e8bbd3f5 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -8,6 +8,10 @@ #include "scaled_mm_c2x_sm89_fp8_dispatch.cuh" #include "scaled_mm_c2x_sm89_int8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp" + +using namespace vllm; + /* This file defines quantized GEMM operations using the CUTLASS 2.x API, for NVIDIA GPUs with SM versions prior to sm90 (Hopper). @@ -22,12 +26,11 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm75_dispatch( + return cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm75_dispatch( + return cutlass_gemm_sm75_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -42,10 +45,10 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales); } } @@ -61,10 +64,10 @@ void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm75_epilogue( + return cutlass_scaled_mm_sm75_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -78,12 +81,11 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm80_dispatch( + return cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm80_dispatch( + return cutlass_gemm_sm80_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -98,10 +100,10 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales); } } @@ -117,10 +119,10 @@ void cutlass_scaled_mm_azp_sm80(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm80_epilogue( + return cutlass_scaled_mm_sm80_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } @@ -134,13 +136,12 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kInt8); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm89_int8_dispatch( + return cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { assert(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm89_int8_dispatch( + return cutlass_gemm_sm89_int8_dispatch( out, a, b, std::forward(epilogue_args)...); } } else { @@ -148,13 +149,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); if (out.dtype() == torch::kBFloat16) { - return vllm::cutlass_gemm_sm89_fp8_dispatch< - cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>( + return cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } else { TORCH_CHECK(out.dtype() == torch::kFloat16); - return vllm::cutlass_gemm_sm89_fp8_dispatch( + return cutlass_gemm_sm89_fp8_dispatch( out, a, b, std::forward(epilogue_args)...); } } @@ -170,10 +171,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == out.dtype(), "currently bias dtype must match output dtype ", out.dtype()); - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales); } } @@ -189,10 +190,10 @@ void cutlass_scaled_mm_azp_sm89(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm89_epilogue( + return cutlass_scaled_mm_sm89_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh index 6329ff63623e2..d03242f44ab1d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh @@ -21,7 +21,6 @@ #include "cutlass/epilogue/threadblock/fusion/visitors.hpp" #include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" -#include "broadcast_load_epilogue_c2x.hpp" #include "common.hpp" // clang-format on @@ -71,307 +70,6 @@ struct enable_sm89_to_sm90 : Kernel { #endif } }; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; - - template - using ColOrScalarLoad = - cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = - cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - template - using ColLoad = cutlass::epilogue::threadblock::VisitorColBroadcast< - OutputTileThreadMap, T, Stride, Int<0>, Int<0>>>; - - template - using RowLoad = cutlass::epilogue::threadblock::VisitorRowBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - template - using RowOrZeroLoad = - cutlass::epilogue::threadblock::VisitorRowOrZeroBroadcast< - OutputTileThreadMap, T, Stride, Int<1>, Int<0>>>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - // it would technically work but no use case as data_ptr is never nullptr - static_assert(!std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - static_assert(std::is_same_v>); - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch._scaled_mm. - - A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or - per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : protected ScaledEpilogueBase { - protected: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::threadblock::Sm80EVT; - - using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : protected ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowOrZeroLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : protected ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowOrZeroLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::threadblock::Sm80EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleB = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::threadblock::Sm80EVT; - - using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::threadblock::Sm80EVT; - - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - template typename ArchGuard, typename ElementAB_, typename ElementD_, template typename Epilogue_, typename TileShape, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index 292c9e4b34e1c..33581a63d4c3d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -23,11 +23,12 @@ #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" -#include "broadcast_load_epilogue_c3x.hpp" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" #include "common.hpp" // clang-format on using namespace cute; +using namespace vllm; /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for @@ -56,305 +57,6 @@ struct enable_sm90_or_later : Kernel { #endif } }; - -/* - * This class provides the common load descriptors for the - * ScaledEpilogue[...] classes - */ -template -struct ScaledEpilogueBase { - protected: - using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - - template - using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; - - template - using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; - - // Don't want to support nullptr by default - template - using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // Don't want to support nullptr by default - template - using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; - - // This utility function constructs the arguments for the load descriptors - // from a tensor. It can handle both row and column, as well as row/column or - // scalar cases. - template - static auto args_from_tensor(torch::Tensor const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = static_cast(tensor.data_ptr()); - if constexpr (std::is_same_v> || - std::is_same_v>) { - return Arguments{data_ptr, tensor.numel() != 1}; - } else { - static_assert(!std::is_same_v> && - !std::is_same_v>); - return Arguments{data_ptr}; - } - } - - // This overload handles the case where there might not be a tensor, in which - // case a nullptr is passed and a constant (0) is used. - template - static auto args_from_tensor(c10::optional const& tensor) { - using Arguments = typename Descriptor::Arguments; - auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; - static_assert(std::is_same_v> || - std::is_same_v>); - return Arguments{data_ptr}; - } -}; - -/* - This epilogue function defines a quantized GEMM operation similar to - torch.scaled_mm_. - - A and B may be both either int8 or fp8_e4m3. A can be - quantized per-tensor or per-row. B can be quantized per-tensor or per-column. - Any combination of per-tensor and per-row or column is supported. - A and B must have symmetric quantization (zero point == 0). - - So the GEMM operation is D = (a_scales * A) (b_scales * B), where the - scales are applied elementwise with numpy-style broadcasting. - - ScaleA and ScaleB define the epilogue functions that apply the scales for - the A and B operands respectively. These scales may be either per-tensor or - per row or column. -*/ -template -struct ScaledEpilogue - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args}; - } -}; - -/* - * This epilogue performs the same operation as ScaledEpilogue, but adds a bias. - * This bias can also be used in the per-tensor azp case, where the activation - * zero point (azp) is used to compute an azp correction term, - * which is folded into the bias. - * - * The bias tensor must be per-output channel. - * ScaleA and ScaleB can be per-tensor or per-token/per-channel. - */ -template -struct ScaledEpilogueBias - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTCompute0 = - cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - - using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - - typename EVTCompute0::Arguments evt0_args{b_args}; - return ArgumentType{a_args, evt0_args, bias_args}; - } -}; - -/* - * This epilogue directly supports per-tensor azp in int32 form. - * As opposed to the per-token epilogue below, this epilogue only has an azp_adj - * term, which should already be multiplied with the scalar azp. - * The azp_adj term is a 1D tensor of shape (1,n), computed as azp * J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzp - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // This is the full AZP term, azp * J @ B, shape (1,n) - using AzpWithAdj = typename SUPER::template RowLoad; - - // Compute float(accum - azp_adj), both operands are int32_t - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{{}, azp_adj_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_azp_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - -/* - * This epilogue supports per-token azp by computing and applying - * the correction term using a rank-1 update. If the term were materialized, - * it would require O(m*n) space, and this way it only requires O(m+n) space. - * The azp term is a 1D tensor of shape (m,1), and represents the unscaled zero - * point for each row of A. - * The azp_adj term is a 1D tensor of shape (1,n), computed as J @ B. - * - * This epilogue also supports bias, which remains per-channel. - */ -template -struct ScaledEpilogueBiasAzpToken - : private ScaledEpilogueBase { - private: - using SUPER = ScaledEpilogueBase; - using Accum = typename SUPER::Accum; - using ScaleA = typename SUPER::template ColOrScalarLoad; - using ScaleB = typename SUPER::template RowOrScalarLoad; - using Bias = typename SUPER::template RowLoad; - - // Per-token azp term, shape (m,1) - using Azp = typename SUPER::template ColLoad; - - // This is the AZP adjustment term, J @ B, shape (1,n) - using AzpAdj = typename SUPER::template RowLoad; - - // Compute azp * azp_adj - using ComputeAzp = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, int32_t, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAzp = - cutlass::epilogue::fusion::Sm90EVT; - - // Compute float(accum - azp*azp_adj), all operands are int32_t - using ComputeAcc = cutlass::epilogue::fusion::Sm90Compute< - cutlass::minus, float, int32_t, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeAcc = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleB = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiplies, float, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - using EVTComputeScaleB = - cutlass::epilogue::fusion::Sm90EVT; - - using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute< - cutlass::multiply_add, ElementD, float, - cutlass::FloatRoundStyle::round_to_nearest>; - - public: - using EVTCompute = - cutlass::epilogue::fusion::Sm90EVT; - using ArgumentType = typename EVTCompute::Arguments; - - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - c10::optional const& bias) { - auto a_args = SUPER::template args_from_tensor(a_scales); - auto b_args = SUPER::template args_from_tensor(b_scales); - auto bias_args = SUPER::template args_from_tensor(bias); - auto azp_args = SUPER::template args_from_tensor(azp); - auto azp_adj_args = - SUPER::template args_from_tensor(azp_adj); - - typename EVTComputeAzp::Arguments evt_azp_args{azp_args, azp_adj_args}; - typename EVTComputeAcc::Arguments evt_acc_args{{}, evt_azp_args}; - typename EVTComputeScaleB::Arguments evt_scale_b_args{b_args, evt_acc_args}; - return ArgumentType{a_args, evt_scale_b_args, bias_args}; - } -}; - template typename Epilogue_, typename TileShape, typename ClusterShape, typename KernelSchedule, @@ -721,11 +423,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, if (bias) { TORCH_CHECK(bias->dtype() == c.dtype(), "currently bias dtype must match output dtype ", c.dtype()); - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( c, a, b, a_scales, b_scales, *bias); } else { - return cutlass_scaled_mm_sm90_epilogue(c, a, b, a_scales, - b_scales); + return cutlass_scaled_mm_sm90_epilogue( + c, a, b, a_scales, b_scales); } } @@ -740,10 +442,10 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, TORCH_CHECK(b_scales.dtype() == torch::kFloat32); if (azp) { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, *azp, bias); } else { - return cutlass_scaled_mm_sm90_epilogue( + return cutlass_scaled_mm_sm90_epilogue( out, a, b, a_scales, b_scales, azp_adj, bias); } } diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index d126af1849024..ac63afe79a255 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -3,8 +3,10 @@ import os import shutil from collections.abc import Iterable -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from copy import deepcopy +from dataclasses import dataclass, fields +from functools import reduce +from typing import Dict, List, Optional, Tuple, Union import jinja2 # yapf conflicts with isort for this block @@ -14,7 +16,10 @@ MixedInputKernelScheduleType, TileSchedulerTag, TileSchedulerType, VLLMDataType, - VLLMDataTypeNames, VLLMDataTypeTag, + VLLMDataTypeNames, + VLLMDataTypeSize, VLLMDataTypeTag, + VLLMDataTypeTorchDataTypeTag, + VLLMDataTypeVLLMScalarTypeTag, VLLMKernelScheduleTag) # yapf: enable @@ -27,49 +32,125 @@ #include "../machete_mm_launcher.cuh" namespace machete { -using GemmDispatcher_ = GemmDispatcher< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints - -{% for s in schedules %}extern torch::Tensor -impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args); -{% endfor %} -template <> -torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) { + +{% for impl_config in impl_configs %} +{% set type_sig = gen_type_sig(impl_config.types) -%} +{% for s in impl_config.schedules %} +extern torch::Tensor impl_{{type_sig}}_sch_{{gen_sch_sig(s)}}(MMArgs); +{%- endfor %} + +torch::Tensor mm_dispatch_{{type_sig}}(MMArgs args) { [[maybe_unused]] auto M = args.A.size(0); [[maybe_unused]] auto N = args.B.size(1); [[maybe_unused]] auto K = args.A.size(1); - if (!args.schedule) { - {%- for cond, s in heuristic %} + if (!args.maybe_schedule) { + {%- for cond, s in impl_config.heuristic %} {%if cond is not none%}if ({{cond}}) {%- else %}else {%- endif %} - return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %} + return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args);{% endfor %} } - {% for s in schedules %} - if (*args.schedule == "{{ gen_sch_name(s) }}") { - return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args); - } - {% endfor %} + {%- for s in impl_config.schedules %} + if (*args.maybe_schedule == "{{ gen_sch_sig(s) }}") + return impl_{{type_sig}}_sch_{{ gen_sch_sig(s) }}(args); + {%- endfor %} TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for " - "schedule = ", *args.schedule); + "schedule = ", *args.maybe_schedule); } +{%- endfor %} + -template <> -std::vector GemmDispatcher_::supported_schedules() { - return { - {% for s in schedules -%} - "{{ gen_sch_name(s) }}"{{ ", - " if not loop.last }}{%- endfor %} - }; +static inline std::optional maybe_scalartype( + c10::optional const& t) { + if (!t) { + return std::nullopt; + } else { + return t->scalar_type(); + }; +} + +torch::Tensor mm_dispatch(MMArgs args) { + auto out_type = args.maybe_out_type.value_or(args.A.scalar_type()); + auto a_type = args.A.scalar_type(); + auto maybe_g_scales_type = maybe_scalartype(args.maybe_group_scales); + auto maybe_g_zeros_type = maybe_scalartype(args.maybe_group_zeros); + auto maybe_ch_scales_type = maybe_scalartype(args.maybe_channel_scales); + auto maybe_tok_scales_type = maybe_scalartype(args.maybe_token_scales); + + {% for impl_config in impl_configs %} + {% set t = impl_config.types -%} + {% set type_sig = gen_type_sig(t) -%} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} + && a_type == {{TorchTypeTag[t.a]}} + && out_type == {{TorchTypeTag[t.out]}} + && {%if t.b_group_scale != void -%} + maybe_g_scales_type == {{TorchTypeTag[t.b_group_scale]}} + {%- else %}!maybe_g_scales_type{%endif%} + && {%if t.b_group_zeropoint != void -%} + maybe_g_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}} + {%- else %}!maybe_g_zeros_type{%endif%} + && {%if t.b_channel_scale != void -%} + maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}} + {%- else %}!maybe_ch_scales_type{%endif%} + && {%if t.a_token_scale != void -%} + maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}} + {%- else %}!maybe_tok_scales_type{%endif%} + ) { + return mm_dispatch_{{type_sig}}(args); + } + {%- endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED( + false, "machete_mm(..) is not implemented for " + "a_type=", args.A.scalar_type(), + ", b_type=", args.b_type.str(), + ", out_type=", out_type, + ", with_group_scale_type=", maybe_g_scales_type + ? toString(*maybe_g_scales_type) : "None", + ", with_group_zeropoint_type=", maybe_g_zeros_type + ? toString(*maybe_g_zeros_type) : "None", + ", with_channel_scale_type=", maybe_ch_scales_type + ? toString(*maybe_ch_scales_type) : "None", + ", with_token_scale_type=", maybe_tok_scales_type + ? toString(*maybe_tok_scales_type) : "None", + "; implemented types are: \\n", + {%- for impl_config in impl_configs %} + {% set t = impl_config.types -%} + "\\t{{gen_type_option_name(t)}}\\n", + {%- endfor %} + ""); } +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args) { + auto out_type = args.maybe_out_type.value_or(args.a_type); + + {% for impl_config in impl_configs %} + {% set t = impl_config.types -%} + {% set schs = impl_config.schedules -%} + if (args.b_type == {{VLLMScalarTypeTag[t.b]}} + && args.a_type == {{TorchTypeTag[t.a]}} + && out_type == {{TorchTypeTag[t.out]}} + && {%if t.b_group_scale != void -%} + args.maybe_group_scales_type == {{TorchTypeTag[t.b_group_scale]}} + {%- else %}!args.maybe_group_scales_type{%endif%} + && {%if t.b_group_zeropoint != void-%} + args.maybe_group_zeros_type == {{TorchTypeTag[t.b_group_zeropoint]}} + {%- else %}!args.maybe_group_zeros_type{%endif%} + ) { + return { + {%- for s in impl_config.schedules %} + "{{gen_sch_sig(s)}}"{% if not loop.last %},{% endif %} + {%- endfor %} + }; + } + {%- endfor %} + + return {}; +}; + }; // namespace machete """ @@ -77,20 +158,10 @@ #include "../machete_mm_launcher.cuh" namespace machete { -template -using Kernel = MacheteKernelTemplate< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, - Config, with_C, with_scales, with_zeropoints>; - -{% for sch in schedules %} -{% set schedule_name = gen_sch_name(sch) -%} -struct sch_{{schedule_name}} { + +{% for sch in unique_schedules(impl_configs) %} +{% set sch_sig = gen_sch_sig(sch) -%} +struct sch_{{sch_sig}} { using TileShapeNM = Shape<{{ to_cute_constant(sch.tile_shape_mn)|join(', ')}}>; using ClusterShape = Shape<{{ @@ -101,27 +172,34 @@ using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}}; using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; }; - +{% endfor %} + +{% for impl_config in impl_configs %} +{% set t = impl_config.types -%} +{% set schs = impl_config.schedules -%} +{% set type_sig = gen_type_sig(t) -%} + +template +using Kernel_{{type_sig}} = MacheteKernelTemplate< + {{DataTypeTag[t.a]}}, // ElementA + {{DataTypeTag[t.b]}}, // ElementB + {{DataTypeTag[t.out]}}, // ElementD + {{DataTypeTag[t.accumulator]}}, // Accumulator + {{DataTypeTag[t.b_group_scale]}}, // GroupScaleT + {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT + {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT + {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, + Sch>; + +{% for sch in schs %} +{% set sch_sig = gen_sch_sig(sch) -%} torch::Tensor -impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) { - bool with_C = args.C.has_value(), with_scales = args.scales.has_value(), - with_zeropoints = args.zeros.has_value(); - - {% for s in specializations %} - if (with_C == {{s.with_C|lower}} - && with_zeropoints == {{s.with_zeropoints|lower}} - && with_scales == {{s.with_scales|lower}}) { - return run_impl>(args); - }{% endfor %} - - TORCH_CHECK_NOT_IMPLEMENTED( - false, "for the sake of compile times and binary size machete_mm(..) is " - " not implemented for with_C=", with_C, ", with_scales=", with_scales, - ", with_zeropoints=", with_zeropoints, - " (for {{type_name}}_sch_{{schedule_name}})"); +impl_{{type_sig}}_sch_{{sch_sig}}(MMArgs args) { + return run_impl>(args); } -{% endfor %} +{%- endfor %} +{%- endfor %} }; // namespace machete """ @@ -130,26 +208,34 @@ #include "../machete_prepack_launcher.cuh" namespace machete { -using PrepackBDispatcher_ = PrepackBDispatcher< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - {{DataTypeTag[type_config.element_b_scale]}}, // Scales - {{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints - -using PrepackedLayoutB = PrepackedLayoutBTemplate< - {{DataTypeTag[type_config.element_a]}}, // ElementA - {{DataTypeTag[type_config.element_b]}}, // ElementB - {{DataTypeTag[type_config.element_d]}}, // ElementD - {{DataTypeTag[type_config.accumulator]}}, // Accumulator - cutlass::layout::ColumnMajor, - cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>; - -template <> -torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) { - return prepack_impl(B); + +torch::Tensor prepack_B_dispatch(PrepackBArgs args) { + auto convert_type = args.maybe_group_scales_type.value_or(args.a_type); + {%- for t in types %} + {% set b_type = unsigned_type_with_bitwidth(t.b_num_bits) %} + if (args.a_type == {{TorchTypeTag[t.a]}} + && args.b_type.size_bits() == {{t.b_num_bits}} + && convert_type == {{TorchTypeTag[t.convert]}}) { + return prepack_impl< + PrepackedLayoutBTemplate< + {{DataTypeTag[t.a]}}, // ElementA + {{DataTypeTag[b_type]}}, // ElementB + {{DataTypeTag[t.convert]}}, // ElementConvert + {{DataTypeTag[t.accumulator]}}, // Accumulator + cutlass::layout::ColumnMajor, + cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput> + >(args.B); + } + {%- endfor %} + + TORCH_CHECK_NOT_IMPLEMENTED(false, + "prepack_B_dispatch(..) is not implemented for " + "atype = ", args.a_type, + ", b_type = ", args.b_type.str(), + ", with_group_scales_type= ", args.maybe_group_scales_type ? + toString(*args.maybe_group_scales_type) : "None"); } + }; // namespace machete """ @@ -166,32 +252,34 @@ class ScheduleConfig: tile_scheduler: TileSchedulerType -@dataclass +@dataclass(frozen=True) class TypeConfig: - element_a: DataType - element_b: Union[DataType, VLLMDataType] - element_b_scale: DataType - element_b_zeropoint: DataType - element_d: DataType + a: DataType + b: Union[DataType, VLLMDataType] + b_group_scale: DataType + b_group_zeropoint: DataType + b_channel_scale: DataType + a_token_scale: DataType + out: DataType accumulator: DataType -@dataclass -class Specialization: - with_C: bool - with_zeropoints: bool - with_scales: bool +@dataclass(frozen=True) +class PrepackTypeConfig: + a: DataType + b_num_bits: int + convert: DataType + accumulator: DataType @dataclass class ImplConfig: - type_config: TypeConfig - schedule_configs: List[ScheduleConfig] - specializations: List[Specialization] + types: TypeConfig + schedules: List[ScheduleConfig] heuristic: List[Tuple[Optional[str], ScheduleConfig]] -def generate_schedule_name(schedule_config: ScheduleConfig) -> str: +def generate_sch_sig(schedule_config: ScheduleConfig) -> str: tile_shape = ( f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}" ) @@ -209,40 +297,34 @@ def generate_schedule_name(schedule_config: ScheduleConfig) -> str: f"_{epilogue_schedule}_{tile_scheduler}") -# mostly unique shorter schedule_name -def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str: +# mostly unique shorter sch_sig +def generate_terse_sch_sig(schedule_config: ScheduleConfig) -> str: kernel_terse_names_replace = { "KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_", "TmaWarpSpecializedCooperative_": "TmaCoop_", "StreamKScheduler": "streamK", } - schedule_name = generate_schedule_name(schedule_config) + sch_sig = generate_sch_sig(schedule_config) for orig, terse in kernel_terse_names_replace.items(): - schedule_name = schedule_name.replace(orig, terse) - return schedule_name + sch_sig = sch_sig.replace(orig, terse) + return sch_sig # unique type_name -def generate_type_signature(kernel_type_config: TypeConfig): - element_a = VLLMDataTypeNames[kernel_type_config.element_a] - element_b = VLLMDataTypeNames[kernel_type_config.element_b] - element_d = VLLMDataTypeNames[kernel_type_config.element_d] - accumulator = VLLMDataTypeNames[kernel_type_config.accumulator] - element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale] - element_zeropoint = VLLMDataTypeNames[ - kernel_type_config.element_b_zeropoint] - - return (f"{element_a}{element_b}{element_d}" - f"{accumulator}{element_scale}{element_zeropoint}") - +def generate_type_signature(kernel_types: TypeConfig): + return str("".join([ + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ])) -# non-unique shorter type_name -def generate_terse_type_signature(kernel_type_config: TypeConfig): - element_a = VLLMDataTypeNames[kernel_type_config.element_a] - element_b = VLLMDataTypeNames[kernel_type_config.element_b] - return f"{element_a}{element_b}" +def generate_type_option_name(kernel_types: TypeConfig): + return ", ".join([ + f"{field.name.replace('b_', 'with_')+'_type'}=" + + VLLMDataTypeNames[getattr(kernel_types, field.name)] + for field in fields(TypeConfig) + ]) def is_power_of_two(n): @@ -263,13 +345,36 @@ def _to_cute_constant(value: int): return _to_cute_constant(value) +def unique_schedules(impl_configs: List[ImplConfig]): + return list( + set(sch for impl_config in impl_configs + for sch in impl_config.schedules)) + + +def unsigned_type_with_bitwidth(num_bits): + return { + 4: DataType.u4, + 8: DataType.u8, + 16: DataType.u16, + 32: DataType.u32, + 64: DataType.u64, + }[num_bits] + + template_globals = { + "void": DataType.void, "DataTypeTag": VLLMDataTypeTag, + "VLLMScalarTypeTag": VLLMDataTypeVLLMScalarTypeTag, + "TorchTypeTag": VLLMDataTypeTorchDataTypeTag, "KernelScheduleTag": VLLMKernelScheduleTag, "EpilogueScheduleTag": EpilogueScheduleTag, "TileSchedulerTag": TileSchedulerTag, "to_cute_constant": to_cute_constant, - "gen_sch_name": generate_terse_schedule_name, + "gen_sch_sig": generate_terse_sch_sig, + "gen_type_sig": generate_type_signature, + "unique_schedules": unique_schedules, + "unsigned_type_with_bitwidth": unsigned_type_with_bitwidth, + "gen_type_option_name": generate_type_option_name } @@ -284,42 +389,82 @@ def create_template(template_str): prepack_dispatch_template = create_template(PREPACK_TEMPLATE) -def create_sources(impl_config: ImplConfig, num_impl_files=1): +def create_sources(impl_configs: List[ImplConfig], num_impl_files=8): sources = [] - type_name = generate_type_signature(impl_config.type_config) - terse_type_name = generate_terse_type_signature(impl_config.type_config) - sources.append(( - f"machete_mm_{terse_type_name}", - mm_dispatch_template.render(type_name=type_name, - type_config=impl_config.type_config, - schedules=impl_config.schedule_configs, - heuristic=impl_config.heuristic), + "machete_mm_dispatch", + mm_dispatch_template.render(impl_configs=impl_configs), )) + prepack_types = [] + for impl_config in impl_configs: + convert_type = impl_config.types.a \ + if impl_config.types.b_group_scale == DataType.void \ + else impl_config.types.b_group_scale + prepack_types.append( + PrepackTypeConfig( + a=impl_config.types.a, + b_num_bits=VLLMDataTypeSize[impl_config.types.b], + convert=convert_type, + accumulator=impl_config.types.accumulator, + )) + + def prepacked_type_key(prepack_type: PrepackTypeConfig): + # For now we we can just use the first accumulator type seen since + # the tensor core shapes/layouts don't vary based on accumulator + # type so we can generate less code this way + return (prepack_type.a, prepack_type.b_num_bits, prepack_type.convert) + + unique_prepack_types = [] + prepack_types_seen = set() + for prepack_type in prepack_types: + key = prepacked_type_key(prepack_type) + if key not in prepack_types_seen: + unique_prepack_types.append(prepack_type) + prepack_types_seen.add(key) + sources.append(( - f"machete_prepack_{terse_type_name}", - prepack_dispatch_template.render( - type_name=type_name, - type_config=impl_config.type_config, - ), + "machete_prepack", + prepack_dispatch_template.render(types=unique_prepack_types, ), )) - num_schedules = len(impl_config.schedule_configs) - schedules_per_file = math.ceil(num_schedules / num_impl_files) - for part, i in enumerate(range(0, num_schedules, schedules_per_file)): - file_schedules = impl_config.schedule_configs[i:i + schedules_per_file] + # Split up impls across files + num_impls = reduce(lambda x, y: x + len(y.schedules), impl_configs, 0) + num_impls_per_file = math.ceil(num_impls / num_impl_files) + + files_impls: List[List[ImplConfig]] = [[]] + + curr_num_impls_assigned = 0 + curr_impl_in_file = 0 + curr_impl_configs = deepcopy(list(reversed(impl_configs))) + + while curr_num_impls_assigned < num_impls: + room_left_in_file = num_impls_per_file - curr_impl_in_file + if room_left_in_file == 0: + files_impls.append([]) + room_left_in_file = num_impls_per_file + curr_impl_in_file = 0 + + curr_ic = curr_impl_configs[-1] + if len(curr_ic.schedules) >= room_left_in_file: + # Break apart the current impl config + tmp_ic = deepcopy(curr_ic) + tmp_ic.schedules = curr_ic.schedules[:room_left_in_file] + curr_ic.schedules = curr_ic.schedules[room_left_in_file:] + files_impls[-1].append(tmp_ic) + else: + files_impls[-1].append(curr_ic) + curr_impl_configs.pop() + curr_num_impls_assigned += len(files_impls[-1][-1].schedules) + curr_impl_in_file += len(files_impls[-1][-1].schedules) + for part, file_impls in enumerate(files_impls): sources.append(( - f"machete_mm_{terse_type_name}_impl_part{part}", - mm_impl_template.render( - type_name=type_name, - type_config=impl_config.type_config, - schedules=file_schedules, - specializations=impl_config.specializations, - ), + f"machete_mm_impl_part{part+1}", + mm_impl_template.render(impl_configs=file_impls), )) + return sources @@ -328,187 +473,169 @@ def generate(): # about how this works SCRIPT_DIR = os.path.dirname(__file__) - schedule_common_params = dict( + sch_common_params = dict( kernel_schedule=TmaMI, epilogue_schedule=TmaCoop, tile_scheduler=TileSchedulerType.StreamK, ) - # For now we use the same heuristic for all types - # Heuristic is currently tuned for H100s - default_heuristic = [ + # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + default_tile_heuristic_config = { #### M = 257+ - ( - "M > 256 && K <= 16384 && N <= 4096", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 256", - ScheduleConfig( - tile_shape_mn=(128, 256), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + "M > 256": ((128, 256), (2, 1, 1)), #### M = 129-256 - ( - "M > 128 && K <= 4096 && N <= 4096", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 128 && K <= 8192 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 128", - ScheduleConfig( - tile_shape_mn=(128, 256), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + "M > 128": ((128, 256), (2, 1, 1)), #### M = 65-128 - ( - "M > 64 && K <= 4069 && N <= 4069", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64 && K <= 4069 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64 && K >= 8192 && N >= 12288", - ScheduleConfig( - tile_shape_mn=(256, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 64", - ScheduleConfig( - tile_shape_mn=(128, 128), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + "M > 64": ((128, 128), (2, 1, 1)), #### M = 33-64 - ( - "M > 32 && K <= 6144 && N <= 6144", - ScheduleConfig( - tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 32 && K >= 16384 && N >= 12288", - ScheduleConfig( - tile_shape_mn=(256, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 32", - ScheduleConfig( - tile_shape_mn=(128, 64), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + "M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + "M > 32": ((128, 64), (2, 1, 1)), #### M = 17-32 - ( - "M > 16 && K <= 12288 && N <= 8192", - ScheduleConfig( - tile_shape_mn=(128, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), - ( - "M > 16", - ScheduleConfig( - tile_shape_mn=(256, 32), - cluster_shape_mnk=(2, 1, 1), - **schedule_common_params # type: ignore - )), + "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + "M > 16": ((256, 32), (2, 1, 1)), #### M = 1-16 - ( - "N >= 26624", - ScheduleConfig( - tile_shape_mn=(256, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), - ( - None, - ScheduleConfig( - tile_shape_mn=(128, 16), - cluster_shape_mnk=(1, 1, 1), - **schedule_common_params # type: ignore - )), + "N >= 26624": ((256, 16), (1, 1, 1)), + None: ((128, 16), (1, 1, 1)), + } + + # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s + default_heuristic = [ + (cond, ScheduleConfig(*tile_config, + **sch_common_params)) # type: ignore + for cond, tile_config in default_tile_heuristic_config.items() ] - # Do not use schedules = list(set(...)) because we need to make sure - # the output list is deterministic; otherwise the generated kernel file - # will be non-deterministic and causes ccache miss. - schedules = [] - for _, schedule_config in default_heuristic: - if schedule_config not in schedules: - schedules.append(schedule_config) + def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]): + # Do not use schedules = list(set(...)) because we need to make sure + # the output list is deterministic; otherwise the generated kernel file + # will be non-deterministic and causes ccache miss. + schedules = [] + for _, schedule_config in heuristic: + if schedule_config not in schedules: + schedules.append(schedule_config) + return schedules impl_configs = [] GPTQ_kernel_type_configs = list( TypeConfig( - element_a=element_a, - element_b=element_b, - element_b_scale=element_a, - element_b_zeropoint=element_a, - element_d=element_a, + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.void, + a_token_scale=DataType.void, + out=a, accumulator=DataType.f32, - ) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128) - for element_a in (DataType.f16, DataType.bf16)) - - GPTQ_kernel_specializations = [ - Specialization(with_C=False, with_zeropoints=False, with_scales=True) - ] + ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) + for a in (DataType.f16, DataType.bf16)) impl_configs += [ - ImplConfig(x[0], x[1], x[2], x[3]) - for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules), - itertools.repeat(GPTQ_kernel_specializations), + ImplConfig(x[0], x[1], x[2]) + for x in zip(GPTQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), itertools.repeat(default_heuristic)) ] AWQ_kernel_type_configs = list( TypeConfig( - element_a=element_a, - element_b=element_b, - element_b_scale=element_a, - element_b_zeropoint=element_a, - element_d=element_a, + a=a, + b=b, + b_group_scale=a, + b_group_zeropoint=a, + b_channel_scale=DataType.void, + a_token_scale=DataType.void, + out=a, accumulator=DataType.f32, - ) for element_b in (DataType.u4, DataType.u8) - for element_a in (DataType.f16, DataType.bf16)) + ) for b in (DataType.u4, DataType.u8) + for a in (DataType.f16, DataType.bf16)) + + impl_configs += [ + ImplConfig(x[0], x[1], x[2]) + for x in zip(AWQ_kernel_type_configs, + itertools.repeat(get_unique_schedules(default_heuristic)), + itertools.repeat(default_heuristic)) + ] - AWQ_kernel_specializations = [ - Specialization(with_C=False, with_zeropoints=True, with_scales=True) + # Stored as "condition": ((tile_shape_mn), (cluster_shape_mnk)) + # TODO (LucasWilkinson): Further tuning required + qqq_tile_heuristic_config = { + #### M = 257+ + # ((128, 256), (2, 1, 1)) Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + # "M > 256 && K <= 16384 && N <= 4096": ((128, 128), (2, 1, 1)), + # "M > 256": ((128, 256), (2, 1, 1)), + "M > 256": ((128, 128), (2, 1, 1)), + #### M = 129-256 + "M > 128 && K <= 4096 && N <= 4096": ((128, 64), (2, 1, 1)), + "M > 128 && K <= 8192 && N <= 8192": ((128, 128), (2, 1, 1)), + # ((128, 256), (2, 1, 1)) Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + # "M > 128": ((128, 256), (2, 1, 1)), + "M > 128": ((128, 128), (2, 1, 1)), + #### M = 65-128 + "M > 64 && K <= 4069 && N <= 4069": ((128, 32), (2, 1, 1)), + "M > 64 && K <= 4069 && N <= 8192": ((128, 64), (2, 1, 1)), + "M > 64 && K >= 8192 && N >= 12288": ((256, 128), (2, 1, 1)), + "M > 64": ((128, 128), (2, 1, 1)), + #### M = 33-64 + "M > 32 && K <= 6144 && N <= 6144": ((128, 16), (1, 1, 1)), + # Broken for QQQ types + # TODO (LucasWilkinson): Investigate further + #"M > 32 && K >= 16384 && N >= 12288": ((256, 64), (2, 1, 1)), + "M > 32": ((128, 64), (2, 1, 1)), + #### M = 17-32 + "M > 16 && K <= 12288 && N <= 8192": ((128, 32), (2, 1, 1)), + "M > 16": ((256, 32), (2, 1, 1)), + #### M = 1-16 + "N >= 26624": ((256, 16), (1, 1, 1)), + None: ((128, 16), (1, 1, 1)), + } + + # For now we use the same heuristic for all types + # Heuristic is currently tuned for H100s + qqq_heuristic = [ + (cond, ScheduleConfig(*tile_config, + **sch_common_params)) # type: ignore + for cond, tile_config in qqq_tile_heuristic_config.items() + ] + + QQQ_kernel_types = [ + *(TypeConfig( + a=DataType.s8, + b=VLLMDataType.u4b8, + b_group_scale=b_group_scale, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.f32, + a_token_scale=DataType.f32, + out=DataType.f16, + accumulator=DataType.s32, + ) for b_group_scale in (DataType.f16, DataType.void)), + *(TypeConfig( + a=DataType.e4m3, + b=VLLMDataType.u4b8, + b_group_scale=b_group_scale, + b_group_zeropoint=DataType.void, + b_channel_scale=DataType.f32, + a_token_scale=DataType.f32, + out=DataType.f16, + accumulator=DataType.f32, + ) for b_group_scale in (DataType.f16, DataType.void)), ] impl_configs += [ - ImplConfig(x[0], x[1], x[2], x[3]) - for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules), - itertools.repeat(AWQ_kernel_specializations), - itertools.repeat(default_heuristic)) + ImplConfig(x[0], x[1], x[2]) + for x in zip(QQQ_kernel_types, + itertools.repeat(get_unique_schedules(qqq_heuristic)), + itertools.repeat(qqq_heuristic)) ] output_dir = os.path.join(SCRIPT_DIR, "generated") @@ -521,12 +648,11 @@ def generate(): os.makedirs(output_dir) # Render each group of configurations into separate files - for impl_config in impl_configs: - for filename, code in create_sources(impl_config): - filepath = os.path.join(output_dir, f"{filename}.cu") - with open(filepath, "w") as output_file: - output_file.write(code) - print(f"Rendered template to {filepath}") + for filename, code in create_sources(impl_configs): + filepath = os.path.join(output_dir, f"{filename}.cu") + with open(filepath, "w") as output_file: + output_file.write(code) + print(f"Rendered template to {filepath}") if __name__ == "__main__": diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index e8e7b14de0da1..816f33a1078e5 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -171,6 +171,10 @@ struct MacheteCollectiveMma { make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutACopy = decltype(GmemLayoutA::TVbNbKL_to_offset_copy( + make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}), + Int{}))); + using SmemLayoutAtomARowMajor = decltype(rs_smem_selector(TileShape_MNK{})), @@ -288,14 +292,7 @@ struct MacheteCollectiveMma { static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0, "SmemLayoutAtomScale must evenly divide tile k shape."); - // Tile along modes in a way that maximizes the TMA box size. - using SmemLayoutACopy = decltype(tile_to_shape( - SmemLayoutAtomARowMajor{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), - Int{}), - conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(), - Step<_2, _1, _3>, Step<_1, _2, _3>>{})); - + // Tile along modes in a way that maximizes the TMA box size using SmemLayoutB = decltype(tile_to_shape( SmemLayoutAtomB{}, make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), @@ -428,12 +425,12 @@ struct MacheteCollectiveMma { // clang-format on // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) - using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset( + using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset_copy( make_shape(int32_t(0), int32_t(0), int32_t(0))))); using ATensor = decltype(make_tensor( get_logical_ptr(static_cast(nullptr)), - shape(GmemLayoutA::TVbNbKL_to_offset( + shape(GmemLayoutA::TVbNbKL_to_offset_copy( make_shape(int32_t(0), int32_t(0), int32_t(0)))), PrepackedStrideA{})); @@ -450,8 +447,8 @@ struct MacheteCollectiveMma { static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) { return make_tma_copy( - GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}), - shape(SmemLayoutA{}(_, _, cute::Int<0>{})), + GmemTiledCopyA{}, tensor_a, SmemLayoutACopy{}(_, _, cute::Int<0>{}), + shape(SmemLayoutACopy{}(_, _, cute::Int<0>{})), size<1>(ClusterShape{})); // mcast along N mode for this M load, if any } @@ -584,7 +581,7 @@ struct MacheteCollectiveMma { typename Params::TMA_Scale tma_load_scale; typename Params::TMA_Zero tma_load_zero; - auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); tma_load_a = make_tma_copy_A( make_logical_tensor(ptr_A, shape(layout), stride(layout))); @@ -722,7 +719,7 @@ struct MacheteCollectiveMma { // (TILE_V,TILE_B,m,k,l) auto make_gA_mkl = [&]() { // ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx) - auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L)); + auto layout = GmemLayoutA::TVbNbKL_to_offset_copy(make_shape(M, K, L)); Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout)); return local_tile(mA_mkl, make_shape(size<0>(layout), PPBlocksPerTile_MK{}), diff --git a/csrc/quantization/machete/machete_mm_kernel.cuh b/csrc/quantization/machete/machete_mm_kernel.cuh index 4d41b8d291484..d4d19ae5deec7 100644 --- a/csrc/quantization/machete/machete_mm_kernel.cuh +++ b/csrc/quantization/machete/machete_mm_kernel.cuh @@ -21,6 +21,8 @@ #include "cutlass_extensions/cute_utils.cuh" #include "cutlass_extensions/vllm_numeric_conversion.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" +#include "cutlass_extensions/torch_utils.hpp" #include "machete_collective_builder.cuh" #include "machete_prepacked_layout.cuh" #include "machete_interleaving_utils.cuh" @@ -37,27 +39,42 @@ using namespace cute; // W is quantized, in this situation or right-hand operand is quantized so // we compute the transpose to move it to the left-hand side. template + typename AccumulatorT, typename GroupScaleT, typename GroupZeroT, + typename ChannelScaleT, typename TokenScaleT, class KernelSchedule, + typename ScheduleConfig> struct MacheteKernelTemplate { + static constexpr bool with_C = false; // not ever used + static constexpr bool with_group_scales = !std::is_same_v; + static constexpr bool with_group_zeropoints = + !std::is_same_v; + static constexpr bool with_channel_scales = + !std::is_same_v; + static constexpr bool with_token_scales = !std::is_same_v; + using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_; using ElementD = ElementD_; using ElementC = cute::conditional_t; - using ElementZ = ZeroT; - using ElementS = ScaleT; - - using ElementAccumulator = - AccumulatorT; // Element type for internal accumulation + using ElementAccumulator = AccumulatorT; using ElementCompute = AccumulatorT; // For Epilogue + // Use dummy values when we don't have scales or zeropoints + using ElementZGroup = + cute::conditional_t; + using ElementSGroup = + cute::conditional_t; + using ElementConvertGroup = + cute::conditional_t; + using ElementSChannel = + cute::conditional_t; + using ElementSToken = + cute::conditional_t; using BTypeTuple = cute::conditional_t< - with_scales, - cute::conditional_t, - cute::tuple>, + with_group_scales, + cute::conditional_t, + cute::tuple>, ElementB>; using LayoutA = cutlass::layout::RowMajor; @@ -71,8 +88,8 @@ struct MacheteKernelTemplate { using StrideA = cutlass::detail::TagToStrideA_t; using StrideC = cutlass::detail::TagToStrideA_t; using StrideD = cutlass::detail::TagToStrideA_t; - using StrideS = cutlass::detail::TagToStrideA_t; - using StrideZ = StrideS; + using StrideSGroup = cutlass::detail::TagToStrideA_t; + using StrideZGroup = StrideSGroup; using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; @@ -85,8 +102,8 @@ struct MacheteKernelTemplate { using OperatorClass = cutlass::arch::OpClassTensorOp; using PrepackedLayoutB = - PrepackedLayoutBTemplate; + PrepackedLayoutBTemplate; static int constexpr TileShapeK = 128 * 8 / cutlass::sizeof_bits::value; @@ -103,12 +120,42 @@ struct MacheteKernelTemplate { using EpilogueTileType = typename ScheduleConfig::EpilogueTileType; using TileScheduler = typename ScheduleConfig::TileScheduler; + static_assert( + (!with_channel_scales && !with_token_scales) || + ((with_channel_scales && with_token_scales) && + std::is_same_v), + "Currently token and channel scales (if present) must be the same type"); + + using EpilogueDescriptor = + cutlass::epilogue::collective::detail::EpilogueDescriptor< + TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, + ElementD, EpilogueSchedule>; + + // Currently only supports float scales + using ChTokScalesEpilogue = + typename vllm::c3x::ScaledEpilogue; + static_assert((with_channel_scales || with_token_scales) || + (std::is_same_v && + std::is_same_v), + "Currently token and channel scales (if present) must be float " + "(and if one is present the other must be too)"); + + using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90AccFetch>; + + using EVTCompute = + std::conditional_t; + + // EVTCompute using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType, - ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose, - AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, - EpilogueSchedule>::CollectiveOp; + ElementAccumulator, ElementSChannel, ElementC, LayoutC_Transpose, + AlignmentC, ElementD, LayoutD_Transpose, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; using CollectiveMainloop = typename cutlass::gemm::collective::VLLMCollectiveBuilder< @@ -131,26 +178,44 @@ struct MacheteKernelTemplate { using MainloopArguments = typename GemmKernel::MainloopArguments; using EpilogueArguments = typename GemmKernel::EpilogueArguments; - template static Arguments create_arguments( cudaStream_t stream, - ElementA const* A_ptr, // A is an MxK matrix - Layout const& layout_A, - ElementB const* B_ptr, // B is an KxN prepacked matrix - ElementD* D_ptr, // D is an MxN matrix - Layout const& layout_D, - ElementC const* C_ptr, // C is an MxN matrix - std::optional> const& layout_C, - ElementS const* S_ptr, // S is an scale_KxN matrix - std::optional> const& layout_S, - ElementZ const* Z_ptr, // Z is an scale_KxN matrix - std::optional> const& layout_Z, - ElementCompute alpha, ElementCompute beta, - std::optional maybe_group_size) { - static_assert(!with_zeropoints || with_scales); - - int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); + torch::Tensor const& A, // MxK matrix + torch::Tensor const& B, // KxN prepacked matrix + torch::Tensor& D, // MxN matrix + c10::optional const& maybe_g_scales, // scale_KxN matrix + c10::optional const& maybe_g_zeros, // scale_KxN matrix + c10::optional maybe_group_size, + c10::optional const& maybe_ch_scales, // len N vector + c10::optional const& maybe_tok_scales) // len M vector + { + static_assert(!with_group_zeropoints || with_group_scales); + + int M = A.size(0), N = B.size(1), K = A.size(1); + TORCH_CHECK(D.size(0) == M && D.size(1) == N); + + auto layout_A = make_cute_layout(A, "A"); + auto layout_D = make_cute_layout(D, "D"); + auto layout_S_group = + maybe_make_cute_layout(maybe_g_scales, "group_scales"); + auto layout_Z_group = + maybe_make_cute_layout(maybe_g_zeros, "group_zeros"); + int64_t numel_S_channel = maybe_ch_scales ? maybe_ch_scales->numel() : 0; + int64_t numel_S_token = maybe_tok_scales ? maybe_tok_scales->numel() : 0; + + auto unwrap = [](auto const& t) { + return t ? t->const_data_ptr() : nullptr; + }; + auto A_ptr = static_cast(A.const_data_ptr()); + auto B_ptr = static_cast(B.const_data_ptr()); + auto D_ptr = static_cast(D.mutable_data_ptr()); + auto S_group_ptr = + static_cast(unwrap(maybe_g_scales)); + auto Z_group_ptr = static_cast(unwrap(maybe_g_zeros)); + auto S_channel_ptr = + static_cast(unwrap(maybe_ch_scales)); + auto S_token_ptr = + static_cast(unwrap(maybe_tok_scales)); int const group_size = maybe_group_size == -1 ? K : maybe_group_size.value_or(K); @@ -159,26 +224,28 @@ struct MacheteKernelTemplate { TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N); - if constexpr (with_C) { - TORCH_CHECK(C_ptr && layout_C); + if constexpr (with_group_scales) { + TORCH_CHECK(S_group_ptr && layout_S_group); + TORCH_CHECK((size<0>(*layout_S_group) == scale_k && + size<1>(*layout_S_group) == N)); } else { - TORCH_CHECK(!C_ptr, "C not supported"); + TORCH_CHECK(!S_group_ptr, "Scales not supported"); } - if constexpr (with_scales) { - TORCH_CHECK(S_ptr && layout_S); - TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N)); + if constexpr (with_group_zeropoints) { + TORCH_CHECK(Z_group_ptr && layout_Z_group); + TORCH_CHECK((size<0>(*layout_Z_group) == scale_k && + size<1>(*layout_Z_group) == N)); + TORCH_CHECK(layout_S_group && *layout_Z_group == *layout_S_group, + "Scales and zeros must have the same layout"); } else { - TORCH_CHECK(!S_ptr, "Scales not supported"); + TORCH_CHECK(!Z_group_ptr, "Zeropoints not supported"); } - if constexpr (with_zeropoints) { - TORCH_CHECK(Z_ptr && layout_Z); - TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N)); - TORCH_CHECK(layout_S && *layout_Z == *layout_S, - "Scales and zeros must have the same layout"); - } else { - TORCH_CHECK(!Z_ptr, "Zeropoints not supported"); + if constexpr (with_channel_scales || with_token_scales) { + TORCH_CHECK( + (maybe_ch_scales->numel() == N || maybe_ch_scales->numel() == 1) && + (maybe_tok_scales->numel() == M || maybe_tok_scales->numel() == 1)); } // Transpose A and D @@ -186,24 +253,33 @@ struct MacheteKernelTemplate { // for B (which is At) auto stride_At = layout_A.stride(); auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride(); - auto stride_Ct = stride_Dt; - if (layout_C) { - stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride(); - } MainloopArguments mainloop_arguments{}; - EpilogueArguments epilogue_arguments{ - {alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt}; + // {Accum, C, C_layout, D, D} + EpilogueArguments epilogue_arguments{}; + + if constexpr (with_channel_scales || with_token_scales) { + epilogue_arguments = + EpilogueArguments{ChTokScalesEpilogue::prepare_args( + *maybe_ch_scales, *maybe_tok_scales), + nullptr, + {}, + D_ptr, + stride_Dt}; + } else { + epilogue_arguments = EpilogueArguments{{}, nullptr, {}, D_ptr, stride_Dt}; + } - if constexpr (with_scales && with_zeropoints) { - auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); - mainloop_arguments = - MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, - S_ptr, stride_S, group_size, Z_ptr}; - } else if constexpr (with_scales) { - auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride(); + if constexpr (with_group_scales && with_group_zeropoints) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); mainloop_arguments = MainloopArguments{ - B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size}; + B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size, Z_group_ptr}; + } else if constexpr (with_group_scales) { + auto stride_S_group = permute_layout<1, 0, 2>(*layout_S_group).stride(); + mainloop_arguments = + MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At, + S_group_ptr, stride_S_group, group_size}; } else { mainloop_arguments = MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At}; diff --git a/csrc/quantization/machete/machete_mm_launcher.cuh b/csrc/quantization/machete/machete_mm_launcher.cuh index 60a4ed60535b7..4b0da5b303e0c 100644 --- a/csrc/quantization/machete/machete_mm_launcher.cuh +++ b/csrc/quantization/machete/machete_mm_launcher.cuh @@ -5,73 +5,61 @@ #include "machete_mm_kernel.cuh" #include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" namespace machete { -struct PyTorchArguments { +struct MMArgs { torch::Tensor const& A; torch::Tensor const& B; - c10::optional const& scales; - c10::optional const& zeros; - c10::optional group_size; - c10::optional const& C; - c10::optional alpha; - c10::optional beta; - c10::optional schedule; + vllm::ScalarType const& b_type; + c10::optional const& maybe_out_type; + c10::optional const& maybe_group_scales; + c10::optional const& maybe_group_zeros; + c10::optional maybe_group_size; + c10::optional const& maybe_channel_scales; + c10::optional const& maybe_token_scales; + c10::optional maybe_schedule; }; +struct SupportedSchedulesArgs { + at::ScalarType a_type; + vllm::ScalarType b_type; + c10::optional maybe_group_scales_type; + c10::optional maybe_group_zeros_type; + c10::optional maybe_channel_scales_type; + c10::optional maybe_token_scales_type; + c10::optional maybe_out_type; +}; + +torch::Tensor mm_dispatch(MMArgs args); + +std::vector supported_schedules_dispatch( + SupportedSchedulesArgs args); + template -torch::Tensor run_impl(PyTorchArguments args) { +torch::Tensor run_impl(MMArgs args) { const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A)); auto device = args.A.device(); auto stream = at::cuda::getCurrentCUDAStream(device.index()); - using EleA = typename MacheteKernel::ElementA; - using EleB = typename MacheteKernel::ElementB; - using EleC = typename MacheteKernel::ElementC; - using EleD = typename MacheteKernel::ElementD; - using EleScale = typename MacheteKernel::ElementS; - using EleZero = typename MacheteKernel::ElementZ; - - using StrideA = typename MacheteKernel::StrideA; - using StrideC = typename MacheteKernel::StrideC; - using StrideD = typename MacheteKernel::StrideD; - using StrideS = typename MacheteKernel::StrideS; - using StrideZ = typename MacheteKernel::StrideZ; - int M = args.A.size(0); int N = args.B.size(1); int K = args.A.size(1); // Allocate output - torch::Tensor D = - torch::empty({M, N}, torch::TensorOptions() - .dtype(equivalent_scalar_type_v) - .device(device)); - - auto const &A = args.A, &B = args.B; - auto const &C = args.C, &scales = args.scales, &zeros = args.zeros; - - auto layout_A = make_cute_layout(A, "A"); - auto layout_D = make_cute_layout(D, "D"); - auto layout_C = maybe_make_cute_layout(C, "C"); - auto layout_S = maybe_make_cute_layout(scales, "scales"); - auto layout_Z = maybe_make_cute_layout(zeros, "zeros"); - - auto A_ptr = static_cast(A.const_data_ptr()); - auto B_ptr = static_cast(B.const_data_ptr()); - auto D_ptr = static_cast(D.mutable_data_ptr()); - auto C_ptr = static_cast(C ? C->const_data_ptr() : nullptr); - auto S_ptr = - static_cast(scales ? scales->const_data_ptr() : nullptr); - auto Z_ptr = - static_cast(zeros ? zeros->const_data_ptr() : nullptr); + torch::Tensor D = torch::empty( + {M, N}, + torch::TensorOptions() + .dtype(equivalent_scalar_type_v) + .device(device)); auto arguments = MacheteKernel::create_arguments( - stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, - layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), - args.group_size); + stream, // + args.A, args.B, D, args.maybe_group_scales, args.maybe_group_zeros, + args.maybe_group_size, args.maybe_channel_scales, + args.maybe_token_scales); TORCH_CHECK(MacheteKernel::can_implement(arguments), "Machete kernel cannot be run with these arguments"); @@ -84,12 +72,4 @@ torch::Tensor run_impl(PyTorchArguments args) { return D; }; -template -struct GemmDispatcher { - static torch::Tensor dispatch(PyTorchArguments args); - static std::vector supported_schedules(); -}; - }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_kernel.cuh b/csrc/quantization/machete/machete_prepack_kernel.cuh index f23483f928b47..d002355ca49d6 100644 --- a/csrc/quantization/machete/machete_prepack_kernel.cuh +++ b/csrc/quantization/machete/machete_prepack_kernel.cuh @@ -6,31 +6,49 @@ namespace machete { -template -static __global__ void prepack_B_kernel(BInTensor B_in, - BTiledOutTensor B_tiled_out) { - auto tB_in = local_tile(B_in, TileShapeNKL{}, - make_coord(blockIdx.x, blockIdx.y, blockIdx.z)); - auto tB_out = B_tiled_out(make_coord(_, _), - make_coord(blockIdx.x, blockIdx.y), blockIdx.z); +template +static __global__ void prepack_B_kernel(BInTensor B_in, ElementB* B_out_ptr) { + auto constexpr block_size = + Int{}; + auto constexpr eles_per_thread = Int{}; + static_assert(block_size % threads == 0, + "block_size must be divisible by the number of threads"); - auto tiled_copy = make_tiled_copy(Copy_Atom{}, - Layout, Stride<_32, _1>>{}, - Layout>{}); + // Which pre-packed are we responsible for + auto blk_coord = make_coord(blockIdx.x, blockIdx.y, blockIdx.z); + auto tB_in = local_tile( + B_in, append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}), + blk_coord); - auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + // Find the start offset in the output for this pre-packed block + auto bNbKL_to_offset = PrepackedLayoutB::bNbKL_to_offset(shape(B_in)); - Tensor thr_tile_S = thr_copy.partition_S(tB_in); - Tensor thr_tile_D = thr_copy.partition_D(tB_out); + // Tensor representing a 1:1 mapping to the output space in 1D + auto tB_out_linear = + make_tensor(get_logical_ptr(B_out_ptr) + bNbKL_to_offset(blk_coord), + make_layout(make_shape(block_size))); + // Mapping from output space (1D) to input space + auto tB_in_linear = make_tensor( + tB_in.data(), + tB_in.layout() + .compose(right_inverse(PrepackedLayoutB::ppblock_ilvd_NK_to_offset())) + .with_shape(make_shape(block_size))); + + // Tile for this specific thread (could have used a TiledCopy but these work + // best with 2d layouts, this is a simple 1d layout so local_tile is enough, + // we are also not that concerned with performance for this kernel) + auto thr_tB_in_linear = + local_tile(tB_in_linear, make_shape(eles_per_thread), threadIdx.x); + auto thr_tB_out_linear = + local_tile(tB_out_linear, make_shape(eles_per_thread), threadIdx.x); // Construct a register-backed Tensor with the same shape as each thread's // partition - auto fragment = make_tensor(shape(thr_tile_D)); + auto fragment = make_tensor(shape(thr_tB_in_linear)); - // Copy from GMEM to RMEM and from RMEM to GMEM - copy(tiled_copy, thr_tile_S, fragment); - copy(Copy_Atom{}, fragment, thr_tile_D); + copy(thr_tB_in_linear, fragment); + copy(Copy_Atom{}, fragment, thr_tB_out_linear); } template @@ -44,18 +62,15 @@ static void prepack_B_template( TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0); TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0); - TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0); auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{}); auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{}); - auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{}); + auto L_tiles = size<2>(B_layout); auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout); - auto B_tiled_out = - make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset); - prepack_B_kernel - <<>>(B_in, B_tiled_out); + prepack_B_kernel<128, PrepackedLayoutB> + <<>>(B_in, B_out_ptr); } }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index a33d8f9484cfe..3486d28be2126 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -2,9 +2,17 @@ #include "machete_prepack_kernel.cuh" #include "cutlass_extensions/torch_utils.hpp" +#include "core/scalar_type.hpp" namespace machete { +struct PrepackBArgs { + torch::Tensor const& B; + at::ScalarType a_type; + vllm::ScalarType b_type; + c10::optional maybe_group_scales_type; +}; + template torch::Tensor prepack_impl(torch::Tensor const B) { const at::cuda::OptionalCUDAGuard device_guard(device_of(B)); @@ -61,11 +69,6 @@ torch::Tensor prepack_impl(torch::Tensor const B) { return D; }; -template -struct PrepackBDispatcher { - static torch::Tensor dispatch(torch::Tensor B); -}; +torch::Tensor prepack_B_dispatch(PrepackBArgs args); }; // namespace machete \ No newline at end of file diff --git a/csrc/quantization/machete/machete_prepacked_layout.cuh b/csrc/quantization/machete/machete_prepacked_layout.cuh index 78e2cc5eec7d8..680a858a893c1 100644 --- a/csrc/quantization/machete/machete_prepacked_layout.cuh +++ b/csrc/quantization/machete/machete_prepacked_layout.cuh @@ -41,7 +41,7 @@ struct IlvBlkLayoutAuto {}; // The contract here is that the `TiledMma` determined below matches the one // ultimately used in the kernel. (this is also why the other element types are // required along with the kernel schedule) -template // clang-format on @@ -49,20 +49,27 @@ struct PrepackedLayoutBTemplate { using MmaType = ElementA_; using ElementA = ElementA_; using ElementB = ElementB_; - using ElementD = ElementD_; - using ElementAccumulator = - AccumulatorT; // Element type for internal accumulation + using ElementAccumulator = AccumulatorT; using ElementMma = MmaType; - // Only use interleaved layouts for subbyte weights, prmt instructions makes - // non-interleaved layouts for 8bit+ weights efficient enough we don't need - // iterleaved layouts + // Interleave for 4bit bit types when we are not upconverting to fp8 or int8, + // in those cases case we use a LUT using prmt instructions to upconvert and + // is more efficient if the data is not interleaved For 8bit+ prmt + // instructions makes non-interleaved layouts efficient enough we don't need + // iterleaved layouts (and can reuse more of the existing cutlass converts) + static constexpr bool should_interleave = + sizeof_bits_v <= 4 && + !std::is_same_v && + !std::is_same_v; + + // Only use interleaved layouts for subbyte weights, using IlvdBlkLayout = std::conditional_t< std::is_same_v, - std::conditional_t <= 4, - decltype(get_interleaved_blk_layout< - ElementB, sizeof_bits_v, 32>()), - void>, + std::conditional_t< + should_interleave, + decltype(get_interleaved_blk_layout< + ElementB, sizeof_bits_v, 32>()), + void>, IlvBlkLayout_>; // TODO (LucasWilkinson): compare the performance for other sizes @@ -135,7 +142,8 @@ struct PrepackedLayoutBTemplate { // then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H} auto frgV = get<1, 0>(layout_no_interleave); auto ilvdBlk = IlvdBlkLayout{}; - static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4"); + static_assert(size(frgV) % size(ilvdBlk) == 0, + "FrgV must be divisible by size(ilvdBlk)"); auto ilvd_FrgV = make_layout( make_shape(shape(ilvdBlk), Int{}), make_stride(stride(ilvdBlk), size(ilvdBlk))); @@ -175,6 +183,15 @@ struct PrepackedLayoutBTemplate { return group<1, 3>(result(_, repeat(result)>(_))); } + // ((athrid_val), (BlocksN, BlocksK, L)) -> (N, K, L) + template + CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset_copy( + Shape_NKL shape_mkl) { + auto layout = TVbNbKL_to_offset(shape_mkl); + return make_layout(coalesce(get<0>(layout)), get<1>(layout), + get<2>(layout)); + } + // ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx) template CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset( @@ -197,6 +214,19 @@ struct PrepackedLayoutBTemplate { return group<1, 3>(result(_, repeat(result)>(_))); } + // (BlocksN, BlocksK, L) -> (storage_idx) + template + CUTE_HOST_DEVICE static constexpr auto bNbKL_to_offset(Shape_NKL shape_mkl) { + // (BlocksN, BlocksK, L) + auto blocks_shape = + cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}), + [](auto x, auto y) { return x / y; }); + auto stride = size(PPBlockShape_NK{}); + + // (BlocksN, BlocksK, L) -> (storage_idx) + return make_layout(blocks_shape, compact_col_major(blocks_shape, stride)); + } + // ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L) template CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) { diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index 9f9073ded6191..da2c2fb0d3e77 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -8,89 +8,61 @@ namespace machete { using namespace vllm; -// -// Utils (type dispatching) -// - -template -static auto scalar_type_dispatch(ScalarType const& type, Fn fn) { - if (type == vllm::kU4) { - return fn(cutlass::uint4b_t{}); - } else if (type == vllm::kU8) { - return fn(cutlass::uint8_t{}); - } else if (type == vllm::kU4B8) { - return fn(cutlass::vllm_uint4b8_t{}); - } else if (type == vllm::kU8B128) { - return fn(cutlass::vllm_uint8b128_t{}); - } else { - TORCH_CHECK(false, "Unsupported type ", type.str()); - } -} - -#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \ - AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__) - -#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \ - AT_DISPATCH_SWITCH(TYPE, NAME, \ - AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__)) - -// -// Interface -// - -std::vector supported_schedules(ScalarTypeId const btype_id) { -#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 - vllm::ScalarType b_type = ScalarType::from_id(btype_id); - return scalar_type_dispatch(b_type, [&](auto BType) { - return GemmDispatcher::supported_schedules(); +std::vector supported_schedules( + at::ScalarType a_type, int64_t b_type_id, + c10::optional maybe_group_scales_type, + c10::optional maybe_group_zeros_type, + c10::optional maybe_channel_scales_type, + c10::optional maybe_token_scales_type, + c10::optional maybe_out_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return supported_schedules_dispatch({ + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type, + .maybe_group_zeros_type = maybe_group_zeros_type, + .maybe_channel_scales_type = maybe_channel_scales_type, + .maybe_token_scales_type = maybe_token_scales_type, + .maybe_out_type = maybe_out_type, }); -#else - TORCH_CHECK(false, "Machete requires CUDA 12.0 or later"); -#endif } -torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, - ScalarTypeId const btype_id, - c10::optional const& scales, - c10::optional const& zeros, - c10::optional group_size, - c10::optional const& C, - c10::optional alpha, c10::optional beta, - c10::optional schedule) { -#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 - ScalarType const btype = ScalarType::from_id(btype_id); - auto args = PyTorchArguments{.A = A, - .B = B, - .scales = scales, - .zeros = zeros, - .group_size = group_size, - .C = C, - .alpha = alpha, - .beta = beta, - .schedule = schedule}; - - return scalar_type_dispatch(btype, [&](auto BType) { - return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES( - A.scalar_type(), "machete_gemm", [&] { - using ComputeType = equivalent_cutlass_type_t; - return GemmDispatcher::dispatch(args); - }); - }); -#else - TORCH_CHECK(false, "Machete requires CUDA 12.0 or later"); -#endif +torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B, + int64_t b_type_id, + c10::optional const& maybe_out_type, + c10::optional const& maybe_group_scales, + c10::optional const& maybe_group_zeros, + c10::optional maybe_group_size, + c10::optional const& maybe_channel_scales, + c10::optional const& maybe_token_scales, + c10::optional maybe_schedule) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return mm_dispatch({.A = A, + .B = B, + .b_type = b_type, + .maybe_out_type = maybe_out_type, + .maybe_group_scales = maybe_group_scales, + .maybe_group_zeros = maybe_group_zeros, + .maybe_group_size = maybe_group_size, + .maybe_channel_scales = maybe_channel_scales, + .maybe_token_scales = maybe_token_scales, + .maybe_schedule = maybe_schedule}); } -torch::Tensor prepack_B(torch::Tensor const& B, ScalarTypeId const btype_id) { - ScalarType const btype = ScalarType::from_id(btype_id); - return scalar_type_dispatch(btype, [&](auto BType) { - return PrepackBDispatcher::dispatch(B); - }); +torch::Tensor prepack_B( + torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id, + c10::optional const& maybe_group_scales_type) { + ScalarType const b_type = ScalarType::from_id(b_type_id); + return prepack_B_dispatch( + {.B = B, + .a_type = a_type, + .b_type = b_type, + .maybe_group_scales_type = maybe_group_scales_type}); } TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { m.impl("machete_prepack_B", &prepack_B); - m.impl("machete_gemm", &gemm); + m.impl("machete_mm", &mm); } // use CatchAll since supported_schedules has no tensor arguments diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 229fd554d3eee..e4cc7ec951848 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -203,13 +203,36 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // conditionally compiled so impl in source file // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. - ops.def("machete_supported_schedules(int btype) -> str[]"); ops.def( - "machete_gemm(Tensor A, Tensor B, int btype, " - " Tensor? scales, Tensor? zeros, int? group_size, " - " Tensor? C, float? alpha, float? beta, str? schedule)" - "-> Tensor"); - ops.def("machete_prepack_B(Tensor B, int btype) -> Tensor"); + "machete_supported_schedules(" + " ScalarType a_type," + " int b_type," + " ScalarType? maybe_group_scales_type," + " ScalarType? maybe_group_zeros_type," + " ScalarType? maybe_channel_scales_type," + " ScalarType? maybe_token_scales_type," + " ScalarType? maybe_out_type" + ") -> str[]"); + ops.def( + "machete_mm(" + " Tensor A," + " Tensor B," + " int b_type," + " ScalarType? out_type," + " Tensor? group_scales," + " Tensor? group_zeros," + " int? group_size," + " Tensor? channel_scales," + " Tensor? token_scales," + " str? schedule" + ") -> Tensor"); + ops.def( + "machete_prepack_B(" + " Tensor B," + " ScalarType a_type," + " int b_type," + " ScalarType? group_scales_type" + ") -> Tensor"); // conditionally compiled so impl registration is in source file ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_gemm.py deleted file mode 100644 index 59c0a24753c3b..0000000000000 --- a/tests/kernels/test_machete_gemm.py +++ /dev/null @@ -1,284 +0,0 @@ -"""Tests for the machete kernel. - -Run `pytest tests/kernels/test_machete_gemm.py`. -""" - -import math -from typing import Optional, Tuple - -import pytest -import torch - -from tests.kernels.utils import opcheck -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) -from vllm.platforms import current_platform -from vllm.scalar_type import ScalarType, scalar_types - -CUDA_DEVICES = [ - f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) -] - -MNK_SHAPES = [ - (1, 128, 128), - (1, 512, 1024), - (1, 4096, 4096), - (1, 8192, 28672), - (13, 8192, 4096), - (26, 4096, 8192), - (64, 4096, 4096), - (64, 8192, 28672), - (257, 128, 4096), - (257, 4224, 4160), - (257, 4096, 4096), - (1024, 4096, 8192), - (1024, 8192, 4096), -] - -ACT_TYPES = [torch.float16, torch.bfloat16] -WTYPE_ZEROPOINTS = [ - # GPTQ style - (scalar_types.uint4b8, False), - (scalar_types.uint8b128, False), - # AWQ style - (scalar_types.uint4, True), - (scalar_types.uint8, True), -] - -# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel -# unit tests to a common utility function. Currently the use of -# `is_quant_method_supported` conflates kernels with quantization methods -# an assumption which is breaking down as quantizations methods can have -# have kernels and some kernels support multiple quantization methods. -IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) - - -def rand_data(shape, dtype=torch.float16): - return 10 * (torch.rand(shape, dtype=dtype, device="cuda") - 0.3) - - -def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): - return zps if zps is None else -1 * s * (zps.to(s.dtype)) - - -def machete_quantize_and_pack(w: torch.Tensor, - wtype: ScalarType, - group_size: int, - zero_points: bool = False): - assert wtype.is_integer(), "TODO: support floating point weights" - - w_ref, w_q, w_s, w_zp = quantize_weights( - w, - wtype, - group_size, - zero_points=zero_points, - # to match how the kernel applies zps - ref_zero_points_after_scales=True) - - w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) - w_q = w_q.t().contiguous().t() # convert to col major - w_q_machete = ops.machete_prepack_B(w_q, wtype) - - opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id)) - - return w_ref, w_q_machete, w_s, w_zp - - -def machete_gemm_test_helper(a: torch.Tensor, b: torch.Tensor, - wtype: ScalarType, group_size: int, - zero_points: bool): - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b, wtype, group_size, zero_points) - - output_ref = torch.matmul(a, w_ref) - - output = ops.machete_gemm( - a=a, - b_q=w_q_packed, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - ) - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(a.shape[1]), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) -@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) -@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) -@pytest.mark.parametrize("group_size", [128, None]) -def test_machete_all_schedules(shape, atype: torch.dtype, - wtype_zeropoints: Tuple[ScalarType, bool], - group_size: Optional[int]): - m, n, k = shape - wtype, zero_points = wtype_zeropoints - - if group_size is not None and k % group_size != 0: - return - - print(f"MNK = {m} {n} {k}") - - # Normalize group_size - if group_size is None: - group_size = k - assert group_size <= k - - a = rand_data((m, k), atype) - w = rand_data((k, n), atype) - - w_ref, w_q_machete, w_s, w_zp = machete_quantize_and_pack( - w, wtype, group_size, zero_points) - - output_ref = torch.matmul(a, w_ref) - - for schedule in ops.machete_supported_schedules(wtype): - print(f"Testing schedule {schedule}") - output = ops.machete_gemm( - a, - b_q=w_q_machete, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - schedule=schedule, - ) - - opcheck( - torch.ops._C.machete_gemm, - (a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints( - w_zp, w_s), group_size, None, None, None, schedule)) - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol),\ - f"Schedule failed {schedule}" - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("shape", - MNK_SHAPES, - ids=lambda x: "x".join(str(v) for v in x)) -@pytest.mark.parametrize("atype", ACT_TYPES, ids=lambda x: str(x)) -@pytest.mark.parametrize("wtype_zeropoints", WTYPE_ZEROPOINTS) -@pytest.mark.parametrize("group_size", [128, None]) -def test_machete_heuristic(shape, atype: torch.dtype, - wtype_zeropoints: Tuple[ScalarType, bool], - group_size: Optional[int]): - m, n, k = shape - wtype, zero_points = wtype_zeropoints - - if group_size is not None and k % group_size != 0: - return - - # Normalize group_size - if group_size is None: - group_size = k - assert group_size <= k - - a = rand_data((m, k), atype) - b = rand_data((k, n), atype) - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test working on other devices -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_machete_devices(device: str): - m, n, k = 512, 4096, 4096 - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - print(f"MNK = {m} {n} {k}, device = {device}") - - a = rand_data((m, k), torch.float16).to(device) - b = rand_data((k, n), torch.float16).to(device) - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test working with a subset of A and B -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -def test_machete_subset(): - big_m, big_n, big_k = 1024, 1024, 1024 - m, n, k = 512, 512, 512 - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - whole_a = rand_data((big_m, big_k), torch.float16) - whole_b = rand_data((big_k, big_n), torch.float16) - - a = whole_a[0:m, 0:k] - b = whole_b[0:k, 0:n] - - machete_gemm_test_helper(a, b, wtype, group_size, zero_points) - - -# Test to make sure cuda graphs work -class MacheteLayer(torch.nn.Module): - - def __init__(self, **kwargs): - super().__init__() - self.kwargs = kwargs - - def forward(self, a): - return ops.machete_gemm(**self.kwargs) - - -@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, - reason="Machete is not supported on this GPU type.") -def test_machete_cuda_graph(): - m, n, k = 512, 4096, 4096 - - a = rand_data((m, k), torch.float16) - b = rand_data((k, n), torch.float16) - wtype = scalar_types.uint4b8 - group_size = 128 - zero_points = False - - w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( - b, wtype, group_size, zero_points) - - # Construct a trivial model with a single layer that calls a machete kernel - model = MacheteLayer( - a=a, - b_q=w_q_packed, - b_type=wtype, - b_scales=w_s, - b_zeros=maybe_convert_zeropoints(w_zp, w_s), - b_group_size=group_size, - ) - - output_ref = torch.matmul(a, w_ref) - - # Run the model with a cuda graph - stream = torch.cuda.Stream() - with torch.cuda.stream(stream): - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - output = model(a) - output.zero_() - g.replay() - - # Relax atol as our reduction dim becomes larger (more rounding error) - # Relax atol when we have zeropoints since the way machete applies - # zeropoints (after scales) causes noise around 0 - atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) - torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/tests/kernels/test_machete_mm.py b/tests/kernels/test_machete_mm.py new file mode 100644 index 0000000000000..1c6eb2dd9a228 --- /dev/null +++ b/tests/kernels/test_machete_mm.py @@ -0,0 +1,406 @@ +"""Tests for the machete kernel. + +Run `pytest tests/kernels/test_machete_mm.py`. +""" + +import math +from dataclasses import dataclass, fields +from typing import List, Optional, Tuple + +import pytest +import torch + +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_rows, quantize_weights) +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9 + +MNK_SHAPES = [ + (1, 128, 128), + (1, 512, 1024), + (1, 4096, 4096), + (1, 8192, 28672), + (13, 8192, 4096), + (26, 4096, 8192), + (64, 4096, 4096), + (64, 8192, 28672), + (257, 128, 4096), + (257, 4224, 4160), + (257, 4096, 4096), + (1024, 4096, 8192), + (1024, 8192, 4096), +] + +GROUP_SIZES_TO_TEST: List[Optional[int]] = [128, -1] + + +@dataclass +class TypeConfig: + act_type: torch.dtype + weight_type: ScalarType + output_type: Optional[torch.dtype] + group_scale_type: Optional[torch.dtype] + group_zero_type: Optional[torch.dtype] + channel_scale_type: Optional[torch.dtype] + token_scale_type: Optional[torch.dtype] + + +@dataclass +class Tensors: + w_ref: torch.Tensor + a_ref: torch.Tensor + a: torch.Tensor + w_q: torch.Tensor + w_g_s: Optional[torch.Tensor] + w_g_zp: Optional[torch.Tensor] + w_ch_s: Optional[torch.Tensor] + w_tok_s: Optional[torch.Tensor] + + +# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints, +# Ch Scales Type, Tok Scales Type) +# NOTE: None "Scale Type" means the act type is floating point +# None "Output Type" means the output type is the same as the act type +TestTypeTuple = Tuple[List[torch.dtype], ScalarType, Optional[torch.dtype], + Optional[torch.dtype], bool] +TEST_TYPES = [ + # GPTQ style + *(TypeConfig(act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + for w_type in [scalar_types.uint4b8, scalar_types.uint8b128] + for a_type in [torch.float16, torch.bfloat16]), + # AWQ style + *(TypeConfig(act_type=a_type, + weight_type=w_type, + output_type=None, + group_scale_type=a_type, + group_zero_type=a_type, + channel_scale_type=None, + token_scale_type=None) + for w_type in [scalar_types.uint4, scalar_types.uint8] + for a_type in [torch.float16, torch.bfloat16]), + # QQQ style + *(TypeConfig(act_type=torch.int8, + weight_type=scalar_types.uint4b8, + output_type=torch.float16, + group_scale_type=group_scale_type, + group_zero_type=None, + channel_scale_type=torch.float, + token_scale_type=torch.float) + for group_scale_type in [None, torch.float16]), + *(TypeConfig(act_type=torch.float8_e4m3fn, + weight_type=scalar_types.uint4b8, + output_type=torch.float16, + group_scale_type=group_scale_type, + group_zero_type=None, + channel_scale_type=torch.float, + token_scale_type=torch.float) + for group_scale_type in [None, torch.float16]), +] + +# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel +# unit tests to a common utility function. Currently the use of +# `is_quant_method_supported` conflates kernels with quantization methods +# an assumption which is breaking down as quantizations methods can have +# have kernels and some kernels support multiple quantization methods. +IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90) + + +def rand_data(shape, dtype=torch.float16, scale=1, offset=0): + if dtype.is_floating_point: + return (scale * torch.rand(shape, device="cuda") - offset).to(dtype) + else: + return torch.randint(-8, 7, shape, dtype=dtype, device="cuda") + + +def maybe_convert_zeropoints(zps: Optional[torch.Tensor], s: torch.Tensor): + return zps if zps is None else -1 * s * (zps.to(s.dtype)) + + +def group_size_valid(shape: Tuple[int, int, int], + group_size: Optional[int]) -> bool: + return group_size is None or group_size == -1 or group_size % shape[2] == 0 + + +def machete_quantize_and_pack(atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False): + assert wtype.is_integer(), "TODO: support floating point weights" + + w_ref, w_q, w_s, w_zp = quantize_weights( + w, + wtype, + group_size=group_size, + zero_points=zero_points, + # to match how the kernel applies zps + ref_zero_points_after_scales=True) + + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # convert to col major + + w_q_machete = ops.machete_prepack_B(w_q, atype, wtype, stype) + opcheck(torch.ops._C.machete_prepack_B, (w_q, atype, wtype.id, stype)) + + return w_ref, w_q_machete, w_s, w_zp + + +def create_test_tensors(shape: Tuple[int, int, int], + types: TypeConfig, + group_size: Optional[int], + subset_stride_factor: Optional[int] = None) -> Tensors: + m, n, k = shape + factor = subset_stride_factor or 1 + + print("create_test_tensors, shape:", shape, "types:", types, "group_size:", + group_size) + + a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2) + w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1) + + if factor > 1: + a = a[0:m, 0:k] + w = w[0:k, 0:n] + + if types.group_scale_type is not None: + w = w.to(types.group_scale_type) + if w.dtype.itemsize == 1: + w = w.to(torch.float16) + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + a.dtype, w, types.weight_type, types.group_scale_type, group_size, + types.group_zero_type is not None) + + if not a.dtype.is_floating_point: + aiinfo = torch.iinfo(a.dtype) + w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max) + + a_ref = a.to(torch.float32) + w_ref = w_ref.to(torch.float32) + + w_ch_s = None if types.channel_scale_type is None else\ + rand_data((n,), types.channel_scale_type) + w_tok_s = None if types.token_scale_type is None else\ + rand_data((m,), types.token_scale_type) + + return Tensors(w_ref=w_ref, + a_ref=a_ref, + a=a, + w_q=w_q_packed, + w_g_s=w_s, + w_g_zp=maybe_convert_zeropoints(w_zp, w_s), + w_ch_s=w_ch_s, + w_tok_s=w_tok_s) + + +# None stype means scales use the same dtype as a +def machete_mm_test_helper(types: TypeConfig, + tensors: Tensors, + group_size: Optional[int] = None, + schedule: Optional[str] = None): + output_ref = torch.matmul(tensors.a_ref, tensors.w_ref) + output_ref_type = output_ref.dtype + + if tensors.w_ch_s is not None: + output_ref = (output_ref.to(tensors.w_ch_s.dtype) * + tensors.w_ch_s.unsqueeze(0)).to(output_ref_type) + if tensors.w_tok_s is not None: + output_ref = (output_ref.to(tensors.w_tok_s.dtype) * + tensors.w_tok_s.unsqueeze(1)).to(output_ref_type) + + output = ops.machete_mm( + a=tensors.a, + b_q=tensors.w_q, + b_type=types.weight_type, + b_group_scales=tensors.w_g_s, + b_group_zeros=tensors.w_g_zp, + b_group_size=group_size, + b_channel_scales=tensors.w_ch_s, + a_token_scales=tensors.w_tok_s, + out_type=types.output_type, + schedule=schedule, + ) + + print(output) + print(output_ref) + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if tensors.w_g_zp is not None\ + else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1) + rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1 + torch.testing.assert_close(output, + output_ref.to(output.dtype), + rtol=rtol, + atol=atol) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +def test_machete_all_schedules(shape, types: TypeConfig): + + group_sizes: List[Optional[int]] = [] + if types.group_scale_type is None: + group_sizes = [None] + else: + group_sizes = GROUP_SIZES_TO_TEST + + for group_size in group_sizes: + if not group_size_valid(shape, group_size): + continue + + tensors = create_test_tensors(shape, types, group_size) + print(f"MNK = {shape}") + for schedule in ops.machete_supported_schedules( + types.act_type, + types.weight_type, + group_scales_type=types.group_scale_type, + group_zeros_type=types.group_scale_type, + out_type=types.output_type): + print(f"Testing schedule {schedule}") + machete_mm_test_helper(types, tensors, group_size, schedule) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("shape", + MNK_SHAPES, + ids=lambda x: "x".join(str(v) for v in x)) +@pytest.mark.parametrize("types", TEST_TYPES) +def test_machete_heuristic(shape, types: TypeConfig): + group_sizes: List[Optional[int]] = [] + if types.group_scale_type is None: + group_sizes = [None] + else: + group_sizes = GROUP_SIZES_TO_TEST + + for group_size in group_sizes: + if not group_size_valid(shape, group_size): + continue + + tensors = create_test_tensors(shape, types, group_size) + machete_mm_test_helper(types, tensors, group_size) + + +# Test working on other devices +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_machete_devices(device: str): + group_size = 128 + + type_config = TypeConfig(act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + + tensors = create_test_tensors((512, 4096, 4096), type_config, group_size) + + for field in fields(Tensors): + tensor = getattr(tensors, field.name) + if isinstance(tensor, torch.Tensor): + setattr(tensors, field.name, tensor.to(device)) + + machete_mm_test_helper(type_config, tensors, group_size) + + +# Test working with a subset of A and B +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_subset(): + group_size = 128 + + type_config = TypeConfig(act_type=torch.float16, + weight_type=scalar_types.uint4b8, + output_type=None, + group_scale_type=torch.float16, + group_zero_type=None, + channel_scale_type=None, + token_scale_type=None) + + tensors = create_test_tensors((512, 4096, 4096), + type_config, + group_size, + subset_stride_factor=2) + machete_mm_test_helper(type_config, tensors, group_size) + + +# Test to make sure cuda graphs work +class MacheteLayer(torch.nn.Module): + + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + + def forward(self, a): + return ops.machete_mm(a=a, **self.kwargs) + + +@pytest.mark.skipif(not IS_SUPPORTED_BY_GPU, + reason="Machete is not supported on this GPU type.") +def test_machete_cuda_graph(): + m, n, k = 512, 4096, 4096 + + a = rand_data((m, k), torch.float16) + b = rand_data((k, n), torch.float16) + wtype = scalar_types.uint4b8 + stype = torch.float16 + group_size = 128 + zero_points = False + + w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack( + a.dtype, b, wtype, stype, group_size, zero_points) + + # Construct a trivial model with a single layer that calls a machete kernel + model = MacheteLayer( + b_q=w_q_packed, + b_type=wtype, + b_group_scales=w_s, + b_group_zeros=maybe_convert_zeropoints(w_zp, w_s), + b_group_size=group_size, + ) + + output_ref = torch.matmul(a, w_ref) + + # Run the model with a cuda graph + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + output = model(a) + output.zero_() + g.replay() + + # Relax atol as our reduction dim becomes larger (more rounding error) + # Relax atol when we have zeropoints since the way machete applies + # zeropoints (after scales) causes noise around 0 + atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1) + torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b276b8fc25473..aa89010ca8ecd 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -444,18 +444,18 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, size_k: torch.SymInt) -> torch.Tensor: return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - @register_fake("_C::machete_gemm") - def machete_gemm_fake( + @register_fake("_C::machete_mm") + def machete_mm_fake( a: torch.Tensor, - # Should be the tensor returned by machete_prepack_B + # b_q Should be the tensor returned by machete_prepack_B b_q: torch.Tensor, b_type: ScalarType, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, schedule: Optional[str] = None, ) -> torch.Tensor: m = a.size(0) @@ -463,8 +463,9 @@ def machete_gemm_fake( return torch.empty((m, n), device=a.device, dtype=a.dtype) @register_fake("_C::machete_prepack_B") - def machete_prepack_B_fake(b_q_weight: torch.Tensor, - b_type: ScalarType) -> torch.Tensor: + def machete_prepack_B_fake( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) @@ -617,29 +618,41 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, # machete -def machete_supported_schedules(b_type: ScalarType) -> List[str]: - return torch.ops._C.machete_supported_schedules(b_type.id) - - -def machete_gemm( - a: torch.Tensor, - b_q: torch.Tensor, # Should be the tensor returned by machete_prepack_B - b_type: ScalarType, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, - schedule: Optional[str] = None, -) -> torch.Tensor: - return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros, - b_group_size, c, alpha, beta, schedule) +def machete_supported_schedules( + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None) -> List[str]: + return torch.ops._C.machete_supported_schedules( + a_type, b_type.id, group_scales_type, group_zeros_type, + channel_scales_type, token_scales_type, out_type) -def machete_prepack_B(b_q_weight: torch.Tensor, - b_type: ScalarType) -> torch.Tensor: - return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id) +def machete_mm( + a: torch.Tensor, + # b_q Should be the tensor returned by machete_prepack_B + b_q: torch.Tensor, + b_type: ScalarType, + out_type: Optional[torch.dtype] = None, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + schedule: Optional[str] = None) -> torch.Tensor: + return torch.ops._C.machete_mm(a, b_q, b_type.id, out_type, b_group_scales, + b_group_zeros, b_group_size, + b_channel_scales, a_token_scales, schedule) + + +def machete_prepack_B( + b_q_weight: torch.Tensor, a_type: torch.dtype, b_type: ScalarType, + group_scales_type: Optional[torch.dtype]) -> torch.Tensor: + return torch.ops._C.machete_prepack_B(b_q_weight, a_type, b_type.id, + group_scales_type) if hasattr(torch.ops._C, "permute_cols"): diff --git a/vllm/model_executor/layers/quantization/kernels/machete.py b/vllm/model_executor/layers/quantization/kernels/machete.py index e5696d08f30f5..15df0200f30b5 100644 --- a/vllm/model_executor/layers/quantization/kernels/machete.py +++ b/vllm/model_executor/layers/quantization/kernels/machete.py @@ -79,7 +79,9 @@ def transform_w_q(x): c.weight_type, packed_dim=0) x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), - self.config.weight_type) + a_type=c.act_type, + b_type=c.weight_type, + group_scales_type=c.act_type) return x def transform_w_s(x): @@ -105,12 +107,12 @@ def apply_weights(self, if c.has_g_idx: x_2d = self.act_perm(x_2d) - output = ops.machete_gemm(a=x_2d, - b_q=w_q, - b_type=c.weight_type, - b_zeros=None, - b_scales=w_s, - b_group_size=c.group_size) + output = ops.machete_mm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_group_zeros=None, + b_group_scales=w_s, + b_group_size=c.group_size) if bias is not None: output.add_(bias) # In-place add diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index c217f5ca620a1..83055d6000d83 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -126,11 +126,14 @@ def permute_rows(q_w: torch.Tensor, def quantize_weights(w: torch.Tensor, quant_type: ScalarType, - group_size: int, + group_size: Optional[int], zero_points: bool = False, ref_zero_points_after_scales: bool = False): assert quant_type.is_integer(), \ "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, \ + "to have group zero points, group_size must be provided "\ + "(-1 group_size is channelwise)" orig_device = w.device orig_type = w.dtype @@ -140,10 +143,9 @@ def quantize_weights(w: torch.Tensor, if group_size == -1: group_size = size_k - assert group_size <= size_k # Reshape to [groupsize, -1] - if group_size < size_k: + if group_size is not None and group_size < size_k: w = w.reshape((-1, group_size, size_n)) w = w.permute(1, 0, 2) w = w.reshape((group_size, -1)) @@ -155,18 +157,20 @@ def quantize_weights(w: torch.Tensor, max_q_val = quant_type.max() min_q_val = quant_type.min() - if zero_points: - assert not quant_type.is_signed() and quant_type.max() > 0 - w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() - maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ - .clamp(min_q_val, max_q_val).int() - else: - # If the bias is such that there are no possible negative/positive - # values, set the max value to inf to avoid divide by 0 - w_s = torch.max( - abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), - abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) - maybe_w_zp = None + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ + .clamp(min_q_val, max_q_val).int() + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) # Quantize w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) @@ -176,7 +180,7 @@ def quantize_weights(w: torch.Tensor, # For some kernels (namely Machete) the zero-points are applied after the # scales are applied, for this case computing the reference in similar way # allows us to use tighter error tolerances in our unit tests. - if ref_zero_points_after_scales and zero_points: + if ref_zero_points_after_scales and maybe_w_zp is not None: w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s else: w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s @@ -185,7 +189,7 @@ def quantize_weights(w: torch.Tensor, w_q += quant_type.bias # Restore original shapes - if group_size < size_k: + if group_size is not None and group_size < size_k: def reshape_w(w): w = w.reshape((group_size, -1, size_n)) @@ -195,17 +199,16 @@ def reshape_w(w): w_q = reshape_w(w_q) w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() - w_s = w_s.reshape((-1, size_n)).contiguous() - - if zero_points: + if maybe_w_zp is not None: maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() maybe_w_zp = maybe_w_zp.to(device=orig_device) return ( w_ref.to(device=orig_device), w_q.to(device=orig_device), - w_s.to(device=orig_device), + w_s if group_size is not None else None, maybe_w_zp, )